统一热力图绘制
This commit is contained in:
parent
2f594b0eb0
commit
3d1782c348
|
|
@ -19,7 +19,6 @@ import yaml
|
||||||
|
|
||||||
matplotlib.use("Agg")
|
matplotlib.use("Agg")
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
from matplotlib import colors
|
|
||||||
|
|
||||||
from agents.appo_agent import APPOAgent
|
from agents.appo_agent import APPOAgent
|
||||||
from agents.dcmappo_agent import DCMAPPOAgent
|
from agents.dcmappo_agent import DCMAPPOAgent
|
||||||
|
|
@ -39,15 +38,18 @@ from agents.td3_agent import TD3Agent
|
||||||
from envs.edge_vsl_env import SUMOEdgeVSLEnvironment
|
from envs.edge_vsl_env import SUMOEdgeVSLEnvironment
|
||||||
from envs.reward_system import REWARD_COMPONENT_COLUMNS, REWARD_COMPONENT_LABELS
|
from envs.reward_system import REWARD_COMPONENT_COLUMNS, REWARD_COMPONENT_LABELS
|
||||||
from utils.config import get_agent_config
|
from utils.config import get_agent_config
|
||||||
|
from utils.heatmap_plotting import (
|
||||||
|
build_action_panel,
|
||||||
|
build_occupancy_panel,
|
||||||
|
build_speed_panel,
|
||||||
|
save_heatmap_panels,
|
||||||
|
)
|
||||||
from utils.run_dirs import find_shared_config_path, resolve_checkpoint_root
|
from utils.run_dirs import find_shared_config_path, resolve_checkpoint_root
|
||||||
|
|
||||||
|
|
||||||
MODEL_ORDER = ["ppo", "gpro", "appo", "mappo", "tcamappo", "dcmappo", "dqn", "madqn", "ddqn", "qmix", "dcqmix", "ddpg", "sac", "td3", "sctd3"]
|
MODEL_ORDER = ["ppo", "gpro", "appo", "mappo", "tcamappo", "dcmappo", "dqn", "madqn", "ddqn", "qmix", "dcqmix", "ddpg", "sac", "td3", "sctd3"]
|
||||||
BASELINE_NAME = "no_control"
|
BASELINE_NAME = "no_control"
|
||||||
EVAL_ORDER = [BASELINE_NAME] + MODEL_ORDER
|
EVAL_ORDER = [BASELINE_NAME] + MODEL_ORDER
|
||||||
HEATMAP_SPEED_RANGE_KMH = (40.0, 110.0)
|
|
||||||
HEATMAP_OCCUPANCY_RANGE = (0.0, 35.0)
|
|
||||||
HEATMAP_ACTION_LEVELS_KMH = [40.0, 60.0, 80.0, 100.0, 110.0]
|
|
||||||
MODEL_LABELS = {
|
MODEL_LABELS = {
|
||||||
BASELINE_NAME: "NO_CONTROL",
|
BASELINE_NAME: "NO_CONTROL",
|
||||||
"ppo": "PPO",
|
"ppo": "PPO",
|
||||||
|
|
@ -917,19 +919,6 @@ def plot_summary_bars(summary_df: pd.DataFrame, output_dir: str):
|
||||||
def plot_model_heatmaps(edge_df: pd.DataFrame, detector_df: pd.DataFrame, output_dir: str):
|
def plot_model_heatmaps(edge_df: pd.DataFrame, detector_df: pd.DataFrame, output_dir: str):
|
||||||
heatmap_dir = os.path.join(output_dir, "heatmaps")
|
heatmap_dir = os.path.join(output_dir, "heatmaps")
|
||||||
os.makedirs(heatmap_dir, exist_ok=True)
|
os.makedirs(heatmap_dir, exist_ok=True)
|
||||||
speed_cmap = plt.get_cmap("RdYlGn").copy()
|
|
||||||
speed_cmap.set_bad(color="#d9d9d9")
|
|
||||||
occ_cmap = plt.get_cmap("magma").copy()
|
|
||||||
occ_cmap.set_bad(color="#d9d9d9")
|
|
||||||
action_cmap = plt.get_cmap("viridis", len(HEATMAP_ACTION_LEVELS_KMH)).copy()
|
|
||||||
action_cmap.set_bad(color="#d9d9d9")
|
|
||||||
action_boundaries = [HEATMAP_ACTION_LEVELS_KMH[0] - 10.0]
|
|
||||||
action_boundaries.extend(
|
|
||||||
(left + right) / 2.0
|
|
||||||
for left, right in zip(HEATMAP_ACTION_LEVELS_KMH[:-1], HEATMAP_ACTION_LEVELS_KMH[1:])
|
|
||||||
)
|
|
||||||
action_boundaries.append(HEATMAP_ACTION_LEVELS_KMH[-1] + 10.0)
|
|
||||||
action_norm = colors.BoundaryNorm(action_boundaries, action_cmap.N, clip=True)
|
|
||||||
|
|
||||||
for model_name in EVAL_ORDER:
|
for model_name in EVAL_ORDER:
|
||||||
detector_model_df = detector_df[detector_df["model"] == model_name]
|
detector_model_df = detector_df[detector_df["model"] == model_name]
|
||||||
|
|
@ -951,70 +940,39 @@ def plot_model_heatmaps(edge_df: pd.DataFrame, detector_df: pd.DataFrame, output
|
||||||
.sort_values("edge_index")
|
.sort_values("edge_index")
|
||||||
)
|
)
|
||||||
ordered_edge_ids = edge_order["edge_id"].tolist()
|
ordered_edge_ids = edge_order["edge_id"].tolist()
|
||||||
action_grid = edge_model_df.pivot(index="edge_id", columns="step", values="action_speed_kmh").reindex(ordered_edge_ids).values
|
action_plot_df = edge_model_df.copy()
|
||||||
|
if "action_applied" in action_plot_df.columns:
|
||||||
fig, axes = plt.subplots(1, 3, figsize=(21, 7), gridspec_kw={"width_ratios": [1.3, 0.8, 1.3]})
|
action_plot_df.loc[~action_plot_df["action_applied"].astype(bool), "action_speed_kmh"] = np.nan
|
||||||
|
action_grid = (
|
||||||
speed_im = axes[0].imshow(
|
action_plot_df.pivot(index="edge_id", columns="step", values="action_speed_kmh")
|
||||||
np.ma.masked_invalid(speed_grid),
|
.reindex(ordered_edge_ids)
|
||||||
aspect="auto",
|
.values
|
||||||
origin="lower",
|
|
||||||
cmap=speed_cmap,
|
|
||||||
vmin=HEATMAP_SPEED_RANGE_KMH[0],
|
|
||||||
vmax=HEATMAP_SPEED_RANGE_KMH[1],
|
|
||||||
)
|
|
||||||
axes[0].set_title(f"{MODEL_LABELS[model_name]} Measured Speed (km/h)")
|
|
||||||
axes[0].set_xlabel("Step")
|
|
||||||
axes[0].set_ylabel("Detector Cell (bottom=upstream, top=downstream)")
|
|
||||||
plt.colorbar(
|
|
||||||
speed_im,
|
|
||||||
ax=axes[0],
|
|
||||||
fraction=0.046,
|
|
||||||
pad=0.04,
|
|
||||||
ticks=HEATMAP_ACTION_LEVELS_KMH,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
action_im = axes[1].imshow(
|
panels = [
|
||||||
np.ma.masked_invalid(action_grid),
|
build_speed_panel(
|
||||||
aspect="auto",
|
speed_grid,
|
||||||
origin="lower",
|
ordered_cell_ids,
|
||||||
cmap=action_cmap,
|
f"{MODEL_LABELS[model_name]} Measured Speed (km/h)",
|
||||||
norm=action_norm,
|
"Detector Cell (bottom=upstream, top=downstream)",
|
||||||
|
),
|
||||||
|
build_action_panel(
|
||||||
|
action_grid,
|
||||||
|
ordered_edge_ids,
|
||||||
|
f"{MODEL_LABELS[model_name]} Applied VSL (km/h)",
|
||||||
|
),
|
||||||
|
build_occupancy_panel(
|
||||||
|
occ_grid,
|
||||||
|
ordered_cell_ids,
|
||||||
|
f"{MODEL_LABELS[model_name]} Occupancy (%)",
|
||||||
|
"Detector Cell (bottom=upstream, top=downstream)",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
save_heatmap_panels(
|
||||||
|
os.path.join(heatmap_dir, f"{model_name}_heatmaps.png"),
|
||||||
|
panels,
|
||||||
|
xlabel="Decision Step",
|
||||||
)
|
)
|
||||||
axes[1].set_title(f"{MODEL_LABELS[model_name]} Applied VSL (km/h)")
|
|
||||||
axes[1].set_xlabel("Step")
|
|
||||||
axes[1].set_ylabel("Controlled Edge")
|
|
||||||
plt.colorbar(
|
|
||||||
action_im,
|
|
||||||
ax=axes[1],
|
|
||||||
fraction=0.046,
|
|
||||||
pad=0.04,
|
|
||||||
ticks=HEATMAP_ACTION_LEVELS_KMH,
|
|
||||||
boundaries=action_boundaries,
|
|
||||||
)
|
|
||||||
|
|
||||||
occ_im = axes[2].imshow(
|
|
||||||
np.ma.masked_invalid(occ_grid),
|
|
||||||
aspect="auto",
|
|
||||||
origin="lower",
|
|
||||||
cmap=occ_cmap,
|
|
||||||
vmin=HEATMAP_OCCUPANCY_RANGE[0],
|
|
||||||
vmax=HEATMAP_OCCUPANCY_RANGE[1],
|
|
||||||
)
|
|
||||||
axes[2].set_title(f"{MODEL_LABELS[model_name]} Occupancy (%)")
|
|
||||||
axes[2].set_xlabel("Step")
|
|
||||||
axes[2].set_ylabel("Detector Cell (bottom=upstream, top=downstream)")
|
|
||||||
plt.colorbar(
|
|
||||||
occ_im,
|
|
||||||
ax=axes[2],
|
|
||||||
fraction=0.046,
|
|
||||||
pad=0.04,
|
|
||||||
ticks=np.arange(0.0, 36.0, 5.0),
|
|
||||||
)
|
|
||||||
|
|
||||||
plt.tight_layout()
|
|
||||||
plt.savefig(os.path.join(heatmap_dir, f"{model_name}_heatmaps.png"), dpi=160)
|
|
||||||
plt.close()
|
|
||||||
|
|
||||||
|
|
||||||
def _format_metric(value: float, fmt: str) -> str:
|
def _format_metric(value: float, fmt: str) -> str:
|
||||||
|
|
|
||||||
|
|
@ -9,9 +9,14 @@ import matplotlib
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
matplotlib.use("Agg")
|
matplotlib.use("Agg")
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
|
|
||||||
from envs.reward_system import REWARD_COMPONENT_COLUMNS
|
from envs.reward_system import REWARD_COMPONENT_COLUMNS
|
||||||
|
from utils.heatmap_plotting import (
|
||||||
|
build_action_panel,
|
||||||
|
build_density_panel,
|
||||||
|
build_speed_panel,
|
||||||
|
save_heatmap_panels,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _safe_float(value, default: float = float("nan")) -> float:
|
def _safe_float(value, default: float = float("nan")) -> float:
|
||||||
|
|
@ -50,7 +55,8 @@ def _normalize_step_rows(episode: int, episode_metrics: Sequence[Dict]) -> Tuple
|
||||||
action_speeds = list(info.get("edge_speeds_kmh", []))
|
action_speeds = list(info.get("edge_speeds_kmh", []))
|
||||||
measured_speeds_ms = list(info.get("edge_speeds_ms", []))
|
measured_speeds_ms = list(info.get("edge_speeds_ms", []))
|
||||||
occupancies = list(info.get("edge_occupancies", []))
|
occupancies = list(info.get("edge_occupancies", []))
|
||||||
edge_count = max(len(action_speeds), len(measured_speeds_ms), len(occupancies))
|
action_applied_mask = list(info.get("action_applied_mask", []))
|
||||||
|
edge_count = max(len(action_speeds), len(measured_speeds_ms), len(occupancies), len(action_applied_mask))
|
||||||
|
|
||||||
if not edge_ids:
|
if not edge_ids:
|
||||||
edge_ids = [f"edge_{idx:02d}" for idx in range(edge_count)]
|
edge_ids = [f"edge_{idx:02d}" for idx in range(edge_count)]
|
||||||
|
|
@ -63,6 +69,7 @@ def _normalize_step_rows(episode: int, episode_metrics: Sequence[Dict]) -> Tuple
|
||||||
"edge_index": edge_idx,
|
"edge_index": edge_idx,
|
||||||
"edge_id": edge_ids[edge_idx],
|
"edge_id": edge_ids[edge_idx],
|
||||||
"action_speed_kmh": _safe_float(action_speeds[edge_idx] if edge_idx < len(action_speeds) else np.nan),
|
"action_speed_kmh": _safe_float(action_speeds[edge_idx] if edge_idx < len(action_speeds) else np.nan),
|
||||||
|
"action_applied": bool(action_applied_mask[edge_idx]) if edge_idx < len(action_applied_mask) else True,
|
||||||
"measured_speed_kmh": _safe_float(
|
"measured_speed_kmh": _safe_float(
|
||||||
measured_speeds_ms[edge_idx] * 3.6 if edge_idx < len(measured_speeds_ms) else np.nan
|
measured_speeds_ms[edge_idx] * 3.6 if edge_idx < len(measured_speeds_ms) else np.nan
|
||||||
),
|
),
|
||||||
|
|
@ -213,64 +220,87 @@ def _build_detector_rows_from_xml(log_dir: str, episode: int) -> List[Dict]:
|
||||||
return detector_rows
|
return detector_rows
|
||||||
|
|
||||||
|
|
||||||
def _plot_episode_heatmap(path: str, detector_rows: Sequence[Dict], title_prefix: str):
|
def _plot_episode_heatmap(
|
||||||
if not detector_rows:
|
path: str,
|
||||||
|
edge_rows: Sequence[Dict],
|
||||||
|
edge_ids: Sequence[str],
|
||||||
|
detector_rows: Sequence[Dict],
|
||||||
|
title_prefix: str,
|
||||||
|
):
|
||||||
|
has_action = bool(edge_rows and edge_ids)
|
||||||
|
has_detector = bool(detector_rows)
|
||||||
|
if not has_action and not has_detector:
|
||||||
return
|
return
|
||||||
|
|
||||||
step_values = sorted({int(row["step"]) for row in detector_rows})
|
step_values = sorted(
|
||||||
ordered_cells = []
|
{
|
||||||
seen_cells = set()
|
int(row["step"])
|
||||||
for row in sorted(
|
for row in list(edge_rows) + list(detector_rows)
|
||||||
detector_rows,
|
if row.get("step") is not None
|
||||||
key=lambda item: (int(item["cell_order"]), str(item["cell_id"])),
|
}
|
||||||
):
|
)
|
||||||
cell_id = str(row["cell_id"])
|
if not step_values:
|
||||||
if cell_id in seen_cells:
|
return
|
||||||
continue
|
|
||||||
seen_cells.add(cell_id)
|
|
||||||
ordered_cells.append(cell_id)
|
|
||||||
|
|
||||||
num_cells = len(ordered_cells)
|
|
||||||
num_steps = len(step_values)
|
num_steps = len(step_values)
|
||||||
step_to_col = {step: idx for idx, step in enumerate(step_values)}
|
step_to_col = {step: idx for idx, step in enumerate(step_values)}
|
||||||
cell_to_row = {cell_id: idx for idx, cell_id in enumerate(ordered_cells)}
|
panels = []
|
||||||
|
if has_action:
|
||||||
|
ordered_edge_ids = list(edge_ids)
|
||||||
|
edge_to_row = {edge_id: idx for idx, edge_id in enumerate(ordered_edge_ids)}
|
||||||
|
action_grid = np.full((len(ordered_edge_ids), num_steps), np.nan, dtype=np.float32)
|
||||||
|
for row in edge_rows:
|
||||||
|
edge_id = str(row["edge_id"])
|
||||||
|
if edge_id not in edge_to_row:
|
||||||
|
continue
|
||||||
|
if not bool(row.get("action_applied", True)):
|
||||||
|
continue
|
||||||
|
row_idx = edge_to_row[edge_id]
|
||||||
|
col_idx = step_to_col[int(row["step"])]
|
||||||
|
action_grid[row_idx, col_idx] = _safe_float(row["action_speed_kmh"])
|
||||||
|
panels.append(build_action_panel(action_grid, ordered_edge_ids, f"{title_prefix} Applied VSL (km/h)"))
|
||||||
|
|
||||||
speed_grid = np.full((num_cells, num_steps), np.nan, dtype=np.float32)
|
if has_detector:
|
||||||
density_grid = np.full((num_cells, num_steps), np.nan, dtype=np.float32)
|
ordered_cells = []
|
||||||
|
seen_cells = set()
|
||||||
|
for row in sorted(
|
||||||
|
detector_rows,
|
||||||
|
key=lambda item: (int(item["cell_order"]), str(item["cell_id"])),
|
||||||
|
):
|
||||||
|
cell_id = str(row["cell_id"])
|
||||||
|
if cell_id in seen_cells:
|
||||||
|
continue
|
||||||
|
seen_cells.add(cell_id)
|
||||||
|
ordered_cells.append(cell_id)
|
||||||
|
|
||||||
for row in detector_rows:
|
num_cells = len(ordered_cells)
|
||||||
row_idx = cell_to_row[str(row["cell_id"])]
|
cell_to_row = {cell_id: idx for idx, cell_id in enumerate(ordered_cells)}
|
||||||
col_idx = step_to_col[int(row["step"])]
|
speed_grid = np.full((num_cells, num_steps), np.nan, dtype=np.float32)
|
||||||
speed_grid[row_idx, col_idx] = _safe_float(row["measured_speed_kmh"])
|
density_grid = np.full((num_cells, num_steps), np.nan, dtype=np.float32)
|
||||||
density_grid[row_idx, col_idx] = _safe_float(row["density_vehpkm"])
|
|
||||||
|
|
||||||
fig, axes = plt.subplots(1, 2, figsize=(18, 7), sharex=True, sharey=True)
|
for row in detector_rows:
|
||||||
plots = [
|
row_idx = cell_to_row[str(row["cell_id"])]
|
||||||
(speed_grid, "RdYlGn", "Measured Speed (km/h)"),
|
col_idx = step_to_col[int(row["step"])]
|
||||||
(density_grid, "YlOrRd", "Density (veh/km)"),
|
speed_grid[row_idx, col_idx] = _safe_float(row["measured_speed_kmh"])
|
||||||
]
|
density_grid[row_idx, col_idx] = _safe_float(row["density_vehpkm"])
|
||||||
|
|
||||||
for ax, (grid, cmap, title) in zip(axes, plots):
|
plots = [
|
||||||
image = ax.imshow(
|
build_speed_panel(
|
||||||
np.ma.masked_invalid(grid),
|
speed_grid,
|
||||||
aspect="auto",
|
ordered_cells,
|
||||||
origin="lower",
|
f"{title_prefix} Measured Speed (km/h)",
|
||||||
cmap=cmap,
|
"Detector Cell",
|
||||||
interpolation="nearest",
|
),
|
||||||
resample=False,
|
build_density_panel(
|
||||||
)
|
density_grid,
|
||||||
ax.set_title(f"{title_prefix} {title}")
|
ordered_cells,
|
||||||
ax.set_xlabel("Decision Step")
|
f"{title_prefix} Density (veh/km)",
|
||||||
ax.set_ylabel("Detector Cell")
|
"Detector Cell",
|
||||||
tick_step = max(num_cells // 12, 1)
|
),
|
||||||
tick_idx = np.arange(0, num_cells, tick_step)
|
]
|
||||||
ax.set_yticks(tick_idx)
|
panels = plots[:1] + panels + plots[1:]
|
||||||
ax.set_yticklabels([ordered_cells[idx] for idx in tick_idx], fontsize=8)
|
|
||||||
plt.colorbar(image, ax=ax, fraction=0.046, pad=0.04)
|
|
||||||
|
|
||||||
plt.tight_layout()
|
save_heatmap_panels(path, panels, xlabel="Decision Step")
|
||||||
plt.savefig(path, dpi=160)
|
|
||||||
plt.close(fig)
|
|
||||||
|
|
||||||
|
|
||||||
def _write_summary(path: str, summary: Optional[Dict]):
|
def _write_summary(path: str, summary: Optional[Dict]):
|
||||||
|
|
@ -303,7 +333,7 @@ def _save_episode_bundle(
|
||||||
detector_rows,
|
detector_rows,
|
||||||
fieldnames=list(detector_rows[0].keys()) if detector_rows else ["episode", "step", "cell_id"],
|
fieldnames=list(detector_rows[0].keys()) if detector_rows else ["episode", "step", "cell_id"],
|
||||||
)
|
)
|
||||||
_plot_episode_heatmap(heatmap_path, detector_rows, title_prefix=f"Episode {episode}")
|
_plot_episode_heatmap(heatmap_path, edge_rows, edge_ids, detector_rows, title_prefix=f"Episode {episode}")
|
||||||
_write_summary(summary_path, summary)
|
_write_summary(summary_path, summary)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,149 @@
|
||||||
|
"""Shared heatmap styling and rendering utilities."""
|
||||||
|
from typing import Dict, Optional, Sequence
|
||||||
|
|
||||||
|
import matplotlib
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
matplotlib.use("Agg")
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
from matplotlib import colors
|
||||||
|
|
||||||
|
HEATMAP_SPEED_RANGE_KMH = (0.0, 110.0)
|
||||||
|
HEATMAP_OCCUPANCY_RANGE = (0.0, 35.0)
|
||||||
|
HEATMAP_ACTION_LEVELS_KMH = [40.0, 60.0, 80.0, 100.0, 110.0]
|
||||||
|
HEATMAP_OCCUPANCY_TICKS = np.arange(0.0, 36.0, 5.0)
|
||||||
|
HEATMAP_BAD_COLOR = "#d9d9d9"
|
||||||
|
|
||||||
|
|
||||||
|
def _get_masked_cmap(name: str, levels: Optional[int] = None):
|
||||||
|
cmap = plt.get_cmap(name, levels).copy() if levels is not None else plt.get_cmap(name).copy()
|
||||||
|
cmap.set_bad(color=HEATMAP_BAD_COLOR)
|
||||||
|
return cmap
|
||||||
|
|
||||||
|
|
||||||
|
def get_action_boundaries():
|
||||||
|
boundaries = [HEATMAP_ACTION_LEVELS_KMH[0] - 10.0]
|
||||||
|
boundaries.extend(
|
||||||
|
(left + right) / 2.0
|
||||||
|
for left, right in zip(HEATMAP_ACTION_LEVELS_KMH[:-1], HEATMAP_ACTION_LEVELS_KMH[1:])
|
||||||
|
)
|
||||||
|
boundaries.append(HEATMAP_ACTION_LEVELS_KMH[-1] + 10.0)
|
||||||
|
return boundaries
|
||||||
|
|
||||||
|
|
||||||
|
def get_action_norm():
|
||||||
|
action_cmap = _get_masked_cmap("viridis", len(HEATMAP_ACTION_LEVELS_KMH))
|
||||||
|
return action_cmap, colors.BoundaryNorm(get_action_boundaries(), action_cmap.N, clip=True)
|
||||||
|
|
||||||
|
|
||||||
|
def build_action_panel(grid, row_labels: Sequence[str], title: str) -> Dict:
|
||||||
|
action_cmap, action_norm = get_action_norm()
|
||||||
|
return {
|
||||||
|
"grid": grid,
|
||||||
|
"row_labels": row_labels,
|
||||||
|
"title": title,
|
||||||
|
"ylabel": "Controlled Edge",
|
||||||
|
"cmap": action_cmap,
|
||||||
|
"norm": action_norm,
|
||||||
|
"width_ratio": 0.8,
|
||||||
|
"colorbar_kwargs": {
|
||||||
|
"ticks": HEATMAP_ACTION_LEVELS_KMH,
|
||||||
|
"boundaries": get_action_boundaries(),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def build_speed_panel(grid, row_labels: Sequence[str], title: str, ylabel: str) -> Dict:
|
||||||
|
return {
|
||||||
|
"grid": grid,
|
||||||
|
"row_labels": row_labels,
|
||||||
|
"title": title,
|
||||||
|
"ylabel": ylabel,
|
||||||
|
"cmap": _get_masked_cmap("RdYlGn"),
|
||||||
|
"vmin": HEATMAP_SPEED_RANGE_KMH[0],
|
||||||
|
"vmax": HEATMAP_SPEED_RANGE_KMH[1],
|
||||||
|
"width_ratio": 1.3,
|
||||||
|
"colorbar_kwargs": {
|
||||||
|
"ticks": HEATMAP_ACTION_LEVELS_KMH,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def build_occupancy_panel(grid, row_labels: Sequence[str], title: str, ylabel: str) -> Dict:
|
||||||
|
return {
|
||||||
|
"grid": grid,
|
||||||
|
"row_labels": row_labels,
|
||||||
|
"title": title,
|
||||||
|
"ylabel": ylabel,
|
||||||
|
"cmap": _get_masked_cmap("magma"),
|
||||||
|
"vmin": HEATMAP_OCCUPANCY_RANGE[0],
|
||||||
|
"vmax": HEATMAP_OCCUPANCY_RANGE[1],
|
||||||
|
"width_ratio": 1.3,
|
||||||
|
"colorbar_kwargs": {
|
||||||
|
"ticks": HEATMAP_OCCUPANCY_TICKS,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def build_density_panel(grid, row_labels: Sequence[str], title: str, ylabel: str) -> Dict:
|
||||||
|
return {
|
||||||
|
"grid": grid,
|
||||||
|
"row_labels": row_labels,
|
||||||
|
"title": title,
|
||||||
|
"ylabel": ylabel,
|
||||||
|
"cmap": _get_masked_cmap("YlOrRd"),
|
||||||
|
"width_ratio": 1.3,
|
||||||
|
"colorbar_kwargs": {},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def save_heatmap_panels(
|
||||||
|
path: str,
|
||||||
|
panels: Sequence[Optional[Dict]],
|
||||||
|
xlabel: str = "Decision Step",
|
||||||
|
dpi: int = 160,
|
||||||
|
):
|
||||||
|
valid_panels = [panel for panel in panels if panel and panel.get("grid") is not None]
|
||||||
|
if not valid_panels:
|
||||||
|
return
|
||||||
|
|
||||||
|
width_ratios = [float(panel.get("width_ratio", 1.0)) for panel in valid_panels]
|
||||||
|
fig, axes = plt.subplots(
|
||||||
|
1,
|
||||||
|
len(valid_panels),
|
||||||
|
figsize=(7 * len(valid_panels), 7),
|
||||||
|
gridspec_kw={"width_ratios": width_ratios},
|
||||||
|
)
|
||||||
|
if len(valid_panels) == 1:
|
||||||
|
axes = [axes]
|
||||||
|
|
||||||
|
for ax, panel in zip(axes, valid_panels):
|
||||||
|
grid = np.asarray(panel["grid"], dtype=np.float32)
|
||||||
|
image = ax.imshow(
|
||||||
|
np.ma.masked_invalid(grid),
|
||||||
|
aspect="auto",
|
||||||
|
origin="lower",
|
||||||
|
cmap=panel["cmap"],
|
||||||
|
interpolation="nearest",
|
||||||
|
resample=False,
|
||||||
|
vmin=panel.get("vmin"),
|
||||||
|
vmax=panel.get("vmax"),
|
||||||
|
norm=panel.get("norm"),
|
||||||
|
)
|
||||||
|
ax.set_title(str(panel["title"]))
|
||||||
|
ax.set_xlabel(xlabel)
|
||||||
|
ax.set_ylabel(str(panel["ylabel"]))
|
||||||
|
|
||||||
|
row_labels = list(panel.get("row_labels", []))
|
||||||
|
if row_labels:
|
||||||
|
tick_step = max(len(row_labels) // 12, 1)
|
||||||
|
tick_idx = np.arange(0, len(row_labels), tick_step)
|
||||||
|
ax.set_yticks(tick_idx)
|
||||||
|
ax.set_yticklabels([row_labels[idx] for idx in tick_idx], fontsize=8)
|
||||||
|
|
||||||
|
colorbar_kwargs = dict(panel.get("colorbar_kwargs", {}))
|
||||||
|
plt.colorbar(image, ax=ax, fraction=0.046, pad=0.04, **colorbar_kwargs)
|
||||||
|
|
||||||
|
plt.tight_layout()
|
||||||
|
plt.savefig(path, dpi=dpi)
|
||||||
|
plt.close(fig)
|
||||||
Loading…
Reference in New Issue