ctm-dqn/scripts/plot_actions.py

56 lines
1.6 KiB
Python

"""Plot VSL action heatmaps from exported evaluation CSV files."""
import os
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
matplotlib.use("Agg")
OUTPUT_DIR = "evaluation_results"
MODEL_LABELS = ["PPO", "APPO", "DQN", "No Control"]
MODEL_FILE_ALIASES = {
"PPO": ["PPO"],
"APPO": ["APPO"],
"DQN": ["DQN"],
"No Control": ["No Control"],
}
SPEED_MAP = [40, 60, 80, 100, 120]
def resolve_csv_path(model_label: str) -> str:
for file_stem in MODEL_FILE_ALIASES[model_label]:
candidate = os.path.join(OUTPUT_DIR, f"{file_stem}_action.csv")
if os.path.isfile(candidate):
return candidate
raise FileNotFoundError(f"Missing action CSV for {model_label} under {OUTPUT_DIR}")
fig, axes = plt.subplots(2, 2, figsize=(16, 12))
axes = axes.flatten()
for idx, model_label in enumerate(MODEL_LABELS):
csv_path = resolve_csv_path(model_label)
df = pd.read_csv(csv_path)
actions = df.values
vsl_speeds = np.vectorize(lambda action_idx: SPEED_MAP[int(action_idx)])(actions)
im = axes[idx].imshow(
vsl_speeds.T,
aspect="auto",
cmap="RdYlGn",
vmin=40,
vmax=120,
origin="lower",
)
axes[idx].set_title(f"{model_label} VSL Decision")
axes[idx].set_xlabel("Time Step")
axes[idx].set_ylabel("Control Edge Index")
cbar = plt.colorbar(im, ax=axes[idx], label="VSL (km/h)")
cbar.set_ticks([40, 60, 80, 100, 120])
plt.tight_layout()
plt.savefig(os.path.join(OUTPUT_DIR, "action_heatmaps.png"), dpi=150)
print("Saved action heatmaps.")