ctm-dqn/utils/episode_artifacts.py

176 lines
7.0 KiB
Python

"""Utilities for persisting per-episode training artifacts."""
import csv
import json
import os
from typing import Dict, Iterable, List, Optional, Sequence, Tuple
import matplotlib
import numpy as np
matplotlib.use("Agg")
import matplotlib.pyplot as plt
from envs.reward_system import REWARD_COMPONENT_COLUMNS
def _safe_float(value, default: float = float("nan")) -> float:
try:
return float(value)
except (TypeError, ValueError):
return default
def _normalize_step_rows(episode: int, episode_metrics: Sequence[Dict]) -> Tuple[List[Dict], List[Dict], List[str]]:
step_rows: List[Dict] = []
edge_rows: List[Dict] = []
edge_ids: List[str] = []
for step_idx, info in enumerate(episode_metrics, start=1):
step_value = int(info.get("step", step_idx))
step_row = {
"episode": episode,
"step": step_value,
"sim_time": _safe_float(info.get("sim_time")),
"reward": _safe_float(info.get("reward")),
"throughput": _safe_float(info.get("throughput")),
"arrived_count": int(info.get("arrived_count", 0)),
"departed_count": int(info.get("departed_count", 0)),
"mean_speed_kmh": _safe_float(info.get("mean_speed_kmh")),
"speed_variance_norm": _safe_float(info.get("speed_variance_norm")),
"mean_occupancy": _safe_float(info.get("mean_occupancy")),
"density": _safe_float(info.get("density")),
"num_vehicles": int(info.get("num_vehicles", 0)),
"num_hard_brakes": int(info.get("num_hard_brakes", 0)),
}
for column in REWARD_COMPONENT_COLUMNS:
step_row[column] = _safe_float(info.get(column))
step_rows.append(step_row)
action_speeds = list(info.get("edge_speeds_kmh", []))
measured_speeds_ms = list(info.get("edge_speeds_ms", []))
occupancies = list(info.get("edge_occupancies", []))
edge_count = max(len(action_speeds), len(measured_speeds_ms), len(occupancies))
if not edge_ids:
edge_ids = [f"edge_{idx:02d}" for idx in range(edge_count)]
for edge_idx in range(edge_count):
edge_rows.append(
{
"episode": episode,
"step": step_value,
"edge_index": edge_idx,
"edge_id": edge_ids[edge_idx],
"action_speed_kmh": _safe_float(action_speeds[edge_idx] if edge_idx < len(action_speeds) else np.nan),
"measured_speed_kmh": _safe_float(
measured_speeds_ms[edge_idx] * 3.6 if edge_idx < len(measured_speeds_ms) else np.nan
),
"occupancy": _safe_float(occupancies[edge_idx] if edge_idx < len(occupancies) else np.nan),
}
)
return step_rows, edge_rows, edge_ids
def _write_csv(path: str, rows: Iterable[Dict], fieldnames: Sequence[str]):
os.makedirs(os.path.dirname(path), exist_ok=True)
with open(path, "w", newline="", encoding="utf-8") as f:
writer = csv.DictWriter(f, fieldnames=fieldnames)
writer.writeheader()
writer.writerows(rows)
def _plot_episode_heatmap(path: str, edge_rows: Sequence[Dict], edge_ids: Sequence[str], title_prefix: str):
if not edge_rows or not edge_ids:
return
step_values = sorted({int(row["step"]) for row in edge_rows})
num_edges = len(edge_ids)
num_steps = len(step_values)
step_to_col = {step: idx for idx, step in enumerate(step_values)}
edge_to_row = {edge_id: idx for idx, edge_id in enumerate(edge_ids)}
action_grid = np.full((num_edges, num_steps), np.nan, dtype=np.float32)
speed_grid = np.full((num_edges, num_steps), np.nan, dtype=np.float32)
occupancy_grid = np.full((num_edges, num_steps), np.nan, dtype=np.float32)
for row in edge_rows:
row_idx = edge_to_row[row["edge_id"]]
col_idx = step_to_col[int(row["step"])]
action_grid[row_idx, col_idx] = _safe_float(row["action_speed_kmh"])
speed_grid[row_idx, col_idx] = _safe_float(row["measured_speed_kmh"])
occupancy_grid[row_idx, col_idx] = _safe_float(row["occupancy"])
fig, axes = plt.subplots(1, 3, figsize=(18, 7), sharex=True, sharey=True)
plots = [
(action_grid, "viridis", "Applied VSL (km/h)"),
(speed_grid, "RdYlGn", "Measured Speed (km/h)"),
(occupancy_grid, "magma", "Occupancy (%)"),
]
for ax, (grid, cmap, title) in zip(axes, plots):
image = ax.imshow(
np.ma.masked_invalid(grid),
aspect="auto",
origin="lower",
cmap=cmap,
interpolation="nearest",
resample=False,
)
ax.set_title(f"{title_prefix} {title}")
ax.set_xlabel("Decision Step")
ax.set_ylabel("Controlled Edge")
ax.set_yticks(np.arange(num_edges))
ax.set_yticklabels(edge_ids)
plt.colorbar(image, ax=ax, fraction=0.046, pad=0.04)
plt.tight_layout()
plt.savefig(path, dpi=160)
plt.close(fig)
def _write_summary(path: str, summary: Optional[Dict]):
if not summary:
return
os.makedirs(os.path.dirname(path), exist_ok=True)
with open(path, "w", encoding="utf-8") as f:
json.dump(summary, f, ensure_ascii=False, indent=2)
def _save_episode_bundle(bundle_dir: str, episode: int, step_rows: Sequence[Dict], edge_rows: Sequence[Dict], edge_ids: Sequence[str], summary: Optional[Dict]):
step_csv_path = os.path.join(bundle_dir, "step_metrics.csv")
edge_csv_path = os.path.join(bundle_dir, "edge_metrics.csv")
heatmap_path = os.path.join(bundle_dir, "episode_heatmap.png")
summary_path = os.path.join(bundle_dir, "summary.json")
_write_csv(step_csv_path, step_rows, fieldnames=list(step_rows[0].keys()) if step_rows else ["episode", "step"])
_write_csv(edge_csv_path, edge_rows, fieldnames=list(edge_rows[0].keys()) if edge_rows else ["episode", "step", "edge_index"])
_plot_episode_heatmap(heatmap_path, edge_rows, edge_ids, title_prefix=f"Episode {episode}")
_write_summary(summary_path, summary)
def save_training_episode_artifacts(
log_dir: str,
episode: int,
episode_metrics: Sequence[Dict],
control_edges: Sequence[str],
summary: Optional[Dict] = None,
snapshot_interval: int = 50,
):
if not episode_metrics:
return
step_rows, edge_rows, edge_ids = _normalize_step_rows(episode, episode_metrics)
if control_edges and len(control_edges) >= len(edge_ids):
edge_ids = list(control_edges[: len(edge_ids)])
for row in edge_rows:
row["edge_id"] = edge_ids[int(row["edge_index"])]
artifacts_root = os.path.join(log_dir, "episode_artifacts")
latest_dir = os.path.join(artifacts_root, "latest")
_save_episode_bundle(latest_dir, episode, step_rows, edge_rows, edge_ids, summary)
if snapshot_interval > 0 and episode % snapshot_interval == 0:
snapshot_dir = os.path.join(artifacts_root, "snapshots", f"episode_{episode:04d}")
_save_episode_bundle(snapshot_dir, episode, step_rows, edge_rows, edge_ids, summary)