ctm-dqn/utils/heatmap_plotting.py

155 lines
4.5 KiB
Python

"""Shared heatmap styling and rendering utilities."""
from typing import Dict, Optional, Sequence
import matplotlib
import numpy as np
matplotlib.use("Agg")
import matplotlib.pyplot as plt
from matplotlib import colors
HEATMAP_SPEED_RANGE_KMH = (0.0, 110.0)
HEATMAP_OCCUPANCY_RANGE = (0.0, 35.0)
HEATMAP_ACTION_LEVELS_KMH = [40.0, 60.0, 80.0, 100.0, 110.0]
HEATMAP_OCCUPANCY_TICKS = np.arange(0.0, 36.0, 5.0)
HEATMAP_BAD_COLOR = "#d9d9d9"
def _get_masked_cmap(name: str, levels: Optional[int] = None):
cmap = plt.get_cmap(name, levels).copy() if levels is not None else plt.get_cmap(name).copy()
cmap.set_bad(color=HEATMAP_BAD_COLOR)
return cmap
def get_action_boundaries():
boundaries = [HEATMAP_ACTION_LEVELS_KMH[0] - 10.0]
boundaries.extend(
(left + right) / 2.0
for left, right in zip(HEATMAP_ACTION_LEVELS_KMH[:-1], HEATMAP_ACTION_LEVELS_KMH[1:])
)
boundaries.append(HEATMAP_ACTION_LEVELS_KMH[-1] + 10.0)
return boundaries
def get_action_norm():
action_cmap = _get_masked_cmap("viridis", len(HEATMAP_ACTION_LEVELS_KMH))
return action_cmap, colors.BoundaryNorm(get_action_boundaries(), action_cmap.N, clip=True)
def build_action_panel(
grid,
row_labels: Sequence[str],
title: str,
ylabel: str = "Corridor Segment",
) -> Dict:
action_cmap, action_norm = get_action_norm()
return {
"grid": grid,
"row_labels": row_labels,
"title": title,
"ylabel": ylabel,
"cmap": action_cmap,
"norm": action_norm,
"width_ratio": 0.8,
"colorbar_kwargs": {
"ticks": HEATMAP_ACTION_LEVELS_KMH,
"boundaries": get_action_boundaries(),
},
}
def build_speed_panel(grid, row_labels: Sequence[str], title: str, ylabel: str) -> Dict:
return {
"grid": grid,
"row_labels": row_labels,
"title": title,
"ylabel": ylabel,
"cmap": _get_masked_cmap("RdYlGn"),
"vmin": HEATMAP_SPEED_RANGE_KMH[0],
"vmax": HEATMAP_SPEED_RANGE_KMH[1],
"width_ratio": 1.3,
"colorbar_kwargs": {
"ticks": HEATMAP_ACTION_LEVELS_KMH,
},
}
def build_occupancy_panel(grid, row_labels: Sequence[str], title: str, ylabel: str) -> Dict:
return {
"grid": grid,
"row_labels": row_labels,
"title": title,
"ylabel": ylabel,
"cmap": _get_masked_cmap("magma"),
"vmin": HEATMAP_OCCUPANCY_RANGE[0],
"vmax": HEATMAP_OCCUPANCY_RANGE[1],
"width_ratio": 1.3,
"colorbar_kwargs": {
"ticks": HEATMAP_OCCUPANCY_TICKS,
},
}
def build_density_panel(grid, row_labels: Sequence[str], title: str, ylabel: str) -> Dict:
return {
"grid": grid,
"row_labels": row_labels,
"title": title,
"ylabel": ylabel,
"cmap": _get_masked_cmap("YlOrRd"),
"width_ratio": 1.3,
"colorbar_kwargs": {},
}
def save_heatmap_panels(
path: str,
panels: Sequence[Optional[Dict]],
xlabel: str = "Decision Step",
dpi: int = 160,
):
valid_panels = [panel for panel in panels if panel and panel.get("grid") is not None]
if not valid_panels:
return
width_ratios = [float(panel.get("width_ratio", 1.0)) for panel in valid_panels]
fig, axes = plt.subplots(
1,
len(valid_panels),
figsize=(7 * len(valid_panels), 7),
gridspec_kw={"width_ratios": width_ratios},
)
if len(valid_panels) == 1:
axes = [axes]
for ax, panel in zip(axes, valid_panels):
grid = np.asarray(panel["grid"], dtype=np.float32)
image = ax.imshow(
np.ma.masked_invalid(grid),
aspect="auto",
origin="lower",
cmap=panel["cmap"],
interpolation="nearest",
resample=False,
vmin=panel.get("vmin"),
vmax=panel.get("vmax"),
norm=panel.get("norm"),
)
ax.set_title(str(panel["title"]))
ax.set_xlabel(xlabel)
ax.set_ylabel(str(panel["ylabel"]))
row_labels = list(panel.get("row_labels", []))
if row_labels:
tick_step = max(len(row_labels) // 12, 1)
tick_idx = np.arange(0, len(row_labels), tick_step)
ax.set_yticks(tick_idx)
ax.set_yticklabels([row_labels[idx] for idx in tick_idx], fontsize=8)
colorbar_kwargs = dict(panel.get("colorbar_kwargs", {}))
plt.colorbar(image, ax=ax, fraction=0.046, pad=0.04, **colorbar_kwargs)
plt.tight_layout()
plt.savefig(path, dpi=dpi)
plt.close(fig)