绘制训练过程热力图时使用传感器数据

This commit is contained in:
Zihan Ye 2026-04-16 11:01:53 +08:00
parent 1691f9b33c
commit c126e4a9f5
1 changed files with 181 additions and 22 deletions

View File

@ -3,6 +3,7 @@ import csv
import json
import os
from typing import Dict, Iterable, List, Optional, Sequence, Tuple
import xml.etree.ElementTree as ET
import matplotlib
import numpy as np
@ -80,32 +81,173 @@ def _write_csv(path: str, rows: Iterable[Dict], fieldnames: Sequence[str]):
writer.writerows(rows)
def _plot_episode_heatmap(path: str, edge_rows: Sequence[Dict], edge_ids: Sequence[str], title_prefix: str):
if not edge_rows or not edge_ids:
def _lane_to_edge_id(lane_id: str) -> str:
if "_" not in lane_id:
return lane_id
return lane_id.rsplit("_", 1)[0]
def _parse_detector_layout(additional_path: str) -> List[Dict]:
if not os.path.isfile(additional_path):
return []
try:
root = ET.parse(additional_path).getroot()
except ET.ParseError:
return []
cells: List[Dict] = []
current_key = None
current_cell: Optional[Dict] = None
edge_pos_counts: Dict[str, int] = {}
for elem in root.findall("inductionLoop"):
det_id = elem.get("id")
lane_id = elem.get("lane")
pos_raw = elem.get("pos")
if not det_id or not lane_id or pos_raw is None:
continue
edge_id = _lane_to_edge_id(lane_id)
position_m = _safe_float(pos_raw)
key = (edge_id, round(position_m, 3))
if key != current_key:
pos_index = edge_pos_counts.get(edge_id, 0)
edge_pos_counts[edge_id] = pos_index + 1
current_cell = {
"cell_id": f"{edge_id}@{pos_index}",
"edge_id": edge_id,
"pos_index": pos_index,
"position_m": position_m,
"detector_ids": [],
}
cells.append(current_cell)
current_key = key
current_cell["detector_ids"].append(det_id)
return cells
def _build_detector_rows_from_xml(log_dir: str, episode: int) -> List[Dict]:
metrics_dir = os.path.join(log_dir, "sumo_metrics")
suffix = f"ep{episode:04d}"
additional_path = os.path.join(metrics_dir, f"runtime_metrics_il_{suffix}.add.xml")
metrics_path = os.path.join(metrics_dir, f"metrics_il_output_{suffix}.xml")
cells = _parse_detector_layout(additional_path)
if not cells or not os.path.isfile(metrics_path):
return []
try:
root = ET.parse(metrics_path).getroot()
except ET.ParseError:
return []
interval_map: Dict[Tuple[float, float], Dict[str, Dict[str, float]]] = {}
for elem in root.findall("interval"):
det_id = elem.get("id")
if not det_id:
continue
begin = _safe_float(elem.get("begin"))
end = _safe_float(elem.get("end"))
if not np.isfinite(begin) or not np.isfinite(end):
continue
interval_key = (begin, end)
interval_values = interval_map.setdefault(interval_key, {})
interval_values[det_id] = {
"speed_ms": _safe_float(elem.get("speed")),
"flow_vehph": _safe_float(elem.get("flow"), default=0.0),
"n_veh_contrib": _safe_float(elem.get("nVehContrib"), default=0.0),
"occupancy": _safe_float(elem.get("occupancy"), default=0.0),
}
detector_rows: List[Dict] = []
sorted_intervals = sorted(interval_map.keys())
for step_idx, interval_key in enumerate(sorted_intervals, start=1):
interval_values = interval_map[interval_key]
for cell_order, cell in enumerate(cells):
speed_weight_sum = 0.0
speed_weight_total = 0.0
density_sum = 0.0
for det_id in cell["detector_ids"]:
lane_metrics = interval_values.get(det_id)
if not lane_metrics:
continue
speed_ms = lane_metrics["speed_ms"]
flow_vehph = max(lane_metrics["flow_vehph"], 0.0)
n_veh_contrib = max(lane_metrics["n_veh_contrib"], 0.0)
if np.isfinite(speed_ms) and speed_ms > 0.0:
speed_weight = n_veh_contrib if n_veh_contrib > 0.0 else flow_vehph
if speed_weight <= 0.0:
speed_weight = 1.0
speed_weight_sum += speed_ms * 3.6 * speed_weight
speed_weight_total += speed_weight
density_sum += flow_vehph / max(speed_ms * 3.6, 1e-6)
measured_speed_kmh = (
speed_weight_sum / speed_weight_total
if speed_weight_total > 0.0
else np.nan
)
detector_rows.append(
{
"episode": episode,
"step": step_idx,
"cell_order": cell_order,
"cell_id": cell["cell_id"],
"edge_id": cell["edge_id"],
"pos_index": int(cell["pos_index"]),
"position_m": float(cell["position_m"]),
"measured_speed_kmh": float(measured_speed_kmh) if np.isfinite(measured_speed_kmh) else np.nan,
"density_vehpkm": float(density_sum),
}
)
return detector_rows
def _plot_episode_heatmap(path: str, detector_rows: Sequence[Dict], title_prefix: str):
if not detector_rows:
return
step_values = sorted({int(row["step"]) for row in edge_rows})
num_edges = len(edge_ids)
step_values = sorted({int(row["step"]) for row in detector_rows})
ordered_cells = []
seen_cells = set()
for row in sorted(
detector_rows,
key=lambda item: (int(item["cell_order"]), str(item["cell_id"])),
):
cell_id = str(row["cell_id"])
if cell_id in seen_cells:
continue
seen_cells.add(cell_id)
ordered_cells.append(cell_id)
num_cells = len(ordered_cells)
num_steps = len(step_values)
step_to_col = {step: idx for idx, step in enumerate(step_values)}
edge_to_row = {edge_id: idx for idx, edge_id in enumerate(edge_ids)}
cell_to_row = {cell_id: idx for idx, cell_id in enumerate(ordered_cells)}
action_grid = np.full((num_edges, num_steps), np.nan, dtype=np.float32)
speed_grid = np.full((num_edges, num_steps), np.nan, dtype=np.float32)
occupancy_grid = np.full((num_edges, num_steps), np.nan, dtype=np.float32)
speed_grid = np.full((num_cells, num_steps), np.nan, dtype=np.float32)
density_grid = np.full((num_cells, num_steps), np.nan, dtype=np.float32)
for row in edge_rows:
row_idx = edge_to_row[row["edge_id"]]
for row in detector_rows:
row_idx = cell_to_row[str(row["cell_id"])]
col_idx = step_to_col[int(row["step"])]
action_grid[row_idx, col_idx] = _safe_float(row["action_speed_kmh"])
speed_grid[row_idx, col_idx] = _safe_float(row["measured_speed_kmh"])
occupancy_grid[row_idx, col_idx] = _safe_float(row["occupancy"])
density_grid[row_idx, col_idx] = _safe_float(row["density_vehpkm"])
fig, axes = plt.subplots(1, 3, figsize=(18, 7), sharex=True, sharey=True)
fig, axes = plt.subplots(1, 2, figsize=(18, 7), sharex=True, sharey=True)
plots = [
(action_grid, "viridis", "Applied VSL (km/h)"),
(speed_grid, "RdYlGn", "Measured Speed (km/h)"),
(occupancy_grid, "magma", "Occupancy (%)"),
(density_grid, "YlOrRd", "Density (veh/km)"),
]
for ax, (grid, cmap, title) in zip(axes, plots):
@ -119,9 +261,11 @@ def _plot_episode_heatmap(path: str, edge_rows: Sequence[Dict], edge_ids: Sequen
)
ax.set_title(f"{title_prefix} {title}")
ax.set_xlabel("Decision Step")
ax.set_ylabel("Controlled Edge")
ax.set_yticks(np.arange(num_edges))
ax.set_yticklabels(edge_ids)
ax.set_ylabel("Detector Cell")
tick_step = max(num_cells // 12, 1)
tick_idx = np.arange(0, num_cells, tick_step)
ax.set_yticks(tick_idx)
ax.set_yticklabels([ordered_cells[idx] for idx in tick_idx], fontsize=8)
plt.colorbar(image, ax=ax, fraction=0.046, pad=0.04)
plt.tight_layout()
@ -137,15 +281,29 @@ def _write_summary(path: str, summary: Optional[Dict]):
json.dump(summary, f, ensure_ascii=False, indent=2)
def _save_episode_bundle(bundle_dir: str, episode: int, step_rows: Sequence[Dict], edge_rows: Sequence[Dict], edge_ids: Sequence[str], summary: Optional[Dict]):
def _save_episode_bundle(
bundle_dir: str,
episode: int,
step_rows: Sequence[Dict],
edge_rows: Sequence[Dict],
edge_ids: Sequence[str],
detector_rows: Sequence[Dict],
summary: Optional[Dict],
):
step_csv_path = os.path.join(bundle_dir, "step_metrics.csv")
edge_csv_path = os.path.join(bundle_dir, "edge_metrics.csv")
detector_csv_path = os.path.join(bundle_dir, "detector_metrics.csv")
heatmap_path = os.path.join(bundle_dir, "episode_heatmap.png")
summary_path = os.path.join(bundle_dir, "summary.json")
_write_csv(step_csv_path, step_rows, fieldnames=list(step_rows[0].keys()) if step_rows else ["episode", "step"])
_write_csv(edge_csv_path, edge_rows, fieldnames=list(edge_rows[0].keys()) if edge_rows else ["episode", "step", "edge_index"])
_plot_episode_heatmap(heatmap_path, edge_rows, edge_ids, title_prefix=f"Episode {episode}")
_write_csv(
detector_csv_path,
detector_rows,
fieldnames=list(detector_rows[0].keys()) if detector_rows else ["episode", "step", "cell_id"],
)
_plot_episode_heatmap(heatmap_path, detector_rows, title_prefix=f"Episode {episode}")
_write_summary(summary_path, summary)
@ -161,6 +319,7 @@ def save_training_episode_artifacts(
return
step_rows, edge_rows, edge_ids = _normalize_step_rows(episode, episode_metrics)
detector_rows = _build_detector_rows_from_xml(log_dir, episode)
if control_edges and len(control_edges) >= len(edge_ids):
edge_ids = list(control_edges[: len(edge_ids)])
for row in edge_rows:
@ -168,8 +327,8 @@ def save_training_episode_artifacts(
artifacts_root = os.path.join(log_dir, "episode_artifacts")
latest_dir = os.path.join(artifacts_root, "latest")
_save_episode_bundle(latest_dir, episode, step_rows, edge_rows, edge_ids, summary)
_save_episode_bundle(latest_dir, episode, step_rows, edge_rows, edge_ids, detector_rows, summary)
if snapshot_interval > 0 and episode % snapshot_interval == 0:
snapshot_dir = os.path.join(artifacts_root, "snapshots", f"episode_{episode:04d}")
_save_episode_bundle(snapshot_dir, episode, step_rows, edge_rows, edge_ids, summary)
_save_episode_bundle(snapshot_dir, episode, step_rows, edge_rows, edge_ids, detector_rows, summary)