"""Plot density heatmaps from exported evaluation CSV files.""" import os import matplotlib import matplotlib.pyplot as plt import numpy as np import pandas as pd matplotlib.use("Agg") OUTPUT_DIR = "evaluation_results" MODEL_LABELS = ["PPO", "APPO", "DQN", "No Control"] MODEL_FILE_ALIASES = { "PPO": ["PPO"], "APPO": ["APPO"], "DQN": ["DQN"], "No Control": ["No Control"], } def resolve_csv_path(model_label: str) -> str: for file_stem in MODEL_FILE_ALIASES[model_label]: candidate = os.path.join(OUTPUT_DIR, f"{file_stem}_density.csv") if os.path.isfile(candidate): return candidate raise FileNotFoundError(f"Missing density CSV for {model_label} under {OUTPUT_DIR}") density_data = {} for model_label in MODEL_LABELS: csv_path = resolve_csv_path(model_label) df = pd.read_csv(csv_path) density_data[model_label] = df.values vmin = 10 vmax = 50 print(f"Density range: {vmin:.2f}% - {vmax:.2f}%") fig, axes = plt.subplots(2, 2, figsize=(16, 12)) axes = axes.flatten() for idx, model_label in enumerate(MODEL_LABELS): im = axes[idx].imshow( density_data[model_label].T, aspect="auto", cmap="YlOrRd", vmin=vmin, vmax=vmax, origin="lower", ) axes[idx].set_title(f"{model_label} Density Heatmap") axes[idx].set_xlabel("Time Step") axes[idx].set_ylabel("Detector Group Index") plt.colorbar(im, ax=axes[idx], label="Occupancy (%)") plt.tight_layout() plt.savefig(os.path.join(OUTPUT_DIR, "density_heatmaps_adjusted.png"), dpi=150) plt.close() print(f"Saved density heatmaps to {OUTPUT_DIR}/density_heatmaps_adjusted.png")