"""Plot global speed variance curves from exported evaluation CSV files.""" import os import matplotlib import matplotlib.pyplot as plt import pandas as pd matplotlib.use("Agg") OUTPUT_DIR = "evaluation_results" MODEL_LABELS = ["PPO", "APPO", "DQN", "No Control"] MODEL_FILE_ALIASES = { "PPO": ["PPO"], "APPO": ["APPO"], "DQN": ["DQN"], "No Control": ["No Control"], } def resolve_csv_path(model_label: str) -> str: for file_stem in MODEL_FILE_ALIASES[model_label]: candidate = os.path.join(OUTPUT_DIR, f"{file_stem}_speed.csv") if os.path.isfile(candidate): return candidate raise FileNotFoundError(f"Missing speed CSV for {model_label} under {OUTPUT_DIR}") fig, ax = plt.subplots(figsize=(12, 6)) for model_label in MODEL_LABELS: csv_path = resolve_csv_path(model_label) df = pd.read_csv(csv_path) variances = df.var(axis=1) ax.plot(variances, label=model_label, alpha=0.8, linewidth=1.5) ax.set_xlabel("Time Step", fontsize=12) ax.set_ylabel("Speed Variance (km/h)^2", fontsize=12) ax.set_title("Global Speed Variance Over Time", fontsize=14) ax.legend(fontsize=11) ax.grid(True, alpha=0.3) plt.tight_layout() plt.savefig(os.path.join(OUTPUT_DIR, "global_speed_variance.png"), dpi=150) plt.close() print(f"Saved global speed variance plot to {OUTPUT_DIR}/global_speed_variance.png")