更新热力图绘制
This commit is contained in:
parent
05ef01d93f
commit
288f67ae50
|
|
@ -17,6 +17,7 @@ sumo:
|
|||
|
||||
environment:
|
||||
control_interval: 60
|
||||
warmup_time: 900
|
||||
control_segment_length_m: 1000
|
||||
detector_spacing_m: 100
|
||||
detector_start_offset_m: 50
|
||||
|
|
|
|||
|
|
@ -59,7 +59,7 @@ class SUMOEdgeVSLEnvironment:
|
|||
|
||||
self.control_interval = env_cfg["control_interval"]
|
||||
self.steps_per_action = int(self.control_interval / self.step_length)
|
||||
self.warmup_time = 900
|
||||
self.warmup_time = int(env_cfg.get("warmup_time", 900))
|
||||
self.episode_length = int(
|
||||
(self.end_time - self.begin_time - self.warmup_time) / self.control_interval
|
||||
)
|
||||
|
|
@ -301,7 +301,7 @@ class SUMOEdgeVSLEnvironment:
|
|||
self._interval_mainline_travel_times = []
|
||||
self._start_sumo(seed=seed)
|
||||
|
||||
warmup_steps = int(900 / self.control_interval)
|
||||
warmup_steps = int(self.warmup_time / self.control_interval)
|
||||
for _ in range(warmup_steps):
|
||||
for _ in range(self.steps_per_action):
|
||||
traci.simulationStep()
|
||||
|
|
|
|||
|
|
@ -935,14 +935,12 @@ def plot_model_heatmaps(edge_df: pd.DataFrame, detector_df: pd.DataFrame, output
|
|||
occ_grid = detector_model_df.pivot(index="cell_id", columns="step", values="occupancy").reindex(ordered_cell_ids).values
|
||||
|
||||
edge_order = (
|
||||
edge_model_df[["edge_index", "edge_id"]]
|
||||
edge_model_df[edge_model_df["action_applied"].astype(bool)][["edge_index", "edge_id"]]
|
||||
.drop_duplicates()
|
||||
.sort_values("edge_index")
|
||||
)
|
||||
ordered_edge_ids = edge_order["edge_id"].tolist()
|
||||
action_plot_df = edge_model_df.copy()
|
||||
if "action_applied" in action_plot_df.columns:
|
||||
action_plot_df.loc[~action_plot_df["action_applied"].astype(bool), "action_speed_kmh"] = np.nan
|
||||
action_plot_df = edge_model_df[edge_model_df["action_applied"].astype(bool)].copy()
|
||||
action_grid = (
|
||||
action_plot_df.pivot(index="edge_id", columns="step", values="action_speed_kmh")
|
||||
.reindex(ordered_edge_ids)
|
||||
|
|
@ -956,18 +954,23 @@ def plot_model_heatmaps(edge_df: pd.DataFrame, detector_df: pd.DataFrame, output
|
|||
f"{MODEL_LABELS[model_name]} Measured Speed (km/h)",
|
||||
"Detector Cell (bottom=upstream, top=downstream)",
|
||||
),
|
||||
]
|
||||
if ordered_edge_ids:
|
||||
panels.append(
|
||||
build_action_panel(
|
||||
action_grid,
|
||||
ordered_edge_ids,
|
||||
f"{MODEL_LABELS[model_name]} Applied VSL (km/h)",
|
||||
),
|
||||
)
|
||||
)
|
||||
panels.append(
|
||||
build_occupancy_panel(
|
||||
occ_grid,
|
||||
ordered_cell_ids,
|
||||
f"{MODEL_LABELS[model_name]} Occupancy (%)",
|
||||
"Detector Cell (bottom=upstream, top=downstream)",
|
||||
),
|
||||
]
|
||||
)
|
||||
)
|
||||
save_heatmap_panels(
|
||||
os.path.join(heatmap_dir, f"{model_name}_heatmaps.png"),
|
||||
panels,
|
||||
|
|
|
|||
|
|
@ -220,6 +220,39 @@ def _build_detector_rows_from_xml(log_dir: str, episode: int) -> List[Dict]:
|
|||
return detector_rows
|
||||
|
||||
|
||||
def _align_detector_rows_to_decision_steps(
|
||||
detector_rows: Sequence[Dict],
|
||||
step_rows: Sequence[Dict],
|
||||
) -> List[Dict]:
|
||||
if not detector_rows or not step_rows:
|
||||
return list(detector_rows)
|
||||
|
||||
detector_step_values = sorted({int(row["step"]) for row in detector_rows})
|
||||
decision_step_values = sorted({int(row["step"]) for row in step_rows})
|
||||
if not detector_step_values or not decision_step_values:
|
||||
return list(detector_rows)
|
||||
|
||||
excess_steps = len(detector_step_values) - len(decision_step_values)
|
||||
if excess_steps <= 0:
|
||||
return list(detector_rows)
|
||||
|
||||
kept_detector_steps = detector_step_values[excess_steps:]
|
||||
step_mapping = {
|
||||
old_step: new_step
|
||||
for old_step, new_step in zip(kept_detector_steps, decision_step_values)
|
||||
}
|
||||
|
||||
aligned_rows: List[Dict] = []
|
||||
for row in detector_rows:
|
||||
old_step = int(row["step"])
|
||||
if old_step not in step_mapping:
|
||||
continue
|
||||
aligned_row = dict(row)
|
||||
aligned_row["step"] = step_mapping[old_step]
|
||||
aligned_rows.append(aligned_row)
|
||||
return aligned_rows
|
||||
|
||||
|
||||
def _plot_episode_heatmap(
|
||||
path: str,
|
||||
edge_rows: Sequence[Dict],
|
||||
|
|
@ -246,7 +279,21 @@ def _plot_episode_heatmap(
|
|||
step_to_col = {step: idx for idx, step in enumerate(step_values)}
|
||||
panels = []
|
||||
if has_action:
|
||||
ordered_edge_ids = list(edge_ids)
|
||||
active_edge_ids = {
|
||||
str(row["edge_id"])
|
||||
for row in edge_rows
|
||||
if bool(row.get("action_applied", True))
|
||||
}
|
||||
ordered_edge_ids = [edge_id for edge_id in edge_ids if edge_id in active_edge_ids]
|
||||
if not ordered_edge_ids and active_edge_ids:
|
||||
ordered_edge_ids = [
|
||||
str(row["edge_id"])
|
||||
for row in sorted(edge_rows, key=lambda item: int(item["edge_index"]))
|
||||
if bool(row.get("action_applied", True))
|
||||
]
|
||||
ordered_edge_ids = list(dict.fromkeys(ordered_edge_ids))
|
||||
if not ordered_edge_ids:
|
||||
ordered_edge_ids = []
|
||||
edge_to_row = {edge_id: idx for idx, edge_id in enumerate(ordered_edge_ids)}
|
||||
action_grid = np.full((len(ordered_edge_ids), num_steps), np.nan, dtype=np.float32)
|
||||
for row in edge_rows:
|
||||
|
|
@ -258,6 +305,7 @@ def _plot_episode_heatmap(
|
|||
row_idx = edge_to_row[edge_id]
|
||||
col_idx = step_to_col[int(row["step"])]
|
||||
action_grid[row_idx, col_idx] = _safe_float(row["action_speed_kmh"])
|
||||
if ordered_edge_ids:
|
||||
panels.append(build_action_panel(action_grid, ordered_edge_ids, f"{title_prefix} Applied VSL (km/h)"))
|
||||
|
||||
if has_detector:
|
||||
|
|
@ -350,6 +398,7 @@ def save_training_episode_artifacts(
|
|||
|
||||
step_rows, edge_rows, edge_ids = _normalize_step_rows(episode, episode_metrics)
|
||||
detector_rows = _build_detector_rows_from_xml(log_dir, episode)
|
||||
detector_rows = _align_detector_rows_to_decision_steps(detector_rows, step_rows)
|
||||
if control_edges and len(control_edges) >= len(edge_ids):
|
||||
edge_ids = list(control_edges[: len(edge_ids)])
|
||||
for row in edge_rows:
|
||||
|
|
|
|||
Loading…
Reference in New Issue