62 lines
1.6 KiB
Python
62 lines
1.6 KiB
Python
"""Plot density heatmaps from exported evaluation CSV files."""
|
|
import os
|
|
|
|
import matplotlib
|
|
import matplotlib.pyplot as plt
|
|
import numpy as np
|
|
import pandas as pd
|
|
|
|
matplotlib.use("Agg")
|
|
|
|
OUTPUT_DIR = "evaluation_results"
|
|
MODEL_LABELS = ["PPO", "APPO", "DQN", "No Control"]
|
|
MODEL_FILE_ALIASES = {
|
|
"PPO": ["PPO"],
|
|
"APPO": ["APPO"],
|
|
"DQN": ["DQN"],
|
|
"No Control": ["No Control"],
|
|
}
|
|
|
|
|
|
def resolve_csv_path(model_label: str) -> str:
|
|
for file_stem in MODEL_FILE_ALIASES[model_label]:
|
|
candidate = os.path.join(OUTPUT_DIR, f"{file_stem}_density.csv")
|
|
if os.path.isfile(candidate):
|
|
return candidate
|
|
raise FileNotFoundError(f"Missing density CSV for {model_label} under {OUTPUT_DIR}")
|
|
|
|
|
|
density_data = {}
|
|
for model_label in MODEL_LABELS:
|
|
csv_path = resolve_csv_path(model_label)
|
|
df = pd.read_csv(csv_path)
|
|
density_data[model_label] = df.values
|
|
|
|
vmin = 10
|
|
vmax = 50
|
|
|
|
print(f"Density range: {vmin:.2f}% - {vmax:.2f}%")
|
|
|
|
fig, axes = plt.subplots(2, 2, figsize=(16, 12))
|
|
axes = axes.flatten()
|
|
|
|
for idx, model_label in enumerate(MODEL_LABELS):
|
|
im = axes[idx].imshow(
|
|
density_data[model_label].T,
|
|
aspect="auto",
|
|
cmap="YlOrRd",
|
|
vmin=vmin,
|
|
vmax=vmax,
|
|
origin="lower",
|
|
)
|
|
axes[idx].set_title(f"{model_label} Density Heatmap")
|
|
axes[idx].set_xlabel("Time Step")
|
|
axes[idx].set_ylabel("Detector Group Index")
|
|
plt.colorbar(im, ax=axes[idx], label="Occupancy (%)")
|
|
|
|
plt.tight_layout()
|
|
plt.savefig(os.path.join(OUTPUT_DIR, "density_heatmaps_adjusted.png"), dpi=150)
|
|
plt.close()
|
|
|
|
print(f"Saved density heatmaps to {OUTPUT_DIR}/density_heatmaps_adjusted.png")
|