365 lines
13 KiB
Python
365 lines
13 KiB
Python
"""Utilities for persisting per-episode training artifacts."""
|
|
import csv
|
|
import json
|
|
import os
|
|
from typing import Dict, Iterable, List, Optional, Sequence, Tuple
|
|
import xml.etree.ElementTree as ET
|
|
|
|
import matplotlib
|
|
import numpy as np
|
|
|
|
matplotlib.use("Agg")
|
|
|
|
from envs.reward_system import REWARD_COMPONENT_COLUMNS
|
|
from utils.heatmap_plotting import (
|
|
build_action_panel,
|
|
build_density_panel,
|
|
build_speed_panel,
|
|
save_heatmap_panels,
|
|
)
|
|
|
|
|
|
def _safe_float(value, default: float = float("nan")) -> float:
|
|
try:
|
|
return float(value)
|
|
except (TypeError, ValueError):
|
|
return default
|
|
|
|
|
|
def _normalize_step_rows(episode: int, episode_metrics: Sequence[Dict]) -> Tuple[List[Dict], List[Dict], List[str]]:
|
|
step_rows: List[Dict] = []
|
|
edge_rows: List[Dict] = []
|
|
edge_ids: List[str] = []
|
|
|
|
for step_idx, info in enumerate(episode_metrics, start=1):
|
|
step_value = int(info.get("step", step_idx))
|
|
step_row = {
|
|
"episode": episode,
|
|
"step": step_value,
|
|
"sim_time": _safe_float(info.get("sim_time")),
|
|
"reward": _safe_float(info.get("reward")),
|
|
"throughput": _safe_float(info.get("throughput")),
|
|
"arrived_count": int(info.get("arrived_count", 0)),
|
|
"departed_count": int(info.get("departed_count", 0)),
|
|
"mean_speed_kmh": _safe_float(info.get("mean_speed_kmh")),
|
|
"speed_variance_norm": _safe_float(info.get("speed_variance_norm")),
|
|
"mean_occupancy": _safe_float(info.get("mean_occupancy")),
|
|
"density": _safe_float(info.get("density")),
|
|
"num_vehicles": int(info.get("num_vehicles", 0)),
|
|
"num_stops": _safe_float(info.get("num_stops", 0.0), default=0.0),
|
|
}
|
|
for column in REWARD_COMPONENT_COLUMNS:
|
|
step_row[column] = _safe_float(info.get(column))
|
|
step_rows.append(step_row)
|
|
|
|
action_speeds = list(info.get("edge_speeds_kmh", []))
|
|
measured_speeds_ms = list(info.get("edge_speeds_ms", []))
|
|
occupancies = list(info.get("edge_occupancies", []))
|
|
action_applied_mask = list(info.get("action_applied_mask", []))
|
|
edge_count = max(len(action_speeds), len(measured_speeds_ms), len(occupancies), len(action_applied_mask))
|
|
|
|
if not edge_ids:
|
|
edge_ids = [f"edge_{idx:02d}" for idx in range(edge_count)]
|
|
|
|
for edge_idx in range(edge_count):
|
|
edge_rows.append(
|
|
{
|
|
"episode": episode,
|
|
"step": step_value,
|
|
"edge_index": edge_idx,
|
|
"edge_id": edge_ids[edge_idx],
|
|
"action_speed_kmh": _safe_float(action_speeds[edge_idx] if edge_idx < len(action_speeds) else np.nan),
|
|
"action_applied": bool(action_applied_mask[edge_idx]) if edge_idx < len(action_applied_mask) else True,
|
|
"measured_speed_kmh": _safe_float(
|
|
measured_speeds_ms[edge_idx] * 3.6 if edge_idx < len(measured_speeds_ms) else np.nan
|
|
),
|
|
"occupancy": _safe_float(occupancies[edge_idx] if edge_idx < len(occupancies) else np.nan),
|
|
}
|
|
)
|
|
|
|
return step_rows, edge_rows, edge_ids
|
|
|
|
|
|
def _write_csv(path: str, rows: Iterable[Dict], fieldnames: Sequence[str]):
|
|
os.makedirs(os.path.dirname(path), exist_ok=True)
|
|
with open(path, "w", newline="", encoding="utf-8") as f:
|
|
writer = csv.DictWriter(f, fieldnames=fieldnames)
|
|
writer.writeheader()
|
|
writer.writerows(rows)
|
|
|
|
|
|
def _lane_to_edge_id(lane_id: str) -> str:
|
|
if "_" not in lane_id:
|
|
return lane_id
|
|
return lane_id.rsplit("_", 1)[0]
|
|
|
|
|
|
def _parse_detector_layout(additional_path: str) -> List[Dict]:
|
|
if not os.path.isfile(additional_path):
|
|
return []
|
|
|
|
try:
|
|
root = ET.parse(additional_path).getroot()
|
|
except ET.ParseError:
|
|
return []
|
|
|
|
cells: List[Dict] = []
|
|
current_key = None
|
|
current_cell: Optional[Dict] = None
|
|
edge_pos_counts: Dict[str, int] = {}
|
|
|
|
for elem in root.findall("inductionLoop"):
|
|
det_id = elem.get("id")
|
|
lane_id = elem.get("lane")
|
|
pos_raw = elem.get("pos")
|
|
if not det_id or not lane_id or pos_raw is None:
|
|
continue
|
|
|
|
edge_id = _lane_to_edge_id(lane_id)
|
|
position_m = _safe_float(pos_raw)
|
|
key = (edge_id, round(position_m, 3))
|
|
|
|
if key != current_key:
|
|
pos_index = edge_pos_counts.get(edge_id, 0)
|
|
edge_pos_counts[edge_id] = pos_index + 1
|
|
current_cell = {
|
|
"cell_id": f"{edge_id}@{pos_index}",
|
|
"edge_id": edge_id,
|
|
"pos_index": pos_index,
|
|
"position_m": position_m,
|
|
"detector_ids": [],
|
|
}
|
|
cells.append(current_cell)
|
|
current_key = key
|
|
|
|
current_cell["detector_ids"].append(det_id)
|
|
|
|
return cells
|
|
|
|
|
|
def _build_detector_rows_from_xml(log_dir: str, episode: int) -> List[Dict]:
|
|
metrics_dir = os.path.join(log_dir, "sumo_metrics")
|
|
suffix = f"ep{episode:04d}"
|
|
additional_path = os.path.join(metrics_dir, f"runtime_metrics_il_{suffix}.add.xml")
|
|
metrics_path = os.path.join(metrics_dir, f"metrics_il_output_{suffix}.xml")
|
|
cells = _parse_detector_layout(additional_path)
|
|
if not cells or not os.path.isfile(metrics_path):
|
|
return []
|
|
|
|
try:
|
|
root = ET.parse(metrics_path).getroot()
|
|
except ET.ParseError:
|
|
return []
|
|
|
|
interval_map: Dict[Tuple[float, float], Dict[str, Dict[str, float]]] = {}
|
|
for elem in root.findall("interval"):
|
|
det_id = elem.get("id")
|
|
if not det_id:
|
|
continue
|
|
|
|
begin = _safe_float(elem.get("begin"))
|
|
end = _safe_float(elem.get("end"))
|
|
if not np.isfinite(begin) or not np.isfinite(end):
|
|
continue
|
|
|
|
interval_key = (begin, end)
|
|
interval_values = interval_map.setdefault(interval_key, {})
|
|
interval_values[det_id] = {
|
|
"speed_ms": _safe_float(elem.get("speed")),
|
|
"flow_vehph": _safe_float(elem.get("flow"), default=0.0),
|
|
"n_veh_contrib": _safe_float(elem.get("nVehContrib"), default=0.0),
|
|
"occupancy": _safe_float(elem.get("occupancy"), default=0.0),
|
|
}
|
|
|
|
detector_rows: List[Dict] = []
|
|
sorted_intervals = sorted(interval_map.keys())
|
|
for step_idx, interval_key in enumerate(sorted_intervals, start=1):
|
|
interval_values = interval_map[interval_key]
|
|
for cell_order, cell in enumerate(cells):
|
|
speed_weight_sum = 0.0
|
|
speed_weight_total = 0.0
|
|
density_sum = 0.0
|
|
|
|
for det_id in cell["detector_ids"]:
|
|
lane_metrics = interval_values.get(det_id)
|
|
if not lane_metrics:
|
|
continue
|
|
|
|
speed_ms = lane_metrics["speed_ms"]
|
|
flow_vehph = max(lane_metrics["flow_vehph"], 0.0)
|
|
n_veh_contrib = max(lane_metrics["n_veh_contrib"], 0.0)
|
|
|
|
if np.isfinite(speed_ms) and speed_ms > 0.0:
|
|
speed_weight = n_veh_contrib if n_veh_contrib > 0.0 else flow_vehph
|
|
if speed_weight <= 0.0:
|
|
speed_weight = 1.0
|
|
speed_weight_sum += speed_ms * 3.6 * speed_weight
|
|
speed_weight_total += speed_weight
|
|
density_sum += flow_vehph / max(speed_ms * 3.6, 1e-6)
|
|
|
|
measured_speed_kmh = (
|
|
speed_weight_sum / speed_weight_total
|
|
if speed_weight_total > 0.0
|
|
else np.nan
|
|
)
|
|
|
|
detector_rows.append(
|
|
{
|
|
"episode": episode,
|
|
"step": step_idx,
|
|
"cell_order": cell_order,
|
|
"cell_id": cell["cell_id"],
|
|
"edge_id": cell["edge_id"],
|
|
"pos_index": int(cell["pos_index"]),
|
|
"position_m": float(cell["position_m"]),
|
|
"measured_speed_kmh": float(measured_speed_kmh) if np.isfinite(measured_speed_kmh) else np.nan,
|
|
"density_vehpkm": float(density_sum),
|
|
}
|
|
)
|
|
|
|
return detector_rows
|
|
|
|
|
|
def _plot_episode_heatmap(
|
|
path: str,
|
|
edge_rows: Sequence[Dict],
|
|
edge_ids: Sequence[str],
|
|
detector_rows: Sequence[Dict],
|
|
title_prefix: str,
|
|
):
|
|
has_action = bool(edge_rows and edge_ids)
|
|
has_detector = bool(detector_rows)
|
|
if not has_action and not has_detector:
|
|
return
|
|
|
|
step_values = sorted(
|
|
{
|
|
int(row["step"])
|
|
for row in list(edge_rows) + list(detector_rows)
|
|
if row.get("step") is not None
|
|
}
|
|
)
|
|
if not step_values:
|
|
return
|
|
|
|
num_steps = len(step_values)
|
|
step_to_col = {step: idx for idx, step in enumerate(step_values)}
|
|
panels = []
|
|
if has_action:
|
|
ordered_edge_ids = list(edge_ids)
|
|
edge_to_row = {edge_id: idx for idx, edge_id in enumerate(ordered_edge_ids)}
|
|
action_grid = np.full((len(ordered_edge_ids), num_steps), np.nan, dtype=np.float32)
|
|
for row in edge_rows:
|
|
edge_id = str(row["edge_id"])
|
|
if edge_id not in edge_to_row:
|
|
continue
|
|
if not bool(row.get("action_applied", True)):
|
|
continue
|
|
row_idx = edge_to_row[edge_id]
|
|
col_idx = step_to_col[int(row["step"])]
|
|
action_grid[row_idx, col_idx] = _safe_float(row["action_speed_kmh"])
|
|
panels.append(build_action_panel(action_grid, ordered_edge_ids, f"{title_prefix} Applied VSL (km/h)"))
|
|
|
|
if has_detector:
|
|
ordered_cells = []
|
|
seen_cells = set()
|
|
for row in sorted(
|
|
detector_rows,
|
|
key=lambda item: (int(item["cell_order"]), str(item["cell_id"])),
|
|
):
|
|
cell_id = str(row["cell_id"])
|
|
if cell_id in seen_cells:
|
|
continue
|
|
seen_cells.add(cell_id)
|
|
ordered_cells.append(cell_id)
|
|
|
|
num_cells = len(ordered_cells)
|
|
cell_to_row = {cell_id: idx for idx, cell_id in enumerate(ordered_cells)}
|
|
speed_grid = np.full((num_cells, num_steps), np.nan, dtype=np.float32)
|
|
density_grid = np.full((num_cells, num_steps), np.nan, dtype=np.float32)
|
|
|
|
for row in detector_rows:
|
|
row_idx = cell_to_row[str(row["cell_id"])]
|
|
col_idx = step_to_col[int(row["step"])]
|
|
speed_grid[row_idx, col_idx] = _safe_float(row["measured_speed_kmh"])
|
|
density_grid[row_idx, col_idx] = _safe_float(row["density_vehpkm"])
|
|
|
|
plots = [
|
|
build_speed_panel(
|
|
speed_grid,
|
|
ordered_cells,
|
|
f"{title_prefix} Measured Speed (km/h)",
|
|
"Detector Cell",
|
|
),
|
|
build_density_panel(
|
|
density_grid,
|
|
ordered_cells,
|
|
f"{title_prefix} Density (veh/km)",
|
|
"Detector Cell",
|
|
),
|
|
]
|
|
panels = plots[:1] + panels + plots[1:]
|
|
|
|
save_heatmap_panels(path, panels, xlabel="Decision Step")
|
|
|
|
|
|
def _write_summary(path: str, summary: Optional[Dict]):
|
|
if not summary:
|
|
return
|
|
os.makedirs(os.path.dirname(path), exist_ok=True)
|
|
with open(path, "w", encoding="utf-8") as f:
|
|
json.dump(summary, f, ensure_ascii=False, indent=2)
|
|
|
|
|
|
def _save_episode_bundle(
|
|
bundle_dir: str,
|
|
episode: int,
|
|
step_rows: Sequence[Dict],
|
|
edge_rows: Sequence[Dict],
|
|
edge_ids: Sequence[str],
|
|
detector_rows: Sequence[Dict],
|
|
summary: Optional[Dict],
|
|
):
|
|
step_csv_path = os.path.join(bundle_dir, "step_metrics.csv")
|
|
edge_csv_path = os.path.join(bundle_dir, "edge_metrics.csv")
|
|
detector_csv_path = os.path.join(bundle_dir, "detector_metrics.csv")
|
|
heatmap_path = os.path.join(bundle_dir, "episode_heatmap.png")
|
|
summary_path = os.path.join(bundle_dir, "summary.json")
|
|
|
|
_write_csv(step_csv_path, step_rows, fieldnames=list(step_rows[0].keys()) if step_rows else ["episode", "step"])
|
|
_write_csv(edge_csv_path, edge_rows, fieldnames=list(edge_rows[0].keys()) if edge_rows else ["episode", "step", "edge_index"])
|
|
_write_csv(
|
|
detector_csv_path,
|
|
detector_rows,
|
|
fieldnames=list(detector_rows[0].keys()) if detector_rows else ["episode", "step", "cell_id"],
|
|
)
|
|
_plot_episode_heatmap(heatmap_path, edge_rows, edge_ids, detector_rows, title_prefix=f"Episode {episode}")
|
|
_write_summary(summary_path, summary)
|
|
|
|
|
|
def save_training_episode_artifacts(
|
|
log_dir: str,
|
|
episode: int,
|
|
episode_metrics: Sequence[Dict],
|
|
control_edges: Sequence[str],
|
|
summary: Optional[Dict] = None,
|
|
snapshot_interval: int = 50,
|
|
):
|
|
if not episode_metrics:
|
|
return
|
|
|
|
step_rows, edge_rows, edge_ids = _normalize_step_rows(episode, episode_metrics)
|
|
detector_rows = _build_detector_rows_from_xml(log_dir, episode)
|
|
if control_edges and len(control_edges) >= len(edge_ids):
|
|
edge_ids = list(control_edges[: len(edge_ids)])
|
|
for row in edge_rows:
|
|
row["edge_id"] = edge_ids[int(row["edge_index"])]
|
|
|
|
artifacts_root = os.path.join(log_dir, "episode_artifacts")
|
|
latest_dir = os.path.join(artifacts_root, "latest")
|
|
_save_episode_bundle(latest_dir, episode, step_rows, edge_rows, edge_ids, detector_rows, summary)
|
|
|
|
if snapshot_interval > 0 and episode % snapshot_interval == 0:
|
|
snapshot_dir = os.path.join(artifacts_root, "snapshots", f"episode_{episode:04d}")
|
|
_save_episode_bundle(snapshot_dir, episode, step_rows, edge_rows, edge_ids, detector_rows, summary)
|