ctm-dqn/scripts/plot_fundamental_diagram.py

380 lines
13 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import argparse
import xml.etree.ElementTree as ET
from pathlib import Path
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
PROJECT_ROOT = Path(__file__).resolve().parent.parent
DEFAULT_SPEED_CSV = PROJECT_ROOT / "speed.csv"
DEFAULT_ROUTE_DIR = PROJECT_ROOT / "sumo_resource"
DEFAULT_OUTPUT_DIR = PROJECT_ROOT / "results" / "fundamental_diagram"
DEFAULT_EDGE_ID = "G1523_AM7.1"
PAPER_FIGSIZE = (6.6, 4.8)
PAPER_DPI = 300
def parse_args():
parser = argparse.ArgumentParser(
description="Plot q-k, v-k, and q-v fundamental diagrams from SUMO route files and speed.csv."
)
parser.add_argument(
"--speed-csv",
default=str(DEFAULT_SPEED_CSV),
help=f"Path to speed CSV. Default: {DEFAULT_SPEED_CSV}",
)
parser.add_argument(
"--route-dir",
default=str(DEFAULT_ROUTE_DIR),
help=f"Directory containing routes_YYYYMMDD.xml. Default: {DEFAULT_ROUTE_DIR}",
)
parser.add_argument(
"--edge-id",
default=DEFAULT_EDGE_ID,
help=f"Representative section edge id. Default: {DEFAULT_EDGE_ID}",
)
parser.add_argument(
"--date",
default="all",
help="Use one day in YYYY-MM-DD format or all overlapping full dates. Default: all",
)
parser.add_argument(
"--output-dir",
default=str(DEFAULT_OUTPUT_DIR),
help=f"Output directory. Default: {DEFAULT_OUTPUT_DIR}",
)
return parser.parse_args()
def load_speed_profile(speed_csv: Path) -> pd.DataFrame:
df = pd.read_csv(speed_csv, encoding="utf-8-sig")
time_col = df.columns[0]
speed_col = df.columns[1]
speed = df[[time_col, speed_col]].copy()
speed.columns = ["datetime", "v_kmh"]
speed["datetime"] = pd.to_datetime(speed["datetime"])
speed["v_kmh"] = pd.to_numeric(speed["v_kmh"], errors="coerce")
speed = speed.dropna()
speed["date"] = speed["datetime"].dt.strftime("%Y-%m-%d")
speed["clock_time"] = speed["datetime"].dt.strftime("%H:%M")
return speed[["datetime", "date", "clock_time", "v_kmh"]]
def extract_edge_flow_from_route(route_path: Path, edge_id: str) -> pd.DataFrame:
root = ET.parse(route_path).getroot()
route_edges = {elem.attrib["id"]: elem.attrib.get("edges", "").split() for elem in root.findall("route")}
bucket_counts: dict[int, int] = {}
# Use departure interval of vehicles whose routes traverse the target edge as section equivalent flow.
for vehicle in root.findall("vehicle"):
route_id = vehicle.attrib.get("route")
edges = route_edges.get(route_id)
if not edges or edge_id not in edges:
continue
depart = float(vehicle.attrib["depart"]) % 86400.0
bucket = int(depart // 1800.0)
if 0 <= bucket < 48:
bucket_counts[bucket] = bucket_counts.get(bucket, 0) + 1
flow = pd.DataFrame({"bucket": range(48)})
flow["veh_30min"] = flow["bucket"].map(lambda bucket: bucket_counts.get(bucket, 0))
flow["q_vph"] = flow["veh_30min"] * 2.0
flow["clock_time"] = pd.to_datetime(flow["bucket"] * 1800, unit="s").dt.strftime("%H:%M")
return flow[["clock_time", "veh_30min", "q_vph"]]
def available_route_dates(route_dir: Path) -> set[str]:
dates = set()
for path in route_dir.glob("routes_*.xml"):
stem = path.stem
parts = stem.split("_")
if len(parts) != 2 or len(parts[1]) != 8 or not parts[1].isdigit():
continue
date_str = f"{parts[1][:4]}-{parts[1][4:6]}-{parts[1][6:8]}"
dates.add(date_str)
return dates
def select_dates(speed: pd.DataFrame, route_dir: Path, date_arg: str) -> list[str]:
speed_counts = speed.groupby("date").size()
route_dates = available_route_dates(route_dir)
if date_arg != "all":
if date_arg not in route_dates:
raise FileNotFoundError(f"Route file for {date_arg} not found under {route_dir}")
if int(speed_counts.get(date_arg, 0)) == 0:
raise ValueError(f"No speed observations found for {date_arg}")
return [date_arg]
dates = [date for date, count in speed_counts.items() if count >= 48 and date in route_dates]
if not dates:
raise ValueError("No overlapping full dates between speed.csv and route files.")
return sorted(dates)
def build_flow_dataset(route_dir: Path, edge_id: str, dates: list[str]) -> pd.DataFrame:
frames = []
for date_str in dates:
route_path = route_dir / f"routes_{date_str.replace('-', '')}.xml"
flow = extract_edge_flow_from_route(route_path, edge_id)
flow["date"] = date_str
flow["datetime"] = pd.to_datetime(flow["date"] + " " + flow["clock_time"])
frames.append(flow[["datetime", "date", "clock_time", "veh_30min", "q_vph"]])
return pd.concat(frames, ignore_index=True).sort_values("datetime").reset_index(drop=True)
def build_merged_dataset(flow: pd.DataFrame, speed: pd.DataFrame, dates: list[str]) -> pd.DataFrame:
speed_use = speed[speed["date"].isin(dates)].copy()
merged = flow.merge(speed_use[["datetime", "date", "clock_time", "v_kmh"]], on=["datetime", "date", "clock_time"], how="inner")
if merged.empty:
raise ValueError("No aligned flow-speed observations after merging route flow and speed data.")
merged = merged[(merged["q_vph"] > 0) & (merged["v_kmh"] > 0)].copy()
merged["k_vpkm"] = merged["q_vph"] / merged["v_kmh"]
return merged[["datetime", "date", "clock_time", "veh_30min", "q_vph", "v_kmh", "k_vpkm"]]
def fit_greenshields(merged: pd.DataFrame) -> dict[str, float]:
coef = np.polyfit(merged["k_vpkm"], merged["v_kmh"], 1)
slope, intercept = coef
vf = float(intercept)
kj = float(-intercept / slope)
kc = kj / 2.0
capacity = vf * kj / 4.0
pred_speed = np.polyval(coef, merged["k_vpkm"])
ss_res = float(np.sum((merged["v_kmh"] - pred_speed) ** 2))
ss_tot = float(np.sum((merged["v_kmh"] - merged["v_kmh"].mean()) ** 2))
r2 = 1.0 - ss_res / ss_tot if ss_tot else float("nan")
return {
"slope": float(slope),
"intercept": float(intercept),
"vf": vf,
"kj": kj,
"kc": kc,
"capacity": capacity,
"r2": r2,
"observed_max_flow_veh_h": float(merged["q_vph"].max()),
"sample_count": int(len(merged)),
}
def configure_paper_style():
plt.rcParams.update(
{
"font.family": "serif",
"font.serif": ["SimSun", "STSong", "Times New Roman", "DejaVu Serif"],
"font.sans-serif": ["SimHei", "Microsoft YaHei", "DejaVu Sans"],
"axes.unicode_minus": False,
"font.size": 11,
"axes.labelsize": 11,
"xtick.labelsize": 10,
"ytick.labelsize": 10,
"legend.fontsize": 9,
"figure.titlesize": 12,
}
)
def style_axis(ax):
ax.grid(True, color="#d9d9d9", linewidth=0.6, linestyle="-", alpha=0.9)
ax.tick_params(direction="in", length=4, width=0.8)
for spine in ax.spines.values():
spine.set_linewidth(0.9)
spine.set_color("black")
def build_curve(fit: dict[str, float]) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
density_grid = np.linspace(0.0, fit["kj"], 300)
speed_grid = fit["vf"] + fit["slope"] * density_grid
flow_grid = density_grid * speed_grid
return density_grid, speed_grid, flow_grid
def save_flow_density_plot(
merged: pd.DataFrame,
fit: dict[str, float],
edge_id: str,
date_label: str,
output_png: Path,
):
density_grid, _, flow_grid = build_curve(fit)
fig, ax = plt.subplots(figsize=PAPER_FIGSIZE)
ax.scatter(
merged["k_vpkm"],
merged["q_vph"],
s=18,
facecolors="white",
edgecolors="#1f4e79",
linewidths=0.8,
alpha=0.95,
label="观测值",
)
ax.plot(density_grid, flow_grid, color="#111111", linewidth=1.6, label="拟合曲线")
ax.axvline(fit["kc"], color="#666666", linestyle="--", linewidth=1.0)
ax.axhline(fit["capacity"], color="#666666", linestyle="--", linewidth=1.0)
ax.scatter([fit["kc"]], [fit["capacity"]], color="#aa3a38", s=36, zorder=5)
ax.annotate(
f"$C$={fit['capacity']:.1f} veh/h\n$k_c$={fit['kc']:.1f} veh/km",
xy=(fit["kc"], fit["capacity"]),
xytext=(10, -18),
textcoords="offset points",
fontsize=8.8,
bbox={"boxstyle": "square,pad=0.18", "facecolor": "white", "edgecolor": "#999999", "linewidth": 0.6},
)
ax.set_xlabel("交通密度 kveh/km")
ax.set_ylabel("交通流率 qveh/h")
style_axis(ax)
ax.legend(
loc="upper left",
frameon=True,
fancybox=False,
edgecolor="#666666",
borderpad=0.25,
handlelength=1.8,
handletextpad=0.5,
labelspacing=0.3,
)
fig.tight_layout()
fig.savefig(output_png, dpi=PAPER_DPI, bbox_inches="tight")
plt.close(fig)
def save_speed_density_plot(
merged: pd.DataFrame,
fit: dict[str, float],
output_png: Path,
):
density_grid, speed_grid, _ = build_curve(fit)
fig, ax = plt.subplots(figsize=PAPER_FIGSIZE)
ax.scatter(
merged["k_vpkm"],
merged["v_kmh"],
s=18,
facecolors="white",
edgecolors="#2d6a4f",
linewidths=0.8,
alpha=0.95,
label="观测值",
)
ax.plot(density_grid, speed_grid, color="#111111", linewidth=1.6, label="拟合曲线")
ax.axvline(fit["kc"], color="#666666", linestyle="--", linewidth=1.0)
ax.axhline(fit["vf"], color="#666666", linestyle=":", linewidth=1.0)
ax.set_xlabel("交通密度 kveh/km")
ax.set_ylabel("平均速度 vkm/h")
style_axis(ax)
ax.legend(
loc="upper right",
frameon=True,
fancybox=False,
edgecolor="#666666",
borderpad=0.25,
handlelength=1.8,
handletextpad=0.5,
labelspacing=0.3,
)
fig.tight_layout()
fig.savefig(output_png, dpi=PAPER_DPI, bbox_inches="tight")
plt.close(fig)
def save_flow_speed_plot(
merged: pd.DataFrame,
fit: dict[str, float],
output_png: Path,
):
_, speed_grid, flow_grid = build_curve(fit)
order = np.argsort(speed_grid)
fig, ax = plt.subplots(figsize=PAPER_FIGSIZE)
ax.scatter(
merged["v_kmh"],
merged["q_vph"],
s=18,
facecolors="white",
edgecolors="#b26a00",
linewidths=0.8,
alpha=0.95,
label="观测值",
)
ax.plot(speed_grid[order], flow_grid[order], color="#111111", linewidth=1.6, label="拟合曲线")
ax.set_xlabel("平均速度 vkm/h")
ax.set_ylabel("交通流率 qveh/h")
style_axis(ax)
ax.legend(
loc="upper left",
frameon=True,
fancybox=False,
edgecolor="#666666",
borderpad=0.25,
handlelength=1.8,
handletextpad=0.5,
labelspacing=0.3,
)
fig.tight_layout()
fig.savefig(output_png, dpi=PAPER_DPI, bbox_inches="tight")
plt.close(fig)
def build_tag(dates: list[str]) -> str:
if len(dates) == 1:
return dates[0]
return f"{dates[0]}_to_{dates[-1]}"
def main():
args = parse_args()
speed_csv = Path(args.speed_csv)
route_dir = Path(args.route_dir)
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
speed = load_speed_profile(speed_csv)
dates = select_dates(speed, route_dir, args.date)
flow = build_flow_dataset(route_dir, args.edge_id, dates)
merged = build_merged_dataset(flow, speed, dates)
fit = fit_greenshields(merged)
configure_paper_style()
tag = build_tag(dates)
merged_csv = output_dir / f"fundamental_diagram_data_{args.edge_id}_{tag}.csv"
metrics_txt = output_dir / f"fundamental_diagram_metrics_{args.edge_id}_{tag}.txt"
qk_png = output_dir / f"fundamental_diagram_qk_{args.edge_id}_{tag}.png"
vk_png = output_dir / f"fundamental_diagram_vk_{args.edge_id}_{tag}.png"
qv_png = output_dir / f"fundamental_diagram_qv_{args.edge_id}_{tag}.png"
merged.to_csv(merged_csv, index=False, encoding="utf-8-sig")
save_flow_density_plot(merged, fit, args.edge_id, tag, qk_png)
save_speed_density_plot(merged, fit, vk_png)
save_flow_speed_plot(merged, fit, qv_png)
metrics_lines = [
f"edge_id={args.edge_id}",
f"period={tag}",
f"samples={fit['sample_count']}",
f"vf_kmh={fit['vf']:.6f}",
f"kj_veh_per_km={fit['kj']:.6f}",
f"kc_veh_per_km={fit['kc']:.6f}",
f"capacity_veh_per_h={fit['capacity']:.6f}",
f"observed_max_flow_veh_per_h={fit['observed_max_flow_veh_h']:.6f}",
f"vk_r2={fit['r2']:.6f}",
]
metrics_txt.write_text("\n".join(metrics_lines), encoding="utf-8")
print(f"Using edge: {args.edge_id}")
print(f"Using dates: {', '.join(dates)}")
print(f"Saved merged data: {merged_csv}")
print(f"Saved metrics: {metrics_txt}")
print(f"Saved q-k plot: {qk_png}")
print(f"Saved v-k plot: {vk_png}")
print(f"Saved q-v plot: {qv_png}")
if __name__ == "__main__":
main()