ctm-dqn/scripts/plot_global_variance.py

47 lines
1.3 KiB
Python

"""Plot global speed variance curves from exported evaluation CSV files."""
import os
import matplotlib
import matplotlib.pyplot as plt
import pandas as pd
matplotlib.use("Agg")
OUTPUT_DIR = "evaluation_results"
MODEL_LABELS = ["PPO", "APPO", "DQN", "No Control"]
MODEL_FILE_ALIASES = {
"PPO": ["PPO"],
"APPO": ["APPO"],
"DQN": ["DQN"],
"No Control": ["No Control"],
}
def resolve_csv_path(model_label: str) -> str:
for file_stem in MODEL_FILE_ALIASES[model_label]:
candidate = os.path.join(OUTPUT_DIR, f"{file_stem}_speed.csv")
if os.path.isfile(candidate):
return candidate
raise FileNotFoundError(f"Missing speed CSV for {model_label} under {OUTPUT_DIR}")
fig, ax = plt.subplots(figsize=(12, 6))
for model_label in MODEL_LABELS:
csv_path = resolve_csv_path(model_label)
df = pd.read_csv(csv_path)
variances = df.var(axis=1)
ax.plot(variances, label=model_label, alpha=0.8, linewidth=1.5)
ax.set_xlabel("Time Step", fontsize=12)
ax.set_ylabel("Speed Variance (km/h)^2", fontsize=12)
ax.set_title("Global Speed Variance Over Time", fontsize=14)
ax.legend(fontsize=11)
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig(os.path.join(OUTPUT_DIR, "global_speed_variance.png"), dpi=150)
plt.close()
print(f"Saved global speed variance plot to {OUTPUT_DIR}/global_speed_variance.png")