From 70de65d97315c81d6a45a311385cc883730c16a8 Mon Sep 17 00:00:00 2001 From: Maple-YZ Date: Fri, 17 Apr 2026 07:30:24 +0800 Subject: [PATCH] =?UTF-8?q?=E9=99=90=E9=80=9F=E7=AD=96=E7=95=A5=E7=83=AD?= =?UTF-8?q?=E5=8A=9B=E5=9B=BE=E7=BB=98=E5=88=B6=E6=9B=B4=E6=96=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- scripts/evaluate_models.py | 7 ++++--- utils/episode_artifacts.py | 23 ++++++++++------------- utils/heatmap_plotting.py | 9 +++++++-- 3 files changed, 21 insertions(+), 18 deletions(-) diff --git a/scripts/evaluate_models.py b/scripts/evaluate_models.py index 36e89ab..0fa71a4 100644 --- a/scripts/evaluate_models.py +++ b/scripts/evaluate_models.py @@ -935,12 +935,12 @@ def plot_model_heatmaps(edge_df: pd.DataFrame, detector_df: pd.DataFrame, output occ_grid = detector_model_df.pivot(index="cell_id", columns="step", values="occupancy").reindex(ordered_cell_ids).values edge_order = ( - edge_model_df[edge_model_df["action_applied"].astype(bool)][["edge_index", "edge_id"]] + edge_model_df[["edge_index", "edge_id"]] .drop_duplicates() .sort_values("edge_index") ) ordered_edge_ids = edge_order["edge_id"].tolist() - action_plot_df = edge_model_df[edge_model_df["action_applied"].astype(bool)].copy() + action_plot_df = edge_model_df.copy() action_grid = ( action_plot_df.pivot(index="edge_id", columns="step", values="action_speed_kmh") .reindex(ordered_edge_ids) @@ -960,7 +960,8 @@ def plot_model_heatmaps(edge_df: pd.DataFrame, detector_df: pd.DataFrame, output build_action_panel( action_grid, ordered_edge_ids, - f"{MODEL_LABELS[model_name]} Applied VSL (km/h)", + f"{MODEL_LABELS[model_name]} Segment Speed Limit (km/h)", + ylabel="Corridor Segment", ) ) panels.append( diff --git a/utils/episode_artifacts.py b/utils/episode_artifacts.py index fc71346..1f0a2c4 100644 --- a/utils/episode_artifacts.py +++ b/utils/episode_artifacts.py @@ -279,34 +279,31 @@ def _plot_episode_heatmap( step_to_col = {step: idx for idx, step in enumerate(step_values)} panels = [] if has_action: - active_edge_ids = { - str(row["edge_id"]) - for row in edge_rows - if bool(row.get("action_applied", True)) - } - ordered_edge_ids = [edge_id for edge_id in edge_ids if edge_id in active_edge_ids] - if not ordered_edge_ids and active_edge_ids: + ordered_edge_ids = list(edge_ids) + if not ordered_edge_ids: ordered_edge_ids = [ str(row["edge_id"]) for row in sorted(edge_rows, key=lambda item: int(item["edge_index"])) - if bool(row.get("action_applied", True)) ] ordered_edge_ids = list(dict.fromkeys(ordered_edge_ids)) - if not ordered_edge_ids: - ordered_edge_ids = [] edge_to_row = {edge_id: idx for idx, edge_id in enumerate(ordered_edge_ids)} action_grid = np.full((len(ordered_edge_ids), num_steps), np.nan, dtype=np.float32) for row in edge_rows: edge_id = str(row["edge_id"]) if edge_id not in edge_to_row: continue - if not bool(row.get("action_applied", True)): - continue row_idx = edge_to_row[edge_id] col_idx = step_to_col[int(row["step"])] action_grid[row_idx, col_idx] = _safe_float(row["action_speed_kmh"]) if ordered_edge_ids: - panels.append(build_action_panel(action_grid, ordered_edge_ids, f"{title_prefix} Applied VSL (km/h)")) + panels.append( + build_action_panel( + action_grid, + ordered_edge_ids, + f"{title_prefix} Segment Speed Limit (km/h)", + ylabel="Corridor Segment", + ) + ) if has_detector: ordered_cells = [] diff --git a/utils/heatmap_plotting.py b/utils/heatmap_plotting.py index 0ffd238..090b5ae 100644 --- a/utils/heatmap_plotting.py +++ b/utils/heatmap_plotting.py @@ -36,13 +36,18 @@ def get_action_norm(): return action_cmap, colors.BoundaryNorm(get_action_boundaries(), action_cmap.N, clip=True) -def build_action_panel(grid, row_labels: Sequence[str], title: str) -> Dict: +def build_action_panel( + grid, + row_labels: Sequence[str], + title: str, + ylabel: str = "Corridor Segment", +) -> Dict: action_cmap, action_norm = get_action_norm() return { "grid": grid, "row_labels": row_labels, "title": title, - "ylabel": "Controlled Edge", + "ylabel": ylabel, "cmap": action_cmap, "norm": action_norm, "width_ratio": 0.8,