From c126e4a9f5654cba9244062ef6cf7863cd5d3dab Mon Sep 17 00:00:00 2001 From: Maple-YZ Date: Thu, 16 Apr 2026 11:01:53 +0800 Subject: [PATCH] =?UTF-8?q?=E7=BB=98=E5=88=B6=E8=AE=AD=E7=BB=83=E8=BF=87?= =?UTF-8?q?=E7=A8=8B=E7=83=AD=E5=8A=9B=E5=9B=BE=E6=97=B6=E4=BD=BF=E7=94=A8?= =?UTF-8?q?=E4=BC=A0=E6=84=9F=E5=99=A8=E6=95=B0=E6=8D=AE?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- utils/episode_artifacts.py | 203 +++++++++++++++++++++++++++++++++---- 1 file changed, 181 insertions(+), 22 deletions(-) diff --git a/utils/episode_artifacts.py b/utils/episode_artifacts.py index e2140d2..390b8c3 100644 --- a/utils/episode_artifacts.py +++ b/utils/episode_artifacts.py @@ -3,6 +3,7 @@ import csv import json import os from typing import Dict, Iterable, List, Optional, Sequence, Tuple +import xml.etree.ElementTree as ET import matplotlib import numpy as np @@ -80,32 +81,173 @@ def _write_csv(path: str, rows: Iterable[Dict], fieldnames: Sequence[str]): writer.writerows(rows) -def _plot_episode_heatmap(path: str, edge_rows: Sequence[Dict], edge_ids: Sequence[str], title_prefix: str): - if not edge_rows or not edge_ids: +def _lane_to_edge_id(lane_id: str) -> str: + if "_" not in lane_id: + return lane_id + return lane_id.rsplit("_", 1)[0] + + +def _parse_detector_layout(additional_path: str) -> List[Dict]: + if not os.path.isfile(additional_path): + return [] + + try: + root = ET.parse(additional_path).getroot() + except ET.ParseError: + return [] + + cells: List[Dict] = [] + current_key = None + current_cell: Optional[Dict] = None + edge_pos_counts: Dict[str, int] = {} + + for elem in root.findall("inductionLoop"): + det_id = elem.get("id") + lane_id = elem.get("lane") + pos_raw = elem.get("pos") + if not det_id or not lane_id or pos_raw is None: + continue + + edge_id = _lane_to_edge_id(lane_id) + position_m = _safe_float(pos_raw) + key = (edge_id, round(position_m, 3)) + + if key != current_key: + pos_index = edge_pos_counts.get(edge_id, 0) + edge_pos_counts[edge_id] = pos_index + 1 + current_cell = { + "cell_id": f"{edge_id}@{pos_index}", + "edge_id": edge_id, + "pos_index": pos_index, + "position_m": position_m, + "detector_ids": [], + } + cells.append(current_cell) + current_key = key + + current_cell["detector_ids"].append(det_id) + + return cells + + +def _build_detector_rows_from_xml(log_dir: str, episode: int) -> List[Dict]: + metrics_dir = os.path.join(log_dir, "sumo_metrics") + suffix = f"ep{episode:04d}" + additional_path = os.path.join(metrics_dir, f"runtime_metrics_il_{suffix}.add.xml") + metrics_path = os.path.join(metrics_dir, f"metrics_il_output_{suffix}.xml") + cells = _parse_detector_layout(additional_path) + if not cells or not os.path.isfile(metrics_path): + return [] + + try: + root = ET.parse(metrics_path).getroot() + except ET.ParseError: + return [] + + interval_map: Dict[Tuple[float, float], Dict[str, Dict[str, float]]] = {} + for elem in root.findall("interval"): + det_id = elem.get("id") + if not det_id: + continue + + begin = _safe_float(elem.get("begin")) + end = _safe_float(elem.get("end")) + if not np.isfinite(begin) or not np.isfinite(end): + continue + + interval_key = (begin, end) + interval_values = interval_map.setdefault(interval_key, {}) + interval_values[det_id] = { + "speed_ms": _safe_float(elem.get("speed")), + "flow_vehph": _safe_float(elem.get("flow"), default=0.0), + "n_veh_contrib": _safe_float(elem.get("nVehContrib"), default=0.0), + "occupancy": _safe_float(elem.get("occupancy"), default=0.0), + } + + detector_rows: List[Dict] = [] + sorted_intervals = sorted(interval_map.keys()) + for step_idx, interval_key in enumerate(sorted_intervals, start=1): + interval_values = interval_map[interval_key] + for cell_order, cell in enumerate(cells): + speed_weight_sum = 0.0 + speed_weight_total = 0.0 + density_sum = 0.0 + + for det_id in cell["detector_ids"]: + lane_metrics = interval_values.get(det_id) + if not lane_metrics: + continue + + speed_ms = lane_metrics["speed_ms"] + flow_vehph = max(lane_metrics["flow_vehph"], 0.0) + n_veh_contrib = max(lane_metrics["n_veh_contrib"], 0.0) + + if np.isfinite(speed_ms) and speed_ms > 0.0: + speed_weight = n_veh_contrib if n_veh_contrib > 0.0 else flow_vehph + if speed_weight <= 0.0: + speed_weight = 1.0 + speed_weight_sum += speed_ms * 3.6 * speed_weight + speed_weight_total += speed_weight + density_sum += flow_vehph / max(speed_ms * 3.6, 1e-6) + + measured_speed_kmh = ( + speed_weight_sum / speed_weight_total + if speed_weight_total > 0.0 + else np.nan + ) + + detector_rows.append( + { + "episode": episode, + "step": step_idx, + "cell_order": cell_order, + "cell_id": cell["cell_id"], + "edge_id": cell["edge_id"], + "pos_index": int(cell["pos_index"]), + "position_m": float(cell["position_m"]), + "measured_speed_kmh": float(measured_speed_kmh) if np.isfinite(measured_speed_kmh) else np.nan, + "density_vehpkm": float(density_sum), + } + ) + + return detector_rows + + +def _plot_episode_heatmap(path: str, detector_rows: Sequence[Dict], title_prefix: str): + if not detector_rows: return - step_values = sorted({int(row["step"]) for row in edge_rows}) - num_edges = len(edge_ids) + step_values = sorted({int(row["step"]) for row in detector_rows}) + ordered_cells = [] + seen_cells = set() + for row in sorted( + detector_rows, + key=lambda item: (int(item["cell_order"]), str(item["cell_id"])), + ): + cell_id = str(row["cell_id"]) + if cell_id in seen_cells: + continue + seen_cells.add(cell_id) + ordered_cells.append(cell_id) + + num_cells = len(ordered_cells) num_steps = len(step_values) step_to_col = {step: idx for idx, step in enumerate(step_values)} - edge_to_row = {edge_id: idx for idx, edge_id in enumerate(edge_ids)} + cell_to_row = {cell_id: idx for idx, cell_id in enumerate(ordered_cells)} - action_grid = np.full((num_edges, num_steps), np.nan, dtype=np.float32) - speed_grid = np.full((num_edges, num_steps), np.nan, dtype=np.float32) - occupancy_grid = np.full((num_edges, num_steps), np.nan, dtype=np.float32) + speed_grid = np.full((num_cells, num_steps), np.nan, dtype=np.float32) + density_grid = np.full((num_cells, num_steps), np.nan, dtype=np.float32) - for row in edge_rows: - row_idx = edge_to_row[row["edge_id"]] + for row in detector_rows: + row_idx = cell_to_row[str(row["cell_id"])] col_idx = step_to_col[int(row["step"])] - action_grid[row_idx, col_idx] = _safe_float(row["action_speed_kmh"]) speed_grid[row_idx, col_idx] = _safe_float(row["measured_speed_kmh"]) - occupancy_grid[row_idx, col_idx] = _safe_float(row["occupancy"]) + density_grid[row_idx, col_idx] = _safe_float(row["density_vehpkm"]) - fig, axes = plt.subplots(1, 3, figsize=(18, 7), sharex=True, sharey=True) + fig, axes = plt.subplots(1, 2, figsize=(18, 7), sharex=True, sharey=True) plots = [ - (action_grid, "viridis", "Applied VSL (km/h)"), (speed_grid, "RdYlGn", "Measured Speed (km/h)"), - (occupancy_grid, "magma", "Occupancy (%)"), + (density_grid, "YlOrRd", "Density (veh/km)"), ] for ax, (grid, cmap, title) in zip(axes, plots): @@ -119,9 +261,11 @@ def _plot_episode_heatmap(path: str, edge_rows: Sequence[Dict], edge_ids: Sequen ) ax.set_title(f"{title_prefix} {title}") ax.set_xlabel("Decision Step") - ax.set_ylabel("Controlled Edge") - ax.set_yticks(np.arange(num_edges)) - ax.set_yticklabels(edge_ids) + ax.set_ylabel("Detector Cell") + tick_step = max(num_cells // 12, 1) + tick_idx = np.arange(0, num_cells, tick_step) + ax.set_yticks(tick_idx) + ax.set_yticklabels([ordered_cells[idx] for idx in tick_idx], fontsize=8) plt.colorbar(image, ax=ax, fraction=0.046, pad=0.04) plt.tight_layout() @@ -137,15 +281,29 @@ def _write_summary(path: str, summary: Optional[Dict]): json.dump(summary, f, ensure_ascii=False, indent=2) -def _save_episode_bundle(bundle_dir: str, episode: int, step_rows: Sequence[Dict], edge_rows: Sequence[Dict], edge_ids: Sequence[str], summary: Optional[Dict]): +def _save_episode_bundle( + bundle_dir: str, + episode: int, + step_rows: Sequence[Dict], + edge_rows: Sequence[Dict], + edge_ids: Sequence[str], + detector_rows: Sequence[Dict], + summary: Optional[Dict], +): step_csv_path = os.path.join(bundle_dir, "step_metrics.csv") edge_csv_path = os.path.join(bundle_dir, "edge_metrics.csv") + detector_csv_path = os.path.join(bundle_dir, "detector_metrics.csv") heatmap_path = os.path.join(bundle_dir, "episode_heatmap.png") summary_path = os.path.join(bundle_dir, "summary.json") _write_csv(step_csv_path, step_rows, fieldnames=list(step_rows[0].keys()) if step_rows else ["episode", "step"]) _write_csv(edge_csv_path, edge_rows, fieldnames=list(edge_rows[0].keys()) if edge_rows else ["episode", "step", "edge_index"]) - _plot_episode_heatmap(heatmap_path, edge_rows, edge_ids, title_prefix=f"Episode {episode}") + _write_csv( + detector_csv_path, + detector_rows, + fieldnames=list(detector_rows[0].keys()) if detector_rows else ["episode", "step", "cell_id"], + ) + _plot_episode_heatmap(heatmap_path, detector_rows, title_prefix=f"Episode {episode}") _write_summary(summary_path, summary) @@ -161,6 +319,7 @@ def save_training_episode_artifacts( return step_rows, edge_rows, edge_ids = _normalize_step_rows(episode, episode_metrics) + detector_rows = _build_detector_rows_from_xml(log_dir, episode) if control_edges and len(control_edges) >= len(edge_ids): edge_ids = list(control_edges[: len(edge_ids)]) for row in edge_rows: @@ -168,8 +327,8 @@ def save_training_episode_artifacts( artifacts_root = os.path.join(log_dir, "episode_artifacts") latest_dir = os.path.join(artifacts_root, "latest") - _save_episode_bundle(latest_dir, episode, step_rows, edge_rows, edge_ids, summary) + _save_episode_bundle(latest_dir, episode, step_rows, edge_rows, edge_ids, detector_rows, summary) if snapshot_interval > 0 and episode % snapshot_interval == 0: snapshot_dir = os.path.join(artifacts_root, "snapshots", f"episode_{episode:04d}") - _save_episode_bundle(snapshot_dir, episode, step_rows, edge_rows, edge_ids, summary) + _save_episode_bundle(snapshot_dir, episode, step_rows, edge_rows, edge_ids, detector_rows, summary)