ctm-dqn/utils/episode_artifacts.py

177 lines
7.1 KiB
Python

"""Utilities for persisting per-episode training artifacts."""
import csv
import json
import os
from typing import Dict, Iterable, List, Optional, Sequence, Tuple
import matplotlib
import numpy as np
matplotlib.use("Agg")
import matplotlib.pyplot as plt
def _safe_float(value, default: float = float("nan")) -> float:
try:
return float(value)
except (TypeError, ValueError):
return default
def _normalize_step_rows(episode: int, episode_metrics: Sequence[Dict]) -> Tuple[List[Dict], List[Dict], List[str]]:
step_rows: List[Dict] = []
edge_rows: List[Dict] = []
edge_ids: List[str] = []
for step_idx, info in enumerate(episode_metrics, start=1):
step_value = int(info.get("step", step_idx))
step_rows.append(
{
"episode": episode,
"step": step_value,
"sim_time": _safe_float(info.get("sim_time")),
"reward": _safe_float(info.get("reward")),
"throughput": _safe_float(info.get("throughput")),
"arrived_count": int(info.get("arrived_count", 0)),
"departed_count": int(info.get("departed_count", 0)),
"mean_speed_kmh": _safe_float(info.get("mean_speed_kmh")),
"speed_std_kmh": _safe_float(info.get("speed_std")) * 3.6,
"mean_occupancy": _safe_float(info.get("mean_occupancy")),
"density": _safe_float(info.get("density")),
"num_vehicles": int(info.get("num_vehicles", 0)),
"num_hard_brakes": int(info.get("num_hard_brakes", 0)),
"r_flow": _safe_float(info.get("r_flow")),
"r_var": _safe_float(info.get("r_var")),
"r_brake": _safe_float(info.get("r_brake")),
"r_penalty": _safe_float(info.get("r_penalty")),
}
)
action_speeds = list(info.get("edge_speeds_kmh", []))
measured_speeds_ms = list(info.get("edge_speeds_ms", []))
occupancies = list(info.get("edge_occupancies", []))
edge_count = max(len(action_speeds), len(measured_speeds_ms), len(occupancies))
if not edge_ids:
edge_ids = [f"edge_{idx:02d}" for idx in range(edge_count)]
for edge_idx in range(edge_count):
edge_rows.append(
{
"episode": episode,
"step": step_value,
"edge_index": edge_idx,
"edge_id": edge_ids[edge_idx],
"action_speed_kmh": _safe_float(action_speeds[edge_idx] if edge_idx < len(action_speeds) else np.nan),
"measured_speed_kmh": _safe_float(
measured_speeds_ms[edge_idx] * 3.6 if edge_idx < len(measured_speeds_ms) else np.nan
),
"occupancy": _safe_float(occupancies[edge_idx] if edge_idx < len(occupancies) else np.nan),
}
)
return step_rows, edge_rows, edge_ids
def _write_csv(path: str, rows: Iterable[Dict], fieldnames: Sequence[str]):
os.makedirs(os.path.dirname(path), exist_ok=True)
with open(path, "w", newline="", encoding="utf-8") as f:
writer = csv.DictWriter(f, fieldnames=fieldnames)
writer.writeheader()
writer.writerows(rows)
def _plot_episode_heatmap(path: str, edge_rows: Sequence[Dict], edge_ids: Sequence[str], title_prefix: str):
if not edge_rows or not edge_ids:
return
step_values = sorted({int(row["step"]) for row in edge_rows})
num_edges = len(edge_ids)
num_steps = len(step_values)
step_to_col = {step: idx for idx, step in enumerate(step_values)}
edge_to_row = {edge_id: idx for idx, edge_id in enumerate(edge_ids)}
action_grid = np.full((num_edges, num_steps), np.nan, dtype=np.float32)
speed_grid = np.full((num_edges, num_steps), np.nan, dtype=np.float32)
occupancy_grid = np.full((num_edges, num_steps), np.nan, dtype=np.float32)
for row in edge_rows:
row_idx = edge_to_row[row["edge_id"]]
col_idx = step_to_col[int(row["step"])]
action_grid[row_idx, col_idx] = _safe_float(row["action_speed_kmh"])
speed_grid[row_idx, col_idx] = _safe_float(row["measured_speed_kmh"])
occupancy_grid[row_idx, col_idx] = _safe_float(row["occupancy"])
fig, axes = plt.subplots(1, 3, figsize=(18, 7), sharex=True, sharey=True)
plots = [
(action_grid, "viridis", "Applied VSL (km/h)"),
(speed_grid, "RdYlGn", "Measured Speed (km/h)"),
(occupancy_grid, "magma", "Occupancy (%)"),
]
for ax, (grid, cmap, title) in zip(axes, plots):
image = ax.imshow(
np.ma.masked_invalid(grid),
aspect="auto",
origin="lower",
cmap=cmap,
interpolation="nearest",
resample=False,
)
ax.set_title(f"{title_prefix} {title}")
ax.set_xlabel("Decision Step")
ax.set_ylabel("Controlled Edge")
ax.set_yticks(np.arange(num_edges))
ax.set_yticklabels(edge_ids)
plt.colorbar(image, ax=ax, fraction=0.046, pad=0.04)
plt.tight_layout()
plt.savefig(path, dpi=160)
plt.close(fig)
def _write_summary(path: str, summary: Optional[Dict]):
if not summary:
return
os.makedirs(os.path.dirname(path), exist_ok=True)
with open(path, "w", encoding="utf-8") as f:
json.dump(summary, f, ensure_ascii=False, indent=2)
def _save_episode_bundle(bundle_dir: str, episode: int, step_rows: Sequence[Dict], edge_rows: Sequence[Dict], edge_ids: Sequence[str], summary: Optional[Dict]):
step_csv_path = os.path.join(bundle_dir, "step_metrics.csv")
edge_csv_path = os.path.join(bundle_dir, "edge_metrics.csv")
heatmap_path = os.path.join(bundle_dir, "episode_heatmap.png")
summary_path = os.path.join(bundle_dir, "summary.json")
_write_csv(step_csv_path, step_rows, fieldnames=list(step_rows[0].keys()) if step_rows else ["episode", "step"])
_write_csv(edge_csv_path, edge_rows, fieldnames=list(edge_rows[0].keys()) if edge_rows else ["episode", "step", "edge_index"])
_plot_episode_heatmap(heatmap_path, edge_rows, edge_ids, title_prefix=f"Episode {episode}")
_write_summary(summary_path, summary)
def save_training_episode_artifacts(
log_dir: str,
episode: int,
episode_metrics: Sequence[Dict],
control_edges: Sequence[str],
summary: Optional[Dict] = None,
snapshot_interval: int = 50,
):
if not episode_metrics:
return
step_rows, edge_rows, edge_ids = _normalize_step_rows(episode, episode_metrics)
if control_edges and len(control_edges) >= len(edge_ids):
edge_ids = list(control_edges[: len(edge_ids)])
for row in edge_rows:
row["edge_id"] = edge_ids[int(row["edge_index"])]
artifacts_root = os.path.join(log_dir, "episode_artifacts")
latest_dir = os.path.join(artifacts_root, "latest")
_save_episode_bundle(latest_dir, episode, step_rows, edge_rows, edge_ids, summary)
if snapshot_interval > 0 and episode % snapshot_interval == 0:
snapshot_dir = os.path.join(artifacts_root, "snapshots", f"episode_{episode:04d}")
_save_episode_bundle(snapshot_dir, episode, step_rows, edge_rows, edge_ids, summary)