"""Shared heatmap styling and rendering utilities.""" from typing import Dict, Optional, Sequence import matplotlib import numpy as np matplotlib.use("Agg") import matplotlib.pyplot as plt from matplotlib import colors HEATMAP_SPEED_RANGE_KMH = (0.0, 110.0) HEATMAP_OCCUPANCY_RANGE = (0.0, 35.0) HEATMAP_ACTION_LEVELS_KMH = [40.0, 60.0, 80.0, 100.0, 110.0] HEATMAP_OCCUPANCY_TICKS = np.arange(0.0, 36.0, 5.0) HEATMAP_BAD_COLOR = "#d9d9d9" def _get_masked_cmap(name: str, levels: Optional[int] = None): cmap = plt.get_cmap(name, levels).copy() if levels is not None else plt.get_cmap(name).copy() cmap.set_bad(color=HEATMAP_BAD_COLOR) return cmap def get_action_boundaries(): boundaries = [HEATMAP_ACTION_LEVELS_KMH[0] - 10.0] boundaries.extend( (left + right) / 2.0 for left, right in zip(HEATMAP_ACTION_LEVELS_KMH[:-1], HEATMAP_ACTION_LEVELS_KMH[1:]) ) boundaries.append(HEATMAP_ACTION_LEVELS_KMH[-1] + 10.0) return boundaries def get_action_norm(): action_cmap = _get_masked_cmap("viridis", len(HEATMAP_ACTION_LEVELS_KMH)) return action_cmap, colors.BoundaryNorm(get_action_boundaries(), action_cmap.N, clip=True) def build_action_panel( grid, row_labels: Sequence[str], title: str, ylabel: str = "Corridor Segment", ) -> Dict: action_cmap, action_norm = get_action_norm() return { "grid": grid, "row_labels": row_labels, "title": title, "ylabel": ylabel, "cmap": action_cmap, "norm": action_norm, "width_ratio": 0.8, "colorbar_kwargs": { "ticks": HEATMAP_ACTION_LEVELS_KMH, "boundaries": get_action_boundaries(), }, } def build_speed_panel(grid, row_labels: Sequence[str], title: str, ylabel: str) -> Dict: return { "grid": grid, "row_labels": row_labels, "title": title, "ylabel": ylabel, "cmap": _get_masked_cmap("RdYlGn"), "vmin": HEATMAP_SPEED_RANGE_KMH[0], "vmax": HEATMAP_SPEED_RANGE_KMH[1], "width_ratio": 1.3, "colorbar_kwargs": { "ticks": HEATMAP_ACTION_LEVELS_KMH, }, } def build_occupancy_panel(grid, row_labels: Sequence[str], title: str, ylabel: str) -> Dict: return { "grid": grid, "row_labels": row_labels, "title": title, "ylabel": ylabel, "cmap": _get_masked_cmap("magma"), "vmin": HEATMAP_OCCUPANCY_RANGE[0], "vmax": HEATMAP_OCCUPANCY_RANGE[1], "width_ratio": 1.3, "colorbar_kwargs": { "ticks": HEATMAP_OCCUPANCY_TICKS, }, } def build_density_panel(grid, row_labels: Sequence[str], title: str, ylabel: str) -> Dict: return { "grid": grid, "row_labels": row_labels, "title": title, "ylabel": ylabel, "cmap": _get_masked_cmap("YlOrRd"), "width_ratio": 1.3, "colorbar_kwargs": {}, } def save_heatmap_panels( path: str, panels: Sequence[Optional[Dict]], xlabel: str = "Decision Step", dpi: int = 160, ): valid_panels = [panel for panel in panels if panel and panel.get("grid") is not None] if not valid_panels: return width_ratios = [float(panel.get("width_ratio", 1.0)) for panel in valid_panels] fig, axes = plt.subplots( 1, len(valid_panels), figsize=(7 * len(valid_panels), 7), gridspec_kw={"width_ratios": width_ratios}, ) if len(valid_panels) == 1: axes = [axes] for ax, panel in zip(axes, valid_panels): grid = np.asarray(panel["grid"], dtype=np.float32) image = ax.imshow( np.ma.masked_invalid(grid), aspect="auto", origin="lower", cmap=panel["cmap"], interpolation="nearest", resample=False, vmin=panel.get("vmin"), vmax=panel.get("vmax"), norm=panel.get("norm"), ) ax.set_title(str(panel["title"])) ax.set_xlabel(xlabel) ax.set_ylabel(str(panel["ylabel"])) row_labels = list(panel.get("row_labels", [])) if row_labels: tick_step = max(len(row_labels) // 12, 1) tick_idx = np.arange(0, len(row_labels), tick_step) ax.set_yticks(tick_idx) ax.set_yticklabels([row_labels[idx] for idx in tick_idx], fontsize=8) colorbar_kwargs = dict(panel.get("colorbar_kwargs", {})) plt.colorbar(image, ax=ax, fraction=0.046, pad=0.04, **colorbar_kwargs) plt.tight_layout() plt.savefig(path, dpi=dpi) plt.close(fig)