"""Utilities for persisting per-episode training artifacts.""" import csv import json import os from typing import Dict, Iterable, List, Optional, Sequence, Tuple import matplotlib import numpy as np matplotlib.use("Agg") import matplotlib.pyplot as plt from envs.reward_system import REWARD_COMPONENT_COLUMNS def _safe_float(value, default: float = float("nan")) -> float: try: return float(value) except (TypeError, ValueError): return default def _normalize_step_rows(episode: int, episode_metrics: Sequence[Dict]) -> Tuple[List[Dict], List[Dict], List[str]]: step_rows: List[Dict] = [] edge_rows: List[Dict] = [] edge_ids: List[str] = [] for step_idx, info in enumerate(episode_metrics, start=1): step_value = int(info.get("step", step_idx)) step_row = { "episode": episode, "step": step_value, "sim_time": _safe_float(info.get("sim_time")), "reward": _safe_float(info.get("reward")), "throughput": _safe_float(info.get("throughput")), "arrived_count": int(info.get("arrived_count", 0)), "departed_count": int(info.get("departed_count", 0)), "mean_speed_kmh": _safe_float(info.get("mean_speed_kmh")), "speed_variance_norm": _safe_float(info.get("speed_variance_norm")), "mean_occupancy": _safe_float(info.get("mean_occupancy")), "density": _safe_float(info.get("density")), "num_vehicles": int(info.get("num_vehicles", 0)), "num_hard_brakes": int(info.get("num_hard_brakes", 0)), } for column in REWARD_COMPONENT_COLUMNS: step_row[column] = _safe_float(info.get(column)) step_rows.append(step_row) action_speeds = list(info.get("edge_speeds_kmh", [])) measured_speeds_ms = list(info.get("edge_speeds_ms", [])) occupancies = list(info.get("edge_occupancies", [])) edge_count = max(len(action_speeds), len(measured_speeds_ms), len(occupancies)) if not edge_ids: edge_ids = [f"edge_{idx:02d}" for idx in range(edge_count)] for edge_idx in range(edge_count): edge_rows.append( { "episode": episode, "step": step_value, "edge_index": edge_idx, "edge_id": edge_ids[edge_idx], "action_speed_kmh": _safe_float(action_speeds[edge_idx] if edge_idx < len(action_speeds) else np.nan), "measured_speed_kmh": _safe_float( measured_speeds_ms[edge_idx] * 3.6 if edge_idx < len(measured_speeds_ms) else np.nan ), "occupancy": _safe_float(occupancies[edge_idx] if edge_idx < len(occupancies) else np.nan), } ) return step_rows, edge_rows, edge_ids def _write_csv(path: str, rows: Iterable[Dict], fieldnames: Sequence[str]): os.makedirs(os.path.dirname(path), exist_ok=True) with open(path, "w", newline="", encoding="utf-8") as f: writer = csv.DictWriter(f, fieldnames=fieldnames) writer.writeheader() writer.writerows(rows) def _plot_episode_heatmap(path: str, edge_rows: Sequence[Dict], edge_ids: Sequence[str], title_prefix: str): if not edge_rows or not edge_ids: return step_values = sorted({int(row["step"]) for row in edge_rows}) num_edges = len(edge_ids) num_steps = len(step_values) step_to_col = {step: idx for idx, step in enumerate(step_values)} edge_to_row = {edge_id: idx for idx, edge_id in enumerate(edge_ids)} action_grid = np.full((num_edges, num_steps), np.nan, dtype=np.float32) speed_grid = np.full((num_edges, num_steps), np.nan, dtype=np.float32) occupancy_grid = np.full((num_edges, num_steps), np.nan, dtype=np.float32) for row in edge_rows: row_idx = edge_to_row[row["edge_id"]] col_idx = step_to_col[int(row["step"])] action_grid[row_idx, col_idx] = _safe_float(row["action_speed_kmh"]) speed_grid[row_idx, col_idx] = _safe_float(row["measured_speed_kmh"]) occupancy_grid[row_idx, col_idx] = _safe_float(row["occupancy"]) fig, axes = plt.subplots(1, 3, figsize=(18, 7), sharex=True, sharey=True) plots = [ (action_grid, "viridis", "Applied VSL (km/h)"), (speed_grid, "RdYlGn", "Measured Speed (km/h)"), (occupancy_grid, "magma", "Occupancy (%)"), ] for ax, (grid, cmap, title) in zip(axes, plots): image = ax.imshow( np.ma.masked_invalid(grid), aspect="auto", origin="lower", cmap=cmap, interpolation="nearest", resample=False, ) ax.set_title(f"{title_prefix} {title}") ax.set_xlabel("Decision Step") ax.set_ylabel("Controlled Edge") ax.set_yticks(np.arange(num_edges)) ax.set_yticklabels(edge_ids) plt.colorbar(image, ax=ax, fraction=0.046, pad=0.04) plt.tight_layout() plt.savefig(path, dpi=160) plt.close(fig) def _write_summary(path: str, summary: Optional[Dict]): if not summary: return os.makedirs(os.path.dirname(path), exist_ok=True) with open(path, "w", encoding="utf-8") as f: json.dump(summary, f, ensure_ascii=False, indent=2) def _save_episode_bundle(bundle_dir: str, episode: int, step_rows: Sequence[Dict], edge_rows: Sequence[Dict], edge_ids: Sequence[str], summary: Optional[Dict]): step_csv_path = os.path.join(bundle_dir, "step_metrics.csv") edge_csv_path = os.path.join(bundle_dir, "edge_metrics.csv") heatmap_path = os.path.join(bundle_dir, "episode_heatmap.png") summary_path = os.path.join(bundle_dir, "summary.json") _write_csv(step_csv_path, step_rows, fieldnames=list(step_rows[0].keys()) if step_rows else ["episode", "step"]) _write_csv(edge_csv_path, edge_rows, fieldnames=list(edge_rows[0].keys()) if edge_rows else ["episode", "step", "edge_index"]) _plot_episode_heatmap(heatmap_path, edge_rows, edge_ids, title_prefix=f"Episode {episode}") _write_summary(summary_path, summary) def save_training_episode_artifacts( log_dir: str, episode: int, episode_metrics: Sequence[Dict], control_edges: Sequence[str], summary: Optional[Dict] = None, snapshot_interval: int = 50, ): if not episode_metrics: return step_rows, edge_rows, edge_ids = _normalize_step_rows(episode, episode_metrics) if control_edges and len(control_edges) >= len(edge_ids): edge_ids = list(control_edges[: len(edge_ids)]) for row in edge_rows: row["edge_id"] = edge_ids[int(row["edge_index"])] artifacts_root = os.path.join(log_dir, "episode_artifacts") latest_dir = os.path.join(artifacts_root, "latest") _save_episode_bundle(latest_dir, episode, step_rows, edge_rows, edge_ids, summary) if snapshot_interval > 0 and episode % snapshot_interval == 0: snapshot_dir = os.path.join(artifacts_root, "snapshots", f"episode_{episode:04d}") _save_episode_bundle(snapshot_dir, episode, step_rows, edge_rows, edge_ids, summary)