"""Plot VSL action heatmaps from exported evaluation CSV files.""" import os import matplotlib import matplotlib.pyplot as plt import numpy as np import pandas as pd matplotlib.use("Agg") OUTPUT_DIR = "evaluation_results" MODEL_LABELS = ["PPO", "APPO", "DQN", "No Control"] MODEL_FILE_ALIASES = { "PPO": ["PPO"], "APPO": ["APPO"], "DQN": ["DQN"], "No Control": ["No Control"], } SPEED_MAP = [40, 60, 80, 100, 120] def resolve_csv_path(model_label: str) -> str: for file_stem in MODEL_FILE_ALIASES[model_label]: candidate = os.path.join(OUTPUT_DIR, f"{file_stem}_action.csv") if os.path.isfile(candidate): return candidate raise FileNotFoundError(f"Missing action CSV for {model_label} under {OUTPUT_DIR}") fig, axes = plt.subplots(2, 2, figsize=(16, 12)) axes = axes.flatten() for idx, model_label in enumerate(MODEL_LABELS): csv_path = resolve_csv_path(model_label) df = pd.read_csv(csv_path) actions = df.values vsl_speeds = np.vectorize(lambda action_idx: SPEED_MAP[int(action_idx)])(actions) im = axes[idx].imshow( vsl_speeds.T, aspect="auto", cmap="RdYlGn", vmin=40, vmax=120, origin="lower", ) axes[idx].set_title(f"{model_label} VSL Decision") axes[idx].set_xlabel("Time Step") axes[idx].set_ylabel("Control Edge Index") cbar = plt.colorbar(im, ax=axes[idx], label="VSL (km/h)") cbar.set_ticks([40, 60, 80, 100, 120]) plt.tight_layout() plt.savefig(os.path.join(OUTPUT_DIR, "action_heatmaps.png"), dpi=150) print("Saved action heatmaps.")