ctm-dqn/utils/episode_artifacts.py

424 lines
15 KiB
Python

"""Utilities for persisting per-episode training artifacts."""
import csv
import json
import os
from typing import Dict, Iterable, List, Optional, Sequence, Tuple
import xml.etree.ElementTree as ET
import matplotlib
import numpy as np
matplotlib.use("Agg")
from envs.reward_system import REWARD_COMPONENT_COLUMNS
from utils.heatmap_plotting import (
build_action_panel,
build_occupancy_panel,
build_speed_panel,
save_heatmap_panels,
)
def _safe_float(value, default: float = float("nan")) -> float:
try:
return float(value)
except (TypeError, ValueError):
return default
def _normalize_step_rows(episode: int, episode_metrics: Sequence[Dict]) -> Tuple[List[Dict], List[Dict], List[str]]:
step_rows: List[Dict] = []
edge_rows: List[Dict] = []
edge_ids: List[str] = []
for step_idx, info in enumerate(episode_metrics, start=1):
step_value = int(info.get("step", step_idx))
step_row = {
"episode": episode,
"step": step_value,
"sim_time": _safe_float(info.get("sim_time")),
"reward": _safe_float(info.get("reward")),
"throughput": _safe_float(info.get("throughput")),
"arrived_count": int(info.get("arrived_count", 0)),
"departed_count": int(info.get("departed_count", 0)),
"mean_speed_kmh": _safe_float(info.get("mean_speed_kmh")),
"speed_variance_norm": _safe_float(info.get("speed_variance_norm")),
"mean_occupancy": _safe_float(info.get("mean_occupancy")),
"density": _safe_float(info.get("density")),
"num_vehicles": int(info.get("num_vehicles", 0)),
"num_stops": _safe_float(info.get("num_stops", 0.0), default=0.0),
}
for column in REWARD_COMPONENT_COLUMNS:
step_row[column] = _safe_float(info.get(column))
step_rows.append(step_row)
action_speeds = list(info.get("edge_speeds_kmh", []))
measured_speeds_ms = list(info.get("edge_speeds_ms", []))
occupancies = list(info.get("edge_occupancies", []))
action_applied_mask = list(info.get("action_applied_mask", []))
edge_count = max(len(action_speeds), len(measured_speeds_ms), len(occupancies), len(action_applied_mask))
if not edge_ids:
edge_ids = [f"edge_{idx:02d}" for idx in range(edge_count)]
for edge_idx in range(edge_count):
edge_rows.append(
{
"episode": episode,
"step": step_value,
"edge_index": edge_idx,
"edge_id": edge_ids[edge_idx],
"action_speed_kmh": _safe_float(action_speeds[edge_idx] if edge_idx < len(action_speeds) else np.nan),
"action_applied": bool(action_applied_mask[edge_idx]) if edge_idx < len(action_applied_mask) else True,
"measured_speed_kmh": _safe_float(
measured_speeds_ms[edge_idx] * 3.6 if edge_idx < len(measured_speeds_ms) else np.nan
),
"occupancy": _safe_float(occupancies[edge_idx] if edge_idx < len(occupancies) else np.nan),
}
)
return step_rows, edge_rows, edge_ids
def _write_csv(path: str, rows: Iterable[Dict], fieldnames: Sequence[str]):
os.makedirs(os.path.dirname(path), exist_ok=True)
with open(path, "w", newline="", encoding="utf-8") as f:
writer = csv.DictWriter(f, fieldnames=fieldnames)
writer.writeheader()
writer.writerows(rows)
def _lane_to_edge_id(lane_id: str) -> str:
if "_" not in lane_id:
return lane_id
return lane_id.rsplit("_", 1)[0]
def _parse_detector_layout(additional_path: str) -> List[Dict]:
if not os.path.isfile(additional_path):
return []
try:
root = ET.parse(additional_path).getroot()
except ET.ParseError:
return []
cells: List[Dict] = []
current_key = None
current_cell: Optional[Dict] = None
edge_pos_counts: Dict[str, int] = {}
for elem in root.findall("inductionLoop"):
det_id = elem.get("id")
lane_id = elem.get("lane")
pos_raw = elem.get("pos")
if not det_id or not lane_id or pos_raw is None:
continue
edge_id = _lane_to_edge_id(lane_id)
position_m = _safe_float(pos_raw)
key = (edge_id, round(position_m, 3))
if key != current_key:
pos_index = edge_pos_counts.get(edge_id, 0)
edge_pos_counts[edge_id] = pos_index + 1
current_cell = {
"cell_id": f"{edge_id}@{pos_index}",
"edge_id": edge_id,
"pos_index": pos_index,
"position_m": position_m,
"detector_ids": [],
}
cells.append(current_cell)
current_key = key
current_cell["detector_ids"].append(det_id)
return cells
def _build_detector_rows_from_xml(log_dir: str, episode: int) -> List[Dict]:
metrics_dir = os.path.join(log_dir, "sumo_metrics")
suffix = f"ep{episode:04d}"
additional_path = os.path.join(metrics_dir, f"runtime_metrics_il_{suffix}.add.xml")
metrics_path = os.path.join(metrics_dir, f"metrics_il_output_{suffix}.xml")
cells = _parse_detector_layout(additional_path)
if not cells or not os.path.isfile(metrics_path):
return []
try:
root = ET.parse(metrics_path).getroot()
except ET.ParseError:
return []
interval_map: Dict[Tuple[float, float], Dict[str, Dict[str, float]]] = {}
for elem in root.findall("interval"):
det_id = elem.get("id")
if not det_id:
continue
begin = _safe_float(elem.get("begin"))
end = _safe_float(elem.get("end"))
if not np.isfinite(begin) or not np.isfinite(end):
continue
interval_key = (begin, end)
interval_values = interval_map.setdefault(interval_key, {})
interval_values[det_id] = {
"speed_ms": _safe_float(elem.get("speed")),
"flow_vehph": _safe_float(elem.get("flow"), default=0.0),
"n_veh_contrib": _safe_float(elem.get("nVehContrib"), default=0.0),
"occupancy": _safe_float(elem.get("occupancy"), default=0.0),
}
detector_rows: List[Dict] = []
sorted_intervals = sorted(interval_map.keys())
for step_idx, interval_key in enumerate(sorted_intervals, start=1):
interval_values = interval_map[interval_key]
for cell_order, cell in enumerate(cells):
speed_weight_sum = 0.0
speed_weight_total = 0.0
density_sum = 0.0
occupancy_sum = 0.0
occupancy_count = 0
for det_id in cell["detector_ids"]:
lane_metrics = interval_values.get(det_id)
if not lane_metrics:
continue
speed_ms = lane_metrics["speed_ms"]
flow_vehph = max(lane_metrics["flow_vehph"], 0.0)
n_veh_contrib = max(lane_metrics["n_veh_contrib"], 0.0)
occupancy = max(lane_metrics["occupancy"], 0.0)
if np.isfinite(occupancy):
occupancy_sum += occupancy
occupancy_count += 1
if np.isfinite(speed_ms) and speed_ms > 0.0:
speed_weight = n_veh_contrib if n_veh_contrib > 0.0 else flow_vehph
if speed_weight <= 0.0:
speed_weight = 1.0
speed_weight_sum += speed_ms * 3.6 * speed_weight
speed_weight_total += speed_weight
density_sum += flow_vehph / max(speed_ms * 3.6, 1e-6)
measured_speed_kmh = (
speed_weight_sum / speed_weight_total
if speed_weight_total > 0.0
else np.nan
)
occupancy_pct = (
occupancy_sum / occupancy_count
if occupancy_count > 0
else np.nan
)
detector_rows.append(
{
"episode": episode,
"step": step_idx,
"cell_order": cell_order,
"cell_id": cell["cell_id"],
"edge_id": cell["edge_id"],
"pos_index": int(cell["pos_index"]),
"position_m": float(cell["position_m"]),
"measured_speed_kmh": float(measured_speed_kmh) if np.isfinite(measured_speed_kmh) else np.nan,
"occupancy": float(occupancy_pct) if np.isfinite(occupancy_pct) else np.nan,
"density_vehpkm": float(density_sum),
}
)
return detector_rows
def _align_detector_rows_to_decision_steps(
detector_rows: Sequence[Dict],
step_rows: Sequence[Dict],
) -> List[Dict]:
if not detector_rows or not step_rows:
return list(detector_rows)
detector_step_values = sorted({int(row["step"]) for row in detector_rows})
decision_step_values = sorted({int(row["step"]) for row in step_rows})
if not detector_step_values or not decision_step_values:
return list(detector_rows)
excess_steps = len(detector_step_values) - len(decision_step_values)
if excess_steps <= 0:
return list(detector_rows)
kept_detector_steps = detector_step_values[excess_steps:]
step_mapping = {
old_step: new_step
for old_step, new_step in zip(kept_detector_steps, decision_step_values)
}
aligned_rows: List[Dict] = []
for row in detector_rows:
old_step = int(row["step"])
if old_step not in step_mapping:
continue
aligned_row = dict(row)
aligned_row["step"] = step_mapping[old_step]
aligned_rows.append(aligned_row)
return aligned_rows
def _plot_episode_heatmap(
path: str,
edge_rows: Sequence[Dict],
edge_ids: Sequence[str],
detector_rows: Sequence[Dict],
title_prefix: str,
):
has_action = bool(edge_rows and edge_ids)
has_detector = bool(detector_rows)
if not has_action and not has_detector:
return
step_values = sorted(
{
int(row["step"])
for row in list(edge_rows) + list(detector_rows)
if row.get("step") is not None
}
)
if not step_values:
return
num_steps = len(step_values)
step_to_col = {step: idx for idx, step in enumerate(step_values)}
panels = []
if has_action:
ordered_edge_ids = list(edge_ids)
if not ordered_edge_ids:
ordered_edge_ids = [
str(row["edge_id"])
for row in sorted(edge_rows, key=lambda item: int(item["edge_index"]))
]
ordered_edge_ids = list(dict.fromkeys(ordered_edge_ids))
edge_to_row = {edge_id: idx for idx, edge_id in enumerate(ordered_edge_ids)}
action_grid = np.full((len(ordered_edge_ids), num_steps), np.nan, dtype=np.float32)
for row in edge_rows:
edge_id = str(row["edge_id"])
if edge_id not in edge_to_row:
continue
row_idx = edge_to_row[edge_id]
col_idx = step_to_col[int(row["step"])]
action_grid[row_idx, col_idx] = _safe_float(row["action_speed_kmh"])
if ordered_edge_ids:
panels.append(
build_action_panel(
action_grid,
ordered_edge_ids,
f"{title_prefix} Segment Speed Limit (km/h)",
ylabel="Corridor Segment",
)
)
if has_detector:
ordered_cells = []
seen_cells = set()
for row in sorted(
detector_rows,
key=lambda item: (int(item["cell_order"]), str(item["cell_id"])),
):
cell_id = str(row["cell_id"])
if cell_id in seen_cells:
continue
seen_cells.add(cell_id)
ordered_cells.append(cell_id)
num_cells = len(ordered_cells)
cell_to_row = {cell_id: idx for idx, cell_id in enumerate(ordered_cells)}
speed_grid = np.full((num_cells, num_steps), np.nan, dtype=np.float32)
occupancy_grid = np.full((num_cells, num_steps), np.nan, dtype=np.float32)
for row in detector_rows:
row_idx = cell_to_row[str(row["cell_id"])]
col_idx = step_to_col[int(row["step"])]
speed_grid[row_idx, col_idx] = _safe_float(row["measured_speed_kmh"])
occupancy_grid[row_idx, col_idx] = _safe_float(row["occupancy"])
plots = [
build_speed_panel(
speed_grid,
ordered_cells,
f"{title_prefix} Measured Speed (km/h)",
"Detector Cell",
),
build_occupancy_panel(
occupancy_grid,
ordered_cells,
f"{title_prefix} Occupancy (%)",
"Detector Cell",
),
]
panels = plots[:1] + panels + plots[1:]
save_heatmap_panels(path, panels, xlabel="Decision Step")
def _write_summary(path: str, summary: Optional[Dict]):
if not summary:
return
os.makedirs(os.path.dirname(path), exist_ok=True)
with open(path, "w", encoding="utf-8") as f:
json.dump(summary, f, ensure_ascii=False, indent=2)
def _save_episode_bundle(
bundle_dir: str,
episode: int,
step_rows: Sequence[Dict],
edge_rows: Sequence[Dict],
edge_ids: Sequence[str],
detector_rows: Sequence[Dict],
summary: Optional[Dict],
):
step_csv_path = os.path.join(bundle_dir, "step_metrics.csv")
edge_csv_path = os.path.join(bundle_dir, "edge_metrics.csv")
detector_csv_path = os.path.join(bundle_dir, "detector_metrics.csv")
heatmap_path = os.path.join(bundle_dir, "episode_heatmap.png")
summary_path = os.path.join(bundle_dir, "summary.json")
_write_csv(step_csv_path, step_rows, fieldnames=list(step_rows[0].keys()) if step_rows else ["episode", "step"])
_write_csv(edge_csv_path, edge_rows, fieldnames=list(edge_rows[0].keys()) if edge_rows else ["episode", "step", "edge_index"])
_write_csv(
detector_csv_path,
detector_rows,
fieldnames=list(detector_rows[0].keys()) if detector_rows else ["episode", "step", "cell_id"],
)
_plot_episode_heatmap(heatmap_path, edge_rows, edge_ids, detector_rows, title_prefix=f"Episode {episode}")
_write_summary(summary_path, summary)
def save_training_episode_artifacts(
log_dir: str,
episode: int,
episode_metrics: Sequence[Dict],
control_edges: Sequence[str],
summary: Optional[Dict] = None,
snapshot_interval: int = 50,
):
if not episode_metrics:
return
step_rows, edge_rows, edge_ids = _normalize_step_rows(episode, episode_metrics)
detector_rows = _build_detector_rows_from_xml(log_dir, episode)
detector_rows = _align_detector_rows_to_decision_steps(detector_rows, step_rows)
if control_edges and len(control_edges) >= len(edge_ids):
edge_ids = list(control_edges[: len(edge_ids)])
for row in edge_rows:
row["edge_id"] = edge_ids[int(row["edge_index"])]
artifacts_root = os.path.join(log_dir, "episode_artifacts")
latest_dir = os.path.join(artifacts_root, "latest")
_save_episode_bundle(latest_dir, episode, step_rows, edge_rows, edge_ids, detector_rows, summary)
if snapshot_interval > 0 and episode % snapshot_interval == 0:
snapshot_dir = os.path.join(artifacts_root, "snapshots", f"episode_{episode:04d}")
_save_episode_bundle(snapshot_dir, episode, step_rows, edge_rows, edge_ids, detector_rows, summary)