176 lines
7.0 KiB
Python
176 lines
7.0 KiB
Python
"""Utilities for persisting per-episode training artifacts."""
|
|
import csv
|
|
import json
|
|
import os
|
|
from typing import Dict, Iterable, List, Optional, Sequence, Tuple
|
|
|
|
import matplotlib
|
|
import numpy as np
|
|
|
|
matplotlib.use("Agg")
|
|
import matplotlib.pyplot as plt
|
|
|
|
from envs.reward_system import REWARD_COMPONENT_COLUMNS
|
|
|
|
|
|
def _safe_float(value, default: float = float("nan")) -> float:
|
|
try:
|
|
return float(value)
|
|
except (TypeError, ValueError):
|
|
return default
|
|
|
|
|
|
def _normalize_step_rows(episode: int, episode_metrics: Sequence[Dict]) -> Tuple[List[Dict], List[Dict], List[str]]:
|
|
step_rows: List[Dict] = []
|
|
edge_rows: List[Dict] = []
|
|
edge_ids: List[str] = []
|
|
|
|
for step_idx, info in enumerate(episode_metrics, start=1):
|
|
step_value = int(info.get("step", step_idx))
|
|
step_row = {
|
|
"episode": episode,
|
|
"step": step_value,
|
|
"sim_time": _safe_float(info.get("sim_time")),
|
|
"reward": _safe_float(info.get("reward")),
|
|
"throughput": _safe_float(info.get("throughput")),
|
|
"arrived_count": int(info.get("arrived_count", 0)),
|
|
"departed_count": int(info.get("departed_count", 0)),
|
|
"mean_speed_kmh": _safe_float(info.get("mean_speed_kmh")),
|
|
"speed_variance_norm": _safe_float(info.get("speed_variance_norm")),
|
|
"mean_occupancy": _safe_float(info.get("mean_occupancy")),
|
|
"density": _safe_float(info.get("density")),
|
|
"num_vehicles": int(info.get("num_vehicles", 0)),
|
|
"num_hard_brakes": int(info.get("num_hard_brakes", 0)),
|
|
}
|
|
for column in REWARD_COMPONENT_COLUMNS:
|
|
step_row[column] = _safe_float(info.get(column))
|
|
step_rows.append(step_row)
|
|
|
|
action_speeds = list(info.get("edge_speeds_kmh", []))
|
|
measured_speeds_ms = list(info.get("edge_speeds_ms", []))
|
|
occupancies = list(info.get("edge_occupancies", []))
|
|
edge_count = max(len(action_speeds), len(measured_speeds_ms), len(occupancies))
|
|
|
|
if not edge_ids:
|
|
edge_ids = [f"edge_{idx:02d}" for idx in range(edge_count)]
|
|
|
|
for edge_idx in range(edge_count):
|
|
edge_rows.append(
|
|
{
|
|
"episode": episode,
|
|
"step": step_value,
|
|
"edge_index": edge_idx,
|
|
"edge_id": edge_ids[edge_idx],
|
|
"action_speed_kmh": _safe_float(action_speeds[edge_idx] if edge_idx < len(action_speeds) else np.nan),
|
|
"measured_speed_kmh": _safe_float(
|
|
measured_speeds_ms[edge_idx] * 3.6 if edge_idx < len(measured_speeds_ms) else np.nan
|
|
),
|
|
"occupancy": _safe_float(occupancies[edge_idx] if edge_idx < len(occupancies) else np.nan),
|
|
}
|
|
)
|
|
|
|
return step_rows, edge_rows, edge_ids
|
|
|
|
|
|
def _write_csv(path: str, rows: Iterable[Dict], fieldnames: Sequence[str]):
|
|
os.makedirs(os.path.dirname(path), exist_ok=True)
|
|
with open(path, "w", newline="", encoding="utf-8") as f:
|
|
writer = csv.DictWriter(f, fieldnames=fieldnames)
|
|
writer.writeheader()
|
|
writer.writerows(rows)
|
|
|
|
|
|
def _plot_episode_heatmap(path: str, edge_rows: Sequence[Dict], edge_ids: Sequence[str], title_prefix: str):
|
|
if not edge_rows or not edge_ids:
|
|
return
|
|
|
|
step_values = sorted({int(row["step"]) for row in edge_rows})
|
|
num_edges = len(edge_ids)
|
|
num_steps = len(step_values)
|
|
step_to_col = {step: idx for idx, step in enumerate(step_values)}
|
|
edge_to_row = {edge_id: idx for idx, edge_id in enumerate(edge_ids)}
|
|
|
|
action_grid = np.full((num_edges, num_steps), np.nan, dtype=np.float32)
|
|
speed_grid = np.full((num_edges, num_steps), np.nan, dtype=np.float32)
|
|
occupancy_grid = np.full((num_edges, num_steps), np.nan, dtype=np.float32)
|
|
|
|
for row in edge_rows:
|
|
row_idx = edge_to_row[row["edge_id"]]
|
|
col_idx = step_to_col[int(row["step"])]
|
|
action_grid[row_idx, col_idx] = _safe_float(row["action_speed_kmh"])
|
|
speed_grid[row_idx, col_idx] = _safe_float(row["measured_speed_kmh"])
|
|
occupancy_grid[row_idx, col_idx] = _safe_float(row["occupancy"])
|
|
|
|
fig, axes = plt.subplots(1, 3, figsize=(18, 7), sharex=True, sharey=True)
|
|
plots = [
|
|
(action_grid, "viridis", "Applied VSL (km/h)"),
|
|
(speed_grid, "RdYlGn", "Measured Speed (km/h)"),
|
|
(occupancy_grid, "magma", "Occupancy (%)"),
|
|
]
|
|
|
|
for ax, (grid, cmap, title) in zip(axes, plots):
|
|
image = ax.imshow(
|
|
np.ma.masked_invalid(grid),
|
|
aspect="auto",
|
|
origin="lower",
|
|
cmap=cmap,
|
|
interpolation="nearest",
|
|
resample=False,
|
|
)
|
|
ax.set_title(f"{title_prefix} {title}")
|
|
ax.set_xlabel("Decision Step")
|
|
ax.set_ylabel("Controlled Edge")
|
|
ax.set_yticks(np.arange(num_edges))
|
|
ax.set_yticklabels(edge_ids)
|
|
plt.colorbar(image, ax=ax, fraction=0.046, pad=0.04)
|
|
|
|
plt.tight_layout()
|
|
plt.savefig(path, dpi=160)
|
|
plt.close(fig)
|
|
|
|
|
|
def _write_summary(path: str, summary: Optional[Dict]):
|
|
if not summary:
|
|
return
|
|
os.makedirs(os.path.dirname(path), exist_ok=True)
|
|
with open(path, "w", encoding="utf-8") as f:
|
|
json.dump(summary, f, ensure_ascii=False, indent=2)
|
|
|
|
|
|
def _save_episode_bundle(bundle_dir: str, episode: int, step_rows: Sequence[Dict], edge_rows: Sequence[Dict], edge_ids: Sequence[str], summary: Optional[Dict]):
|
|
step_csv_path = os.path.join(bundle_dir, "step_metrics.csv")
|
|
edge_csv_path = os.path.join(bundle_dir, "edge_metrics.csv")
|
|
heatmap_path = os.path.join(bundle_dir, "episode_heatmap.png")
|
|
summary_path = os.path.join(bundle_dir, "summary.json")
|
|
|
|
_write_csv(step_csv_path, step_rows, fieldnames=list(step_rows[0].keys()) if step_rows else ["episode", "step"])
|
|
_write_csv(edge_csv_path, edge_rows, fieldnames=list(edge_rows[0].keys()) if edge_rows else ["episode", "step", "edge_index"])
|
|
_plot_episode_heatmap(heatmap_path, edge_rows, edge_ids, title_prefix=f"Episode {episode}")
|
|
_write_summary(summary_path, summary)
|
|
|
|
|
|
def save_training_episode_artifacts(
|
|
log_dir: str,
|
|
episode: int,
|
|
episode_metrics: Sequence[Dict],
|
|
control_edges: Sequence[str],
|
|
summary: Optional[Dict] = None,
|
|
snapshot_interval: int = 50,
|
|
):
|
|
if not episode_metrics:
|
|
return
|
|
|
|
step_rows, edge_rows, edge_ids = _normalize_step_rows(episode, episode_metrics)
|
|
if control_edges and len(control_edges) >= len(edge_ids):
|
|
edge_ids = list(control_edges[: len(edge_ids)])
|
|
for row in edge_rows:
|
|
row["edge_id"] = edge_ids[int(row["edge_index"])]
|
|
|
|
artifacts_root = os.path.join(log_dir, "episode_artifacts")
|
|
latest_dir = os.path.join(artifacts_root, "latest")
|
|
_save_episode_bundle(latest_dir, episode, step_rows, edge_rows, edge_ids, summary)
|
|
|
|
if snapshot_interval > 0 and episode % snapshot_interval == 0:
|
|
snapshot_dir = os.path.join(artifacts_root, "snapshots", f"episode_{episode:04d}")
|
|
_save_episode_bundle(snapshot_dir, episode, step_rows, edge_rows, edge_ids, summary)
|