绘制训练过程热力图时使用传感器数据

This commit is contained in:
Zihan Ye 2026-04-16 11:01:53 +08:00
parent 1691f9b33c
commit c126e4a9f5
1 changed files with 181 additions and 22 deletions

View File

@ -3,6 +3,7 @@ import csv
import json import json
import os import os
from typing import Dict, Iterable, List, Optional, Sequence, Tuple from typing import Dict, Iterable, List, Optional, Sequence, Tuple
import xml.etree.ElementTree as ET
import matplotlib import matplotlib
import numpy as np import numpy as np
@ -80,32 +81,173 @@ def _write_csv(path: str, rows: Iterable[Dict], fieldnames: Sequence[str]):
writer.writerows(rows) writer.writerows(rows)
def _plot_episode_heatmap(path: str, edge_rows: Sequence[Dict], edge_ids: Sequence[str], title_prefix: str): def _lane_to_edge_id(lane_id: str) -> str:
if not edge_rows or not edge_ids: if "_" not in lane_id:
return lane_id
return lane_id.rsplit("_", 1)[0]
def _parse_detector_layout(additional_path: str) -> List[Dict]:
if not os.path.isfile(additional_path):
return []
try:
root = ET.parse(additional_path).getroot()
except ET.ParseError:
return []
cells: List[Dict] = []
current_key = None
current_cell: Optional[Dict] = None
edge_pos_counts: Dict[str, int] = {}
for elem in root.findall("inductionLoop"):
det_id = elem.get("id")
lane_id = elem.get("lane")
pos_raw = elem.get("pos")
if not det_id or not lane_id or pos_raw is None:
continue
edge_id = _lane_to_edge_id(lane_id)
position_m = _safe_float(pos_raw)
key = (edge_id, round(position_m, 3))
if key != current_key:
pos_index = edge_pos_counts.get(edge_id, 0)
edge_pos_counts[edge_id] = pos_index + 1
current_cell = {
"cell_id": f"{edge_id}@{pos_index}",
"edge_id": edge_id,
"pos_index": pos_index,
"position_m": position_m,
"detector_ids": [],
}
cells.append(current_cell)
current_key = key
current_cell["detector_ids"].append(det_id)
return cells
def _build_detector_rows_from_xml(log_dir: str, episode: int) -> List[Dict]:
metrics_dir = os.path.join(log_dir, "sumo_metrics")
suffix = f"ep{episode:04d}"
additional_path = os.path.join(metrics_dir, f"runtime_metrics_il_{suffix}.add.xml")
metrics_path = os.path.join(metrics_dir, f"metrics_il_output_{suffix}.xml")
cells = _parse_detector_layout(additional_path)
if not cells or not os.path.isfile(metrics_path):
return []
try:
root = ET.parse(metrics_path).getroot()
except ET.ParseError:
return []
interval_map: Dict[Tuple[float, float], Dict[str, Dict[str, float]]] = {}
for elem in root.findall("interval"):
det_id = elem.get("id")
if not det_id:
continue
begin = _safe_float(elem.get("begin"))
end = _safe_float(elem.get("end"))
if not np.isfinite(begin) or not np.isfinite(end):
continue
interval_key = (begin, end)
interval_values = interval_map.setdefault(interval_key, {})
interval_values[det_id] = {
"speed_ms": _safe_float(elem.get("speed")),
"flow_vehph": _safe_float(elem.get("flow"), default=0.0),
"n_veh_contrib": _safe_float(elem.get("nVehContrib"), default=0.0),
"occupancy": _safe_float(elem.get("occupancy"), default=0.0),
}
detector_rows: List[Dict] = []
sorted_intervals = sorted(interval_map.keys())
for step_idx, interval_key in enumerate(sorted_intervals, start=1):
interval_values = interval_map[interval_key]
for cell_order, cell in enumerate(cells):
speed_weight_sum = 0.0
speed_weight_total = 0.0
density_sum = 0.0
for det_id in cell["detector_ids"]:
lane_metrics = interval_values.get(det_id)
if not lane_metrics:
continue
speed_ms = lane_metrics["speed_ms"]
flow_vehph = max(lane_metrics["flow_vehph"], 0.0)
n_veh_contrib = max(lane_metrics["n_veh_contrib"], 0.0)
if np.isfinite(speed_ms) and speed_ms > 0.0:
speed_weight = n_veh_contrib if n_veh_contrib > 0.0 else flow_vehph
if speed_weight <= 0.0:
speed_weight = 1.0
speed_weight_sum += speed_ms * 3.6 * speed_weight
speed_weight_total += speed_weight
density_sum += flow_vehph / max(speed_ms * 3.6, 1e-6)
measured_speed_kmh = (
speed_weight_sum / speed_weight_total
if speed_weight_total > 0.0
else np.nan
)
detector_rows.append(
{
"episode": episode,
"step": step_idx,
"cell_order": cell_order,
"cell_id": cell["cell_id"],
"edge_id": cell["edge_id"],
"pos_index": int(cell["pos_index"]),
"position_m": float(cell["position_m"]),
"measured_speed_kmh": float(measured_speed_kmh) if np.isfinite(measured_speed_kmh) else np.nan,
"density_vehpkm": float(density_sum),
}
)
return detector_rows
def _plot_episode_heatmap(path: str, detector_rows: Sequence[Dict], title_prefix: str):
if not detector_rows:
return return
step_values = sorted({int(row["step"]) for row in edge_rows}) step_values = sorted({int(row["step"]) for row in detector_rows})
num_edges = len(edge_ids) ordered_cells = []
seen_cells = set()
for row in sorted(
detector_rows,
key=lambda item: (int(item["cell_order"]), str(item["cell_id"])),
):
cell_id = str(row["cell_id"])
if cell_id in seen_cells:
continue
seen_cells.add(cell_id)
ordered_cells.append(cell_id)
num_cells = len(ordered_cells)
num_steps = len(step_values) num_steps = len(step_values)
step_to_col = {step: idx for idx, step in enumerate(step_values)} step_to_col = {step: idx for idx, step in enumerate(step_values)}
edge_to_row = {edge_id: idx for idx, edge_id in enumerate(edge_ids)} cell_to_row = {cell_id: idx for idx, cell_id in enumerate(ordered_cells)}
action_grid = np.full((num_edges, num_steps), np.nan, dtype=np.float32) speed_grid = np.full((num_cells, num_steps), np.nan, dtype=np.float32)
speed_grid = np.full((num_edges, num_steps), np.nan, dtype=np.float32) density_grid = np.full((num_cells, num_steps), np.nan, dtype=np.float32)
occupancy_grid = np.full((num_edges, num_steps), np.nan, dtype=np.float32)
for row in edge_rows: for row in detector_rows:
row_idx = edge_to_row[row["edge_id"]] row_idx = cell_to_row[str(row["cell_id"])]
col_idx = step_to_col[int(row["step"])] col_idx = step_to_col[int(row["step"])]
action_grid[row_idx, col_idx] = _safe_float(row["action_speed_kmh"])
speed_grid[row_idx, col_idx] = _safe_float(row["measured_speed_kmh"]) speed_grid[row_idx, col_idx] = _safe_float(row["measured_speed_kmh"])
occupancy_grid[row_idx, col_idx] = _safe_float(row["occupancy"]) density_grid[row_idx, col_idx] = _safe_float(row["density_vehpkm"])
fig, axes = plt.subplots(1, 3, figsize=(18, 7), sharex=True, sharey=True) fig, axes = plt.subplots(1, 2, figsize=(18, 7), sharex=True, sharey=True)
plots = [ plots = [
(action_grid, "viridis", "Applied VSL (km/h)"),
(speed_grid, "RdYlGn", "Measured Speed (km/h)"), (speed_grid, "RdYlGn", "Measured Speed (km/h)"),
(occupancy_grid, "magma", "Occupancy (%)"), (density_grid, "YlOrRd", "Density (veh/km)"),
] ]
for ax, (grid, cmap, title) in zip(axes, plots): for ax, (grid, cmap, title) in zip(axes, plots):
@ -119,9 +261,11 @@ def _plot_episode_heatmap(path: str, edge_rows: Sequence[Dict], edge_ids: Sequen
) )
ax.set_title(f"{title_prefix} {title}") ax.set_title(f"{title_prefix} {title}")
ax.set_xlabel("Decision Step") ax.set_xlabel("Decision Step")
ax.set_ylabel("Controlled Edge") ax.set_ylabel("Detector Cell")
ax.set_yticks(np.arange(num_edges)) tick_step = max(num_cells // 12, 1)
ax.set_yticklabels(edge_ids) tick_idx = np.arange(0, num_cells, tick_step)
ax.set_yticks(tick_idx)
ax.set_yticklabels([ordered_cells[idx] for idx in tick_idx], fontsize=8)
plt.colorbar(image, ax=ax, fraction=0.046, pad=0.04) plt.colorbar(image, ax=ax, fraction=0.046, pad=0.04)
plt.tight_layout() plt.tight_layout()
@ -137,15 +281,29 @@ def _write_summary(path: str, summary: Optional[Dict]):
json.dump(summary, f, ensure_ascii=False, indent=2) json.dump(summary, f, ensure_ascii=False, indent=2)
def _save_episode_bundle(bundle_dir: str, episode: int, step_rows: Sequence[Dict], edge_rows: Sequence[Dict], edge_ids: Sequence[str], summary: Optional[Dict]): def _save_episode_bundle(
bundle_dir: str,
episode: int,
step_rows: Sequence[Dict],
edge_rows: Sequence[Dict],
edge_ids: Sequence[str],
detector_rows: Sequence[Dict],
summary: Optional[Dict],
):
step_csv_path = os.path.join(bundle_dir, "step_metrics.csv") step_csv_path = os.path.join(bundle_dir, "step_metrics.csv")
edge_csv_path = os.path.join(bundle_dir, "edge_metrics.csv") edge_csv_path = os.path.join(bundle_dir, "edge_metrics.csv")
detector_csv_path = os.path.join(bundle_dir, "detector_metrics.csv")
heatmap_path = os.path.join(bundle_dir, "episode_heatmap.png") heatmap_path = os.path.join(bundle_dir, "episode_heatmap.png")
summary_path = os.path.join(bundle_dir, "summary.json") summary_path = os.path.join(bundle_dir, "summary.json")
_write_csv(step_csv_path, step_rows, fieldnames=list(step_rows[0].keys()) if step_rows else ["episode", "step"]) _write_csv(step_csv_path, step_rows, fieldnames=list(step_rows[0].keys()) if step_rows else ["episode", "step"])
_write_csv(edge_csv_path, edge_rows, fieldnames=list(edge_rows[0].keys()) if edge_rows else ["episode", "step", "edge_index"]) _write_csv(edge_csv_path, edge_rows, fieldnames=list(edge_rows[0].keys()) if edge_rows else ["episode", "step", "edge_index"])
_plot_episode_heatmap(heatmap_path, edge_rows, edge_ids, title_prefix=f"Episode {episode}") _write_csv(
detector_csv_path,
detector_rows,
fieldnames=list(detector_rows[0].keys()) if detector_rows else ["episode", "step", "cell_id"],
)
_plot_episode_heatmap(heatmap_path, detector_rows, title_prefix=f"Episode {episode}")
_write_summary(summary_path, summary) _write_summary(summary_path, summary)
@ -161,6 +319,7 @@ def save_training_episode_artifacts(
return return
step_rows, edge_rows, edge_ids = _normalize_step_rows(episode, episode_metrics) step_rows, edge_rows, edge_ids = _normalize_step_rows(episode, episode_metrics)
detector_rows = _build_detector_rows_from_xml(log_dir, episode)
if control_edges and len(control_edges) >= len(edge_ids): if control_edges and len(control_edges) >= len(edge_ids):
edge_ids = list(control_edges[: len(edge_ids)]) edge_ids = list(control_edges[: len(edge_ids)])
for row in edge_rows: for row in edge_rows:
@ -168,8 +327,8 @@ def save_training_episode_artifacts(
artifacts_root = os.path.join(log_dir, "episode_artifacts") artifacts_root = os.path.join(log_dir, "episode_artifacts")
latest_dir = os.path.join(artifacts_root, "latest") latest_dir = os.path.join(artifacts_root, "latest")
_save_episode_bundle(latest_dir, episode, step_rows, edge_rows, edge_ids, summary) _save_episode_bundle(latest_dir, episode, step_rows, edge_rows, edge_ids, detector_rows, summary)
if snapshot_interval > 0 and episode % snapshot_interval == 0: if snapshot_interval > 0 and episode % snapshot_interval == 0:
snapshot_dir = os.path.join(artifacts_root, "snapshots", f"episode_{episode:04d}") snapshot_dir = os.path.join(artifacts_root, "snapshots", f"episode_{episode:04d}")
_save_episode_bundle(snapshot_dir, episode, step_rows, edge_rows, edge_ids, summary) _save_episode_bundle(snapshot_dir, episode, step_rows, edge_rows, edge_ids, detector_rows, summary)