ctm-dqn/plot_global_variance.py

34 lines
947 B
Python

"""绘制全局区段速度方差曲线"""
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib
matplotlib.use("Agg")
output_dir = "evaluation_results"
models = ["PPO", "APPO", "DQN", "No Control"]
fig, ax = plt.subplots(figsize=(12, 6))
for model in models:
csv_path = os.path.join(output_dir, f"{model}_speed.csv")
df = pd.read_csv(csv_path)
# 计算每个时刻所有区段的速度方差
variances = df.var(axis=1)
ax.plot(variances, label=model, alpha=0.8, linewidth=1.5)
ax.set_xlabel("Time Step", fontsize=12)
ax.set_ylabel("Speed Variance (km/h)²", fontsize=12)
ax.set_title("Global Speed Variance Over Time", fontsize=14)
ax.legend(fontsize=11)
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig(os.path.join(output_dir, "global_speed_variance.png"), dpi=150)
plt.close()
print(f"全局速度方差曲线已保存到 {output_dir}/global_speed_variance.png")