56 lines
1.6 KiB
Python
56 lines
1.6 KiB
Python
"""Plot VSL action heatmaps from exported evaluation CSV files."""
|
|
import os
|
|
|
|
import matplotlib
|
|
import matplotlib.pyplot as plt
|
|
import numpy as np
|
|
import pandas as pd
|
|
|
|
matplotlib.use("Agg")
|
|
|
|
OUTPUT_DIR = "evaluation_results"
|
|
MODEL_LABELS = ["PPO", "APPO", "DQN", "No Control"]
|
|
MODEL_FILE_ALIASES = {
|
|
"PPO": ["PPO"],
|
|
"APPO": ["APPO"],
|
|
"DQN": ["DQN"],
|
|
"No Control": ["No Control"],
|
|
}
|
|
SPEED_MAP = [40, 60, 80, 100, 120]
|
|
|
|
|
|
def resolve_csv_path(model_label: str) -> str:
|
|
for file_stem in MODEL_FILE_ALIASES[model_label]:
|
|
candidate = os.path.join(OUTPUT_DIR, f"{file_stem}_action.csv")
|
|
if os.path.isfile(candidate):
|
|
return candidate
|
|
raise FileNotFoundError(f"Missing action CSV for {model_label} under {OUTPUT_DIR}")
|
|
|
|
|
|
fig, axes = plt.subplots(2, 2, figsize=(16, 12))
|
|
axes = axes.flatten()
|
|
|
|
for idx, model_label in enumerate(MODEL_LABELS):
|
|
csv_path = resolve_csv_path(model_label)
|
|
df = pd.read_csv(csv_path)
|
|
actions = df.values
|
|
vsl_speeds = np.vectorize(lambda action_idx: SPEED_MAP[int(action_idx)])(actions)
|
|
|
|
im = axes[idx].imshow(
|
|
vsl_speeds.T,
|
|
aspect="auto",
|
|
cmap="RdYlGn",
|
|
vmin=40,
|
|
vmax=120,
|
|
origin="lower",
|
|
)
|
|
axes[idx].set_title(f"{model_label} VSL Decision")
|
|
axes[idx].set_xlabel("Time Step")
|
|
axes[idx].set_ylabel("Control Edge Index")
|
|
cbar = plt.colorbar(im, ax=axes[idx], label="VSL (km/h)")
|
|
cbar.set_ticks([40, 60, 80, 100, 120])
|
|
|
|
plt.tight_layout()
|
|
plt.savefig(os.path.join(OUTPUT_DIR, "action_heatmaps.png"), dpi=150)
|
|
print("Saved action heatmaps.")
|