"""Utilities for persisting per-episode training artifacts.""" import csv import json import os from typing import Dict, Iterable, List, Optional, Sequence, Tuple import xml.etree.ElementTree as ET import matplotlib import numpy as np matplotlib.use("Agg") from envs.reward_system import REWARD_COMPONENT_COLUMNS from utils.heatmap_plotting import ( build_action_panel, build_density_panel, build_speed_panel, save_heatmap_panels, ) def _safe_float(value, default: float = float("nan")) -> float: try: return float(value) except (TypeError, ValueError): return default def _normalize_step_rows(episode: int, episode_metrics: Sequence[Dict]) -> Tuple[List[Dict], List[Dict], List[str]]: step_rows: List[Dict] = [] edge_rows: List[Dict] = [] edge_ids: List[str] = [] for step_idx, info in enumerate(episode_metrics, start=1): step_value = int(info.get("step", step_idx)) step_row = { "episode": episode, "step": step_value, "sim_time": _safe_float(info.get("sim_time")), "reward": _safe_float(info.get("reward")), "throughput": _safe_float(info.get("throughput")), "arrived_count": int(info.get("arrived_count", 0)), "departed_count": int(info.get("departed_count", 0)), "mean_speed_kmh": _safe_float(info.get("mean_speed_kmh")), "speed_variance_norm": _safe_float(info.get("speed_variance_norm")), "mean_occupancy": _safe_float(info.get("mean_occupancy")), "density": _safe_float(info.get("density")), "num_vehicles": int(info.get("num_vehicles", 0)), "num_stops": _safe_float(info.get("num_stops", 0.0), default=0.0), } for column in REWARD_COMPONENT_COLUMNS: step_row[column] = _safe_float(info.get(column)) step_rows.append(step_row) action_speeds = list(info.get("edge_speeds_kmh", [])) measured_speeds_ms = list(info.get("edge_speeds_ms", [])) occupancies = list(info.get("edge_occupancies", [])) action_applied_mask = list(info.get("action_applied_mask", [])) edge_count = max(len(action_speeds), len(measured_speeds_ms), len(occupancies), len(action_applied_mask)) if not edge_ids: edge_ids = [f"edge_{idx:02d}" for idx in range(edge_count)] for edge_idx in range(edge_count): edge_rows.append( { "episode": episode, "step": step_value, "edge_index": edge_idx, "edge_id": edge_ids[edge_idx], "action_speed_kmh": _safe_float(action_speeds[edge_idx] if edge_idx < len(action_speeds) else np.nan), "action_applied": bool(action_applied_mask[edge_idx]) if edge_idx < len(action_applied_mask) else True, "measured_speed_kmh": _safe_float( measured_speeds_ms[edge_idx] * 3.6 if edge_idx < len(measured_speeds_ms) else np.nan ), "occupancy": _safe_float(occupancies[edge_idx] if edge_idx < len(occupancies) else np.nan), } ) return step_rows, edge_rows, edge_ids def _write_csv(path: str, rows: Iterable[Dict], fieldnames: Sequence[str]): os.makedirs(os.path.dirname(path), exist_ok=True) with open(path, "w", newline="", encoding="utf-8") as f: writer = csv.DictWriter(f, fieldnames=fieldnames) writer.writeheader() writer.writerows(rows) def _lane_to_edge_id(lane_id: str) -> str: if "_" not in lane_id: return lane_id return lane_id.rsplit("_", 1)[0] def _parse_detector_layout(additional_path: str) -> List[Dict]: if not os.path.isfile(additional_path): return [] try: root = ET.parse(additional_path).getroot() except ET.ParseError: return [] cells: List[Dict] = [] current_key = None current_cell: Optional[Dict] = None edge_pos_counts: Dict[str, int] = {} for elem in root.findall("inductionLoop"): det_id = elem.get("id") lane_id = elem.get("lane") pos_raw = elem.get("pos") if not det_id or not lane_id or pos_raw is None: continue edge_id = _lane_to_edge_id(lane_id) position_m = _safe_float(pos_raw) key = (edge_id, round(position_m, 3)) if key != current_key: pos_index = edge_pos_counts.get(edge_id, 0) edge_pos_counts[edge_id] = pos_index + 1 current_cell = { "cell_id": f"{edge_id}@{pos_index}", "edge_id": edge_id, "pos_index": pos_index, "position_m": position_m, "detector_ids": [], } cells.append(current_cell) current_key = key current_cell["detector_ids"].append(det_id) return cells def _build_detector_rows_from_xml(log_dir: str, episode: int) -> List[Dict]: metrics_dir = os.path.join(log_dir, "sumo_metrics") suffix = f"ep{episode:04d}" additional_path = os.path.join(metrics_dir, f"runtime_metrics_il_{suffix}.add.xml") metrics_path = os.path.join(metrics_dir, f"metrics_il_output_{suffix}.xml") cells = _parse_detector_layout(additional_path) if not cells or not os.path.isfile(metrics_path): return [] try: root = ET.parse(metrics_path).getroot() except ET.ParseError: return [] interval_map: Dict[Tuple[float, float], Dict[str, Dict[str, float]]] = {} for elem in root.findall("interval"): det_id = elem.get("id") if not det_id: continue begin = _safe_float(elem.get("begin")) end = _safe_float(elem.get("end")) if not np.isfinite(begin) or not np.isfinite(end): continue interval_key = (begin, end) interval_values = interval_map.setdefault(interval_key, {}) interval_values[det_id] = { "speed_ms": _safe_float(elem.get("speed")), "flow_vehph": _safe_float(elem.get("flow"), default=0.0), "n_veh_contrib": _safe_float(elem.get("nVehContrib"), default=0.0), "occupancy": _safe_float(elem.get("occupancy"), default=0.0), } detector_rows: List[Dict] = [] sorted_intervals = sorted(interval_map.keys()) for step_idx, interval_key in enumerate(sorted_intervals, start=1): interval_values = interval_map[interval_key] for cell_order, cell in enumerate(cells): speed_weight_sum = 0.0 speed_weight_total = 0.0 density_sum = 0.0 for det_id in cell["detector_ids"]: lane_metrics = interval_values.get(det_id) if not lane_metrics: continue speed_ms = lane_metrics["speed_ms"] flow_vehph = max(lane_metrics["flow_vehph"], 0.0) n_veh_contrib = max(lane_metrics["n_veh_contrib"], 0.0) if np.isfinite(speed_ms) and speed_ms > 0.0: speed_weight = n_veh_contrib if n_veh_contrib > 0.0 else flow_vehph if speed_weight <= 0.0: speed_weight = 1.0 speed_weight_sum += speed_ms * 3.6 * speed_weight speed_weight_total += speed_weight density_sum += flow_vehph / max(speed_ms * 3.6, 1e-6) measured_speed_kmh = ( speed_weight_sum / speed_weight_total if speed_weight_total > 0.0 else np.nan ) detector_rows.append( { "episode": episode, "step": step_idx, "cell_order": cell_order, "cell_id": cell["cell_id"], "edge_id": cell["edge_id"], "pos_index": int(cell["pos_index"]), "position_m": float(cell["position_m"]), "measured_speed_kmh": float(measured_speed_kmh) if np.isfinite(measured_speed_kmh) else np.nan, "density_vehpkm": float(density_sum), } ) return detector_rows def _plot_episode_heatmap( path: str, edge_rows: Sequence[Dict], edge_ids: Sequence[str], detector_rows: Sequence[Dict], title_prefix: str, ): has_action = bool(edge_rows and edge_ids) has_detector = bool(detector_rows) if not has_action and not has_detector: return step_values = sorted( { int(row["step"]) for row in list(edge_rows) + list(detector_rows) if row.get("step") is not None } ) if not step_values: return num_steps = len(step_values) step_to_col = {step: idx for idx, step in enumerate(step_values)} panels = [] if has_action: ordered_edge_ids = list(edge_ids) edge_to_row = {edge_id: idx for idx, edge_id in enumerate(ordered_edge_ids)} action_grid = np.full((len(ordered_edge_ids), num_steps), np.nan, dtype=np.float32) for row in edge_rows: edge_id = str(row["edge_id"]) if edge_id not in edge_to_row: continue if not bool(row.get("action_applied", True)): continue row_idx = edge_to_row[edge_id] col_idx = step_to_col[int(row["step"])] action_grid[row_idx, col_idx] = _safe_float(row["action_speed_kmh"]) panels.append(build_action_panel(action_grid, ordered_edge_ids, f"{title_prefix} Applied VSL (km/h)")) if has_detector: ordered_cells = [] seen_cells = set() for row in sorted( detector_rows, key=lambda item: (int(item["cell_order"]), str(item["cell_id"])), ): cell_id = str(row["cell_id"]) if cell_id in seen_cells: continue seen_cells.add(cell_id) ordered_cells.append(cell_id) num_cells = len(ordered_cells) cell_to_row = {cell_id: idx for idx, cell_id in enumerate(ordered_cells)} speed_grid = np.full((num_cells, num_steps), np.nan, dtype=np.float32) density_grid = np.full((num_cells, num_steps), np.nan, dtype=np.float32) for row in detector_rows: row_idx = cell_to_row[str(row["cell_id"])] col_idx = step_to_col[int(row["step"])] speed_grid[row_idx, col_idx] = _safe_float(row["measured_speed_kmh"]) density_grid[row_idx, col_idx] = _safe_float(row["density_vehpkm"]) plots = [ build_speed_panel( speed_grid, ordered_cells, f"{title_prefix} Measured Speed (km/h)", "Detector Cell", ), build_density_panel( density_grid, ordered_cells, f"{title_prefix} Density (veh/km)", "Detector Cell", ), ] panels = plots[:1] + panels + plots[1:] save_heatmap_panels(path, panels, xlabel="Decision Step") def _write_summary(path: str, summary: Optional[Dict]): if not summary: return os.makedirs(os.path.dirname(path), exist_ok=True) with open(path, "w", encoding="utf-8") as f: json.dump(summary, f, ensure_ascii=False, indent=2) def _save_episode_bundle( bundle_dir: str, episode: int, step_rows: Sequence[Dict], edge_rows: Sequence[Dict], edge_ids: Sequence[str], detector_rows: Sequence[Dict], summary: Optional[Dict], ): step_csv_path = os.path.join(bundle_dir, "step_metrics.csv") edge_csv_path = os.path.join(bundle_dir, "edge_metrics.csv") detector_csv_path = os.path.join(bundle_dir, "detector_metrics.csv") heatmap_path = os.path.join(bundle_dir, "episode_heatmap.png") summary_path = os.path.join(bundle_dir, "summary.json") _write_csv(step_csv_path, step_rows, fieldnames=list(step_rows[0].keys()) if step_rows else ["episode", "step"]) _write_csv(edge_csv_path, edge_rows, fieldnames=list(edge_rows[0].keys()) if edge_rows else ["episode", "step", "edge_index"]) _write_csv( detector_csv_path, detector_rows, fieldnames=list(detector_rows[0].keys()) if detector_rows else ["episode", "step", "cell_id"], ) _plot_episode_heatmap(heatmap_path, edge_rows, edge_ids, detector_rows, title_prefix=f"Episode {episode}") _write_summary(summary_path, summary) def save_training_episode_artifacts( log_dir: str, episode: int, episode_metrics: Sequence[Dict], control_edges: Sequence[str], summary: Optional[Dict] = None, snapshot_interval: int = 50, ): if not episode_metrics: return step_rows, edge_rows, edge_ids = _normalize_step_rows(episode, episode_metrics) detector_rows = _build_detector_rows_from_xml(log_dir, episode) if control_edges and len(control_edges) >= len(edge_ids): edge_ids = list(control_edges[: len(edge_ids)]) for row in edge_rows: row["edge_id"] = edge_ids[int(row["edge_index"])] artifacts_root = os.path.join(log_dir, "episode_artifacts") latest_dir = os.path.join(artifacts_root, "latest") _save_episode_bundle(latest_dir, episode, step_rows, edge_rows, edge_ids, detector_rows, summary) if snapshot_interval > 0 and episode % snapshot_interval == 0: snapshot_dir = os.path.join(artifacts_root, "snapshots", f"episode_{episode:04d}") _save_episode_bundle(snapshot_dir, episode, step_rows, edge_rows, edge_ids, detector_rows, summary)