"""Utilities for persisting per-episode training artifacts.""" import csv import json import os from typing import Dict, Iterable, List, Optional, Sequence, Tuple import xml.etree.ElementTree as ET import matplotlib import numpy as np matplotlib.use("Agg") from envs.reward_system import REWARD_COMPONENT_COLUMNS from utils.heatmap_plotting import ( build_action_panel, build_occupancy_panel, build_speed_panel, save_heatmap_panels, ) def _safe_float(value, default: float = float("nan")) -> float: try: return float(value) except (TypeError, ValueError): return default def _normalize_step_rows(episode: int, episode_metrics: Sequence[Dict]) -> Tuple[List[Dict], List[Dict], List[str]]: step_rows: List[Dict] = [] edge_rows: List[Dict] = [] edge_ids: List[str] = [] for step_idx, info in enumerate(episode_metrics, start=1): step_value = int(info.get("step", step_idx)) step_row = { "episode": episode, "step": step_value, "sim_time": _safe_float(info.get("sim_time")), "reward": _safe_float(info.get("reward")), "throughput": _safe_float(info.get("throughput")), "arrived_count": int(info.get("arrived_count", 0)), "departed_count": int(info.get("departed_count", 0)), "mean_speed_kmh": _safe_float(info.get("mean_speed_kmh")), "speed_variance_norm": _safe_float(info.get("speed_variance_norm")), "mean_occupancy": _safe_float(info.get("mean_occupancy")), "density": _safe_float(info.get("density")), "num_vehicles": int(info.get("num_vehicles", 0)), "ttc_risk_rate": _safe_float(info.get("ttc_risk_rate", 0.0), default=0.0), "ttc_min_s": _safe_float(info.get("ttc_min_s")), "incident_enabled": bool(info.get("incident_enabled", False)), "incident_pending": bool(info.get("incident_pending", False)), "incident_commanded": bool(info.get("incident_commanded", False)), "incident_active": bool(info.get("incident_active", False)), "incident_completed": bool(info.get("incident_completed", False)), "incident_vehicle_id": str(info.get("incident_vehicle_id", "")), "incident_target_edge_id": str(info.get("incident_target_edge_id", "")), "incident_target_position_m": _safe_float(info.get("incident_target_position_m")), "incident_target_distance_m": _safe_float(info.get("incident_target_distance_m")), "incident_trigger_time_s": _safe_float(info.get("incident_trigger_time_s")), "incident_duration_s": _safe_float(info.get("incident_duration_s")), "incident_command_time_s": _safe_float(info.get("incident_command_time_s")), "incident_blocking_start_time_s": _safe_float(info.get("incident_blocking_start_time_s")), "incident_release_time_s": _safe_float(info.get("incident_release_time_s")), "incident_released_time_s": _safe_float(info.get("incident_released_time_s")), "incident_lane_index": int(info.get("incident_lane_index", -1)), } for column in REWARD_COMPONENT_COLUMNS: step_row[column] = _safe_float(info.get(column)) step_rows.append(step_row) action_speeds = list(info.get("edge_speeds_kmh", [])) measured_speeds_ms = list(info.get("edge_speeds_ms", [])) occupancies = list(info.get("edge_occupancies", [])) action_applied_mask = list(info.get("action_applied_mask", [])) edge_count = max(len(action_speeds), len(measured_speeds_ms), len(occupancies), len(action_applied_mask)) if not edge_ids: edge_ids = [f"edge_{idx:02d}" for idx in range(edge_count)] for edge_idx in range(edge_count): edge_rows.append( { "episode": episode, "step": step_value, "edge_index": edge_idx, "edge_id": edge_ids[edge_idx], "action_speed_kmh": _safe_float(action_speeds[edge_idx] if edge_idx < len(action_speeds) else np.nan), "action_applied": bool(action_applied_mask[edge_idx]) if edge_idx < len(action_applied_mask) else True, "measured_speed_kmh": _safe_float( measured_speeds_ms[edge_idx] * 3.6 if edge_idx < len(measured_speeds_ms) else np.nan ), "occupancy": _safe_float(occupancies[edge_idx] if edge_idx < len(occupancies) else np.nan), } ) return step_rows, edge_rows, edge_ids def _write_csv(path: str, rows: Iterable[Dict], fieldnames: Sequence[str]): os.makedirs(os.path.dirname(path), exist_ok=True) with open(path, "w", newline="", encoding="utf-8") as f: writer = csv.DictWriter(f, fieldnames=fieldnames) writer.writeheader() writer.writerows(rows) def _lane_to_edge_id(lane_id: str) -> str: if "_" not in lane_id: return lane_id return lane_id.rsplit("_", 1)[0] def _parse_detector_layout(additional_path: str) -> List[Dict]: if not os.path.isfile(additional_path): return [] try: root = ET.parse(additional_path).getroot() except ET.ParseError: return [] cells: List[Dict] = [] current_key = None current_cell: Optional[Dict] = None edge_pos_counts: Dict[str, int] = {} for elem in root.findall("inductionLoop"): det_id = elem.get("id") lane_id = elem.get("lane") pos_raw = elem.get("pos") if not det_id or not lane_id or pos_raw is None: continue edge_id = _lane_to_edge_id(lane_id) position_m = _safe_float(pos_raw) key = (edge_id, round(position_m, 3)) if key != current_key: pos_index = edge_pos_counts.get(edge_id, 0) edge_pos_counts[edge_id] = pos_index + 1 current_cell = { "cell_id": f"{edge_id}@{pos_index}", "edge_id": edge_id, "pos_index": pos_index, "position_m": position_m, "detector_ids": [], } cells.append(current_cell) current_key = key current_cell["detector_ids"].append(det_id) return cells def _build_detector_rows_from_xml(log_dir: str, episode: int) -> List[Dict]: metrics_dir = os.path.join(log_dir, "sumo_metrics") suffix = f"ep{episode:04d}" additional_path = os.path.join(metrics_dir, f"runtime_metrics_il_{suffix}.add.xml") metrics_path = os.path.join(metrics_dir, f"metrics_il_output_{suffix}.xml") cells = _parse_detector_layout(additional_path) if not cells or not os.path.isfile(metrics_path): return [] try: root = ET.parse(metrics_path).getroot() except ET.ParseError: return [] interval_map: Dict[Tuple[float, float], Dict[str, Dict[str, float]]] = {} for elem in root.findall("interval"): det_id = elem.get("id") if not det_id: continue begin = _safe_float(elem.get("begin")) end = _safe_float(elem.get("end")) if not np.isfinite(begin) or not np.isfinite(end): continue interval_key = (begin, end) interval_values = interval_map.setdefault(interval_key, {}) interval_values[det_id] = { "speed_ms": _safe_float(elem.get("speed")), "flow_vehph": _safe_float(elem.get("flow"), default=0.0), "n_veh_contrib": _safe_float(elem.get("nVehContrib"), default=0.0), "occupancy": _safe_float(elem.get("occupancy"), default=0.0), } detector_rows: List[Dict] = [] sorted_intervals = sorted(interval_map.keys()) for step_idx, interval_key in enumerate(sorted_intervals, start=1): interval_values = interval_map[interval_key] for cell_order, cell in enumerate(cells): speed_weight_sum = 0.0 speed_weight_total = 0.0 density_sum = 0.0 occupancy_sum = 0.0 occupancy_count = 0 for det_id in cell["detector_ids"]: lane_metrics = interval_values.get(det_id) if not lane_metrics: continue speed_ms = lane_metrics["speed_ms"] flow_vehph = max(lane_metrics["flow_vehph"], 0.0) n_veh_contrib = max(lane_metrics["n_veh_contrib"], 0.0) occupancy = max(lane_metrics["occupancy"], 0.0) if np.isfinite(occupancy): occupancy_sum += occupancy occupancy_count += 1 if np.isfinite(speed_ms) and speed_ms > 0.0: speed_weight = n_veh_contrib if n_veh_contrib > 0.0 else flow_vehph if speed_weight <= 0.0: speed_weight = 1.0 speed_weight_sum += speed_ms * 3.6 * speed_weight speed_weight_total += speed_weight density_sum += flow_vehph / max(speed_ms * 3.6, 1e-6) measured_speed_kmh = ( speed_weight_sum / speed_weight_total if speed_weight_total > 0.0 else np.nan ) occupancy_pct = ( occupancy_sum / occupancy_count if occupancy_count > 0 else np.nan ) detector_rows.append( { "episode": episode, "step": step_idx, "cell_order": cell_order, "cell_id": cell["cell_id"], "edge_id": cell["edge_id"], "pos_index": int(cell["pos_index"]), "position_m": float(cell["position_m"]), "measured_speed_kmh": float(measured_speed_kmh) if np.isfinite(measured_speed_kmh) else np.nan, "occupancy": float(occupancy_pct) if np.isfinite(occupancy_pct) else np.nan, "density_vehpkm": float(density_sum), } ) return detector_rows def _align_detector_rows_to_decision_steps( detector_rows: Sequence[Dict], step_rows: Sequence[Dict], ) -> List[Dict]: if not detector_rows or not step_rows: return list(detector_rows) detector_step_values = sorted({int(row["step"]) for row in detector_rows}) decision_step_values = sorted({int(row["step"]) for row in step_rows}) if not detector_step_values or not decision_step_values: return list(detector_rows) excess_steps = len(detector_step_values) - len(decision_step_values) if excess_steps <= 0: return list(detector_rows) kept_detector_steps = detector_step_values[excess_steps:] step_mapping = { old_step: new_step for old_step, new_step in zip(kept_detector_steps, decision_step_values) } aligned_rows: List[Dict] = [] for row in detector_rows: old_step = int(row["step"]) if old_step not in step_mapping: continue aligned_row = dict(row) aligned_row["step"] = step_mapping[old_step] aligned_rows.append(aligned_row) return aligned_rows def _plot_episode_heatmap( path: str, edge_rows: Sequence[Dict], edge_ids: Sequence[str], detector_rows: Sequence[Dict], title_prefix: str, ): has_action = bool(edge_rows and edge_ids) has_detector = bool(detector_rows) if not has_action and not has_detector: return step_values = sorted( { int(row["step"]) for row in list(edge_rows) + list(detector_rows) if row.get("step") is not None } ) if not step_values: return num_steps = len(step_values) step_to_col = {step: idx for idx, step in enumerate(step_values)} panels = [] if has_action: ordered_edge_ids = list(edge_ids) if not ordered_edge_ids: ordered_edge_ids = [ str(row["edge_id"]) for row in sorted(edge_rows, key=lambda item: int(item["edge_index"])) ] ordered_edge_ids = list(dict.fromkeys(ordered_edge_ids)) edge_to_row = {edge_id: idx for idx, edge_id in enumerate(ordered_edge_ids)} action_grid = np.full((len(ordered_edge_ids), num_steps), np.nan, dtype=np.float32) for row in edge_rows: edge_id = str(row["edge_id"]) if edge_id not in edge_to_row: continue row_idx = edge_to_row[edge_id] col_idx = step_to_col[int(row["step"])] action_grid[row_idx, col_idx] = _safe_float(row["action_speed_kmh"]) if ordered_edge_ids: panels.append( build_action_panel( action_grid, ordered_edge_ids, f"{title_prefix} Segment Speed Limit (km/h)", ylabel="Corridor Segment", ) ) if has_detector: ordered_cells = [] seen_cells = set() for row in sorted( detector_rows, key=lambda item: (int(item["cell_order"]), str(item["cell_id"])), ): cell_id = str(row["cell_id"]) if cell_id in seen_cells: continue seen_cells.add(cell_id) ordered_cells.append(cell_id) num_cells = len(ordered_cells) cell_to_row = {cell_id: idx for idx, cell_id in enumerate(ordered_cells)} speed_grid = np.full((num_cells, num_steps), np.nan, dtype=np.float32) occupancy_grid = np.full((num_cells, num_steps), np.nan, dtype=np.float32) for row in detector_rows: row_idx = cell_to_row[str(row["cell_id"])] col_idx = step_to_col[int(row["step"])] speed_grid[row_idx, col_idx] = _safe_float(row["measured_speed_kmh"]) occupancy_grid[row_idx, col_idx] = _safe_float(row["occupancy"]) plots = [ build_speed_panel( speed_grid, ordered_cells, f"{title_prefix} Measured Speed (km/h)", "Detector Cell", ), build_occupancy_panel( occupancy_grid, ordered_cells, f"{title_prefix} Occupancy (%)", "Detector Cell", ), ] panels = plots[:1] + panels + plots[1:] save_heatmap_panels(path, panels, xlabel="Decision Step") def _write_summary(path: str, summary: Optional[Dict]): if not summary: return os.makedirs(os.path.dirname(path), exist_ok=True) with open(path, "w", encoding="utf-8") as f: json.dump(summary, f, ensure_ascii=False, indent=2) def _save_episode_bundle( bundle_dir: str, episode: int, step_rows: Sequence[Dict], edge_rows: Sequence[Dict], edge_ids: Sequence[str], detector_rows: Sequence[Dict], summary: Optional[Dict], ): step_csv_path = os.path.join(bundle_dir, "step_metrics.csv") edge_csv_path = os.path.join(bundle_dir, "edge_metrics.csv") detector_csv_path = os.path.join(bundle_dir, "detector_metrics.csv") heatmap_path = os.path.join(bundle_dir, "episode_heatmap.png") summary_path = os.path.join(bundle_dir, "summary.json") _write_csv(step_csv_path, step_rows, fieldnames=list(step_rows[0].keys()) if step_rows else ["episode", "step"]) _write_csv(edge_csv_path, edge_rows, fieldnames=list(edge_rows[0].keys()) if edge_rows else ["episode", "step", "edge_index"]) _write_csv( detector_csv_path, detector_rows, fieldnames=list(detector_rows[0].keys()) if detector_rows else ["episode", "step", "cell_id"], ) _plot_episode_heatmap(heatmap_path, edge_rows, edge_ids, detector_rows, title_prefix=f"Episode {episode}") _write_summary(summary_path, summary) def save_training_episode_artifacts( log_dir: str, episode: int, episode_metrics: Sequence[Dict], control_edges: Sequence[str], summary: Optional[Dict] = None, snapshot_interval: int = 50, ): if not episode_metrics: return step_rows, edge_rows, edge_ids = _normalize_step_rows(episode, episode_metrics) detector_rows = _build_detector_rows_from_xml(log_dir, episode) detector_rows = _align_detector_rows_to_decision_steps(detector_rows, step_rows) if control_edges and len(control_edges) >= len(edge_ids): edge_ids = list(control_edges[: len(edge_ids)]) for row in edge_rows: row["edge_id"] = edge_ids[int(row["edge_index"])] artifacts_root = os.path.join(log_dir, "episode_artifacts") latest_dir = os.path.join(artifacts_root, "latest") _save_episode_bundle(latest_dir, episode, step_rows, edge_rows, edge_ids, detector_rows, summary) if snapshot_interval > 0 and episode % snapshot_interval == 0: snapshot_dir = os.path.join(artifacts_root, "snapshots", f"episode_{episode:04d}") _save_episode_bundle(snapshot_dir, episode, step_rows, edge_rows, edge_ids, detector_rows, summary)