ctm-dqn/scripts/plot_density_from_csv.py

62 lines
1.6 KiB
Python

"""Plot density heatmaps from exported evaluation CSV files."""
import os
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
matplotlib.use("Agg")
OUTPUT_DIR = "evaluation_results"
MODEL_LABELS = ["PPO", "APPO", "DQN", "No Control"]
MODEL_FILE_ALIASES = {
"PPO": ["PPO"],
"APPO": ["APPO"],
"DQN": ["DQN"],
"No Control": ["No Control"],
}
def resolve_csv_path(model_label: str) -> str:
for file_stem in MODEL_FILE_ALIASES[model_label]:
candidate = os.path.join(OUTPUT_DIR, f"{file_stem}_density.csv")
if os.path.isfile(candidate):
return candidate
raise FileNotFoundError(f"Missing density CSV for {model_label} under {OUTPUT_DIR}")
density_data = {}
for model_label in MODEL_LABELS:
csv_path = resolve_csv_path(model_label)
df = pd.read_csv(csv_path)
density_data[model_label] = df.values
vmin = 10
vmax = 50
print(f"Density range: {vmin:.2f}% - {vmax:.2f}%")
fig, axes = plt.subplots(2, 2, figsize=(16, 12))
axes = axes.flatten()
for idx, model_label in enumerate(MODEL_LABELS):
im = axes[idx].imshow(
density_data[model_label].T,
aspect="auto",
cmap="YlOrRd",
vmin=vmin,
vmax=vmax,
origin="lower",
)
axes[idx].set_title(f"{model_label} Density Heatmap")
axes[idx].set_xlabel("Time Step")
axes[idx].set_ylabel("Detector Group Index")
plt.colorbar(im, ax=axes[idx], label="Occupancy (%)")
plt.tight_layout()
plt.savefig(os.path.join(OUTPUT_DIR, "density_heatmaps_adjusted.png"), dpi=150)
plt.close()
print(f"Saved density heatmaps to {OUTPUT_DIR}/density_heatmaps_adjusted.png")