限速策略热力图绘制更新

This commit is contained in:
Zihan Ye 2026-04-17 07:30:24 +08:00
parent 4edf0e5310
commit 70de65d973
3 changed files with 21 additions and 18 deletions

View File

@ -935,12 +935,12 @@ def plot_model_heatmaps(edge_df: pd.DataFrame, detector_df: pd.DataFrame, output
occ_grid = detector_model_df.pivot(index="cell_id", columns="step", values="occupancy").reindex(ordered_cell_ids).values
edge_order = (
edge_model_df[edge_model_df["action_applied"].astype(bool)][["edge_index", "edge_id"]]
edge_model_df[["edge_index", "edge_id"]]
.drop_duplicates()
.sort_values("edge_index")
)
ordered_edge_ids = edge_order["edge_id"].tolist()
action_plot_df = edge_model_df[edge_model_df["action_applied"].astype(bool)].copy()
action_plot_df = edge_model_df.copy()
action_grid = (
action_plot_df.pivot(index="edge_id", columns="step", values="action_speed_kmh")
.reindex(ordered_edge_ids)
@ -960,7 +960,8 @@ def plot_model_heatmaps(edge_df: pd.DataFrame, detector_df: pd.DataFrame, output
build_action_panel(
action_grid,
ordered_edge_ids,
f"{MODEL_LABELS[model_name]} Applied VSL (km/h)",
f"{MODEL_LABELS[model_name]} Segment Speed Limit (km/h)",
ylabel="Corridor Segment",
)
)
panels.append(

View File

@ -279,34 +279,31 @@ def _plot_episode_heatmap(
step_to_col = {step: idx for idx, step in enumerate(step_values)}
panels = []
if has_action:
active_edge_ids = {
str(row["edge_id"])
for row in edge_rows
if bool(row.get("action_applied", True))
}
ordered_edge_ids = [edge_id for edge_id in edge_ids if edge_id in active_edge_ids]
if not ordered_edge_ids and active_edge_ids:
ordered_edge_ids = list(edge_ids)
if not ordered_edge_ids:
ordered_edge_ids = [
str(row["edge_id"])
for row in sorted(edge_rows, key=lambda item: int(item["edge_index"]))
if bool(row.get("action_applied", True))
]
ordered_edge_ids = list(dict.fromkeys(ordered_edge_ids))
if not ordered_edge_ids:
ordered_edge_ids = []
edge_to_row = {edge_id: idx for idx, edge_id in enumerate(ordered_edge_ids)}
action_grid = np.full((len(ordered_edge_ids), num_steps), np.nan, dtype=np.float32)
for row in edge_rows:
edge_id = str(row["edge_id"])
if edge_id not in edge_to_row:
continue
if not bool(row.get("action_applied", True)):
continue
row_idx = edge_to_row[edge_id]
col_idx = step_to_col[int(row["step"])]
action_grid[row_idx, col_idx] = _safe_float(row["action_speed_kmh"])
if ordered_edge_ids:
panels.append(build_action_panel(action_grid, ordered_edge_ids, f"{title_prefix} Applied VSL (km/h)"))
panels.append(
build_action_panel(
action_grid,
ordered_edge_ids,
f"{title_prefix} Segment Speed Limit (km/h)",
ylabel="Corridor Segment",
)
)
if has_detector:
ordered_cells = []

View File

@ -36,13 +36,18 @@ def get_action_norm():
return action_cmap, colors.BoundaryNorm(get_action_boundaries(), action_cmap.N, clip=True)
def build_action_panel(grid, row_labels: Sequence[str], title: str) -> Dict:
def build_action_panel(
grid,
row_labels: Sequence[str],
title: str,
ylabel: str = "Corridor Segment",
) -> Dict:
action_cmap, action_norm = get_action_norm()
return {
"grid": grid,
"row_labels": row_labels,
"title": title,
"ylabel": "Controlled Edge",
"ylabel": ylabel,
"cmap": action_cmap,
"norm": action_norm,
"width_ratio": 0.8,