150 lines
4.5 KiB
Python
150 lines
4.5 KiB
Python
"""Shared heatmap styling and rendering utilities."""
|
|
from typing import Dict, Optional, Sequence
|
|
|
|
import matplotlib
|
|
import numpy as np
|
|
|
|
matplotlib.use("Agg")
|
|
import matplotlib.pyplot as plt
|
|
from matplotlib import colors
|
|
|
|
HEATMAP_SPEED_RANGE_KMH = (0.0, 110.0)
|
|
HEATMAP_OCCUPANCY_RANGE = (0.0, 35.0)
|
|
HEATMAP_ACTION_LEVELS_KMH = [40.0, 60.0, 80.0, 100.0, 110.0]
|
|
HEATMAP_OCCUPANCY_TICKS = np.arange(0.0, 36.0, 5.0)
|
|
HEATMAP_BAD_COLOR = "#d9d9d9"
|
|
|
|
|
|
def _get_masked_cmap(name: str, levels: Optional[int] = None):
|
|
cmap = plt.get_cmap(name, levels).copy() if levels is not None else plt.get_cmap(name).copy()
|
|
cmap.set_bad(color=HEATMAP_BAD_COLOR)
|
|
return cmap
|
|
|
|
|
|
def get_action_boundaries():
|
|
boundaries = [HEATMAP_ACTION_LEVELS_KMH[0] - 10.0]
|
|
boundaries.extend(
|
|
(left + right) / 2.0
|
|
for left, right in zip(HEATMAP_ACTION_LEVELS_KMH[:-1], HEATMAP_ACTION_LEVELS_KMH[1:])
|
|
)
|
|
boundaries.append(HEATMAP_ACTION_LEVELS_KMH[-1] + 10.0)
|
|
return boundaries
|
|
|
|
|
|
def get_action_norm():
|
|
action_cmap = _get_masked_cmap("viridis", len(HEATMAP_ACTION_LEVELS_KMH))
|
|
return action_cmap, colors.BoundaryNorm(get_action_boundaries(), action_cmap.N, clip=True)
|
|
|
|
|
|
def build_action_panel(grid, row_labels: Sequence[str], title: str) -> Dict:
|
|
action_cmap, action_norm = get_action_norm()
|
|
return {
|
|
"grid": grid,
|
|
"row_labels": row_labels,
|
|
"title": title,
|
|
"ylabel": "Controlled Edge",
|
|
"cmap": action_cmap,
|
|
"norm": action_norm,
|
|
"width_ratio": 0.8,
|
|
"colorbar_kwargs": {
|
|
"ticks": HEATMAP_ACTION_LEVELS_KMH,
|
|
"boundaries": get_action_boundaries(),
|
|
},
|
|
}
|
|
|
|
|
|
def build_speed_panel(grid, row_labels: Sequence[str], title: str, ylabel: str) -> Dict:
|
|
return {
|
|
"grid": grid,
|
|
"row_labels": row_labels,
|
|
"title": title,
|
|
"ylabel": ylabel,
|
|
"cmap": _get_masked_cmap("RdYlGn"),
|
|
"vmin": HEATMAP_SPEED_RANGE_KMH[0],
|
|
"vmax": HEATMAP_SPEED_RANGE_KMH[1],
|
|
"width_ratio": 1.3,
|
|
"colorbar_kwargs": {
|
|
"ticks": HEATMAP_ACTION_LEVELS_KMH,
|
|
},
|
|
}
|
|
|
|
|
|
def build_occupancy_panel(grid, row_labels: Sequence[str], title: str, ylabel: str) -> Dict:
|
|
return {
|
|
"grid": grid,
|
|
"row_labels": row_labels,
|
|
"title": title,
|
|
"ylabel": ylabel,
|
|
"cmap": _get_masked_cmap("magma"),
|
|
"vmin": HEATMAP_OCCUPANCY_RANGE[0],
|
|
"vmax": HEATMAP_OCCUPANCY_RANGE[1],
|
|
"width_ratio": 1.3,
|
|
"colorbar_kwargs": {
|
|
"ticks": HEATMAP_OCCUPANCY_TICKS,
|
|
},
|
|
}
|
|
|
|
|
|
def build_density_panel(grid, row_labels: Sequence[str], title: str, ylabel: str) -> Dict:
|
|
return {
|
|
"grid": grid,
|
|
"row_labels": row_labels,
|
|
"title": title,
|
|
"ylabel": ylabel,
|
|
"cmap": _get_masked_cmap("YlOrRd"),
|
|
"width_ratio": 1.3,
|
|
"colorbar_kwargs": {},
|
|
}
|
|
|
|
|
|
def save_heatmap_panels(
|
|
path: str,
|
|
panels: Sequence[Optional[Dict]],
|
|
xlabel: str = "Decision Step",
|
|
dpi: int = 160,
|
|
):
|
|
valid_panels = [panel for panel in panels if panel and panel.get("grid") is not None]
|
|
if not valid_panels:
|
|
return
|
|
|
|
width_ratios = [float(panel.get("width_ratio", 1.0)) for panel in valid_panels]
|
|
fig, axes = plt.subplots(
|
|
1,
|
|
len(valid_panels),
|
|
figsize=(7 * len(valid_panels), 7),
|
|
gridspec_kw={"width_ratios": width_ratios},
|
|
)
|
|
if len(valid_panels) == 1:
|
|
axes = [axes]
|
|
|
|
for ax, panel in zip(axes, valid_panels):
|
|
grid = np.asarray(panel["grid"], dtype=np.float32)
|
|
image = ax.imshow(
|
|
np.ma.masked_invalid(grid),
|
|
aspect="auto",
|
|
origin="lower",
|
|
cmap=panel["cmap"],
|
|
interpolation="nearest",
|
|
resample=False,
|
|
vmin=panel.get("vmin"),
|
|
vmax=panel.get("vmax"),
|
|
norm=panel.get("norm"),
|
|
)
|
|
ax.set_title(str(panel["title"]))
|
|
ax.set_xlabel(xlabel)
|
|
ax.set_ylabel(str(panel["ylabel"]))
|
|
|
|
row_labels = list(panel.get("row_labels", []))
|
|
if row_labels:
|
|
tick_step = max(len(row_labels) // 12, 1)
|
|
tick_idx = np.arange(0, len(row_labels), tick_step)
|
|
ax.set_yticks(tick_idx)
|
|
ax.set_yticklabels([row_labels[idx] for idx in tick_idx], fontsize=8)
|
|
|
|
colorbar_kwargs = dict(panel.get("colorbar_kwargs", {}))
|
|
plt.colorbar(image, ax=ax, fraction=0.046, pad=0.04, **colorbar_kwargs)
|
|
|
|
plt.tight_layout()
|
|
plt.savefig(path, dpi=dpi)
|
|
plt.close(fig)
|