"""Utilities for persisting per-episode training artifacts.""" import csv import json import os from typing import Dict, Iterable, List, Optional, Sequence, Tuple import matplotlib import numpy as np matplotlib.use("Agg") import matplotlib.pyplot as plt def _safe_float(value, default: float = float("nan")) -> float: try: return float(value) except (TypeError, ValueError): return default def _normalize_step_rows(episode: int, episode_metrics: Sequence[Dict]) -> Tuple[List[Dict], List[Dict], List[str]]: step_rows: List[Dict] = [] edge_rows: List[Dict] = [] edge_ids: List[str] = [] for step_idx, info in enumerate(episode_metrics, start=1): step_value = int(info.get("step", step_idx)) step_rows.append( { "episode": episode, "step": step_value, "sim_time": _safe_float(info.get("sim_time")), "reward": _safe_float(info.get("reward")), "throughput": _safe_float(info.get("throughput")), "arrived_count": int(info.get("arrived_count", 0)), "departed_count": int(info.get("departed_count", 0)), "mean_speed_kmh": _safe_float(info.get("mean_speed_kmh")), "speed_std_kmh": _safe_float(info.get("speed_std")) * 3.6, "mean_occupancy": _safe_float(info.get("mean_occupancy")), "density": _safe_float(info.get("density")), "num_vehicles": int(info.get("num_vehicles", 0)), "num_hard_brakes": int(info.get("num_hard_brakes", 0)), "r_flow": _safe_float(info.get("r_flow")), "r_var": _safe_float(info.get("r_var")), "r_brake": _safe_float(info.get("r_brake")), "r_penalty": _safe_float(info.get("r_penalty")), } ) action_speeds = list(info.get("edge_speeds_kmh", [])) measured_speeds_ms = list(info.get("edge_speeds_ms", [])) occupancies = list(info.get("edge_occupancies", [])) edge_count = max(len(action_speeds), len(measured_speeds_ms), len(occupancies)) if not edge_ids: edge_ids = [f"edge_{idx:02d}" for idx in range(edge_count)] for edge_idx in range(edge_count): edge_rows.append( { "episode": episode, "step": step_value, "edge_index": edge_idx, "edge_id": edge_ids[edge_idx], "action_speed_kmh": _safe_float(action_speeds[edge_idx] if edge_idx < len(action_speeds) else np.nan), "measured_speed_kmh": _safe_float( measured_speeds_ms[edge_idx] * 3.6 if edge_idx < len(measured_speeds_ms) else np.nan ), "occupancy": _safe_float(occupancies[edge_idx] if edge_idx < len(occupancies) else np.nan), } ) return step_rows, edge_rows, edge_ids def _write_csv(path: str, rows: Iterable[Dict], fieldnames: Sequence[str]): os.makedirs(os.path.dirname(path), exist_ok=True) with open(path, "w", newline="", encoding="utf-8") as f: writer = csv.DictWriter(f, fieldnames=fieldnames) writer.writeheader() writer.writerows(rows) def _plot_episode_heatmap(path: str, edge_rows: Sequence[Dict], edge_ids: Sequence[str], title_prefix: str): if not edge_rows or not edge_ids: return step_values = sorted({int(row["step"]) for row in edge_rows}) num_edges = len(edge_ids) num_steps = len(step_values) step_to_col = {step: idx for idx, step in enumerate(step_values)} edge_to_row = {edge_id: idx for idx, edge_id in enumerate(edge_ids)} action_grid = np.full((num_edges, num_steps), np.nan, dtype=np.float32) speed_grid = np.full((num_edges, num_steps), np.nan, dtype=np.float32) occupancy_grid = np.full((num_edges, num_steps), np.nan, dtype=np.float32) for row in edge_rows: row_idx = edge_to_row[row["edge_id"]] col_idx = step_to_col[int(row["step"])] action_grid[row_idx, col_idx] = _safe_float(row["action_speed_kmh"]) speed_grid[row_idx, col_idx] = _safe_float(row["measured_speed_kmh"]) occupancy_grid[row_idx, col_idx] = _safe_float(row["occupancy"]) fig, axes = plt.subplots(1, 3, figsize=(18, 7), sharex=True, sharey=True) plots = [ (action_grid, "viridis", "Applied VSL (km/h)"), (speed_grid, "RdYlGn", "Measured Speed (km/h)"), (occupancy_grid, "magma", "Occupancy (%)"), ] for ax, (grid, cmap, title) in zip(axes, plots): image = ax.imshow( np.ma.masked_invalid(grid), aspect="auto", origin="lower", cmap=cmap, interpolation="nearest", resample=False, ) ax.set_title(f"{title_prefix} {title}") ax.set_xlabel("Decision Step") ax.set_ylabel("Controlled Edge") ax.set_yticks(np.arange(num_edges)) ax.set_yticklabels(edge_ids) plt.colorbar(image, ax=ax, fraction=0.046, pad=0.04) plt.tight_layout() plt.savefig(path, dpi=160) plt.close(fig) def _write_summary(path: str, summary: Optional[Dict]): if not summary: return os.makedirs(os.path.dirname(path), exist_ok=True) with open(path, "w", encoding="utf-8") as f: json.dump(summary, f, ensure_ascii=False, indent=2) def _save_episode_bundle(bundle_dir: str, episode: int, step_rows: Sequence[Dict], edge_rows: Sequence[Dict], edge_ids: Sequence[str], summary: Optional[Dict]): step_csv_path = os.path.join(bundle_dir, "step_metrics.csv") edge_csv_path = os.path.join(bundle_dir, "edge_metrics.csv") heatmap_path = os.path.join(bundle_dir, "episode_heatmap.png") summary_path = os.path.join(bundle_dir, "summary.json") _write_csv(step_csv_path, step_rows, fieldnames=list(step_rows[0].keys()) if step_rows else ["episode", "step"]) _write_csv(edge_csv_path, edge_rows, fieldnames=list(edge_rows[0].keys()) if edge_rows else ["episode", "step", "edge_index"]) _plot_episode_heatmap(heatmap_path, edge_rows, edge_ids, title_prefix=f"Episode {episode}") _write_summary(summary_path, summary) def save_training_episode_artifacts( log_dir: str, episode: int, episode_metrics: Sequence[Dict], control_edges: Sequence[str], summary: Optional[Dict] = None, snapshot_interval: int = 50, ): if not episode_metrics: return step_rows, edge_rows, edge_ids = _normalize_step_rows(episode, episode_metrics) if control_edges and len(control_edges) >= len(edge_ids): edge_ids = list(control_edges[: len(edge_ids)]) for row in edge_rows: row["edge_id"] = edge_ids[int(row["edge_index"])] artifacts_root = os.path.join(log_dir, "episode_artifacts") latest_dir = os.path.join(artifacts_root, "latest") _save_episode_bundle(latest_dir, episode, step_rows, edge_rows, edge_ids, summary) if snapshot_interval > 0 and episode % snapshot_interval == 0: snapshot_dir = os.path.join(artifacts_root, "snapshots", f"episode_{episode:04d}") _save_episode_bundle(snapshot_dir, episode, step_rows, edge_rows, edge_ids, summary)