ctm-dqn/scripts/evaluate_models.py

220 lines
7.3 KiB
Python

"""评估训练好的模型 - 使用环境内置的info统计"""
import os
import yaml
import numpy as np
import matplotlib.pyplot as plt
import matplotlib
matplotlib.use("Agg")
import pandas as pd
from envs.edge_vsl_env import SUMOEdgeVSLEnvironment
from agents.ppo_agent import PPOAgent
from agents.appo_agent import APPOAgent
from agents.dqn_agent import DQNAgent
from agents.ddpg_agent import DDPGAgent
def evaluate_model(model_type, checkpoint_dir, config):
"""评估单个模型"""
env = SUMOEdgeVSLEnvironment(config)
# 加载模型
if model_type == "PPO":
agent = PPOAgent(
state_dim=env.state_dim,
action_dims=[env.action_dim] * env.num_edges,
hidden_layers=[256, 256, 128],
device="cuda"
)
agent.load(os.path.join(checkpoint_dir, "model_best.pt"))
elif model_type == "APPO":
agent = APPOAgent(
state_dim=env.state_dim,
action_dims=[env.action_dim] * env.num_edges,
hidden_dim=128,
device="cuda",
total_episodes=500
)
agent.load(os.path.join(checkpoint_dir, "model_best.pt"))
elif model_type == "DQN":
agent = DQNAgent(
state_dim=env.state_dim,
num_edges=env.num_edges,
num_actions_per_edge=5,
hidden_dim=256,
device="cuda"
)
agent.load(os.path.join(checkpoint_dir, "model_best.pt"))
elif model_type == "DDPG":
agent = DDPGAgent(
state_dim=env.state_dim,
action_dims=[env.action_dim] * env.num_edges,
device="cuda"
)
agent.load(os.path.join(checkpoint_dir, "model_best"))
else:
raise ValueError(f"Unknown model type: {model_type}")
# 运行评估
state = env.reset(seed=42)
done = False
# 收集数据
edge_speeds_data = []
edge_occs_data = []
actions_data = []
throughputs = []
mean_speeds = []
speed_stds = []
num_hard_brakes = []
densities = []
while not done:
# 选择动作
if model_type in ["PPO", "APPO"]:
action, _ = agent.select_action(state, deterministic=True)
elif model_type in ["DDPG", "TD3"]:
action, _, _ = agent.select_action(state, deterministic=True)
else: # DQN
action = agent.select_action(state, deterministic=True)
next_state, reward, done, info = env.step(action)
# 从info收集数据
edge_speeds_data.append(info["edge_speeds_ms"])
edge_occs_data.append(info["edge_occupancies"])
actions_data.append(info["edge_speeds_kmh"])
throughputs.append(info["throughput"])
mean_speeds.append(info["mean_speed_kmh"])
speed_stds.append(info["speed_std"])
num_hard_brakes.append(info["num_hard_brakes"])
densities.append(info["density"])
state = next_state
env.close()
return {
"edge_speeds_ms": np.array(edge_speeds_data),
"edge_occs": np.array(edge_occs_data),
"actions_kmh": np.array(actions_data),
"throughputs": throughputs,
"mean_speeds": mean_speeds,
"speed_stds": speed_stds,
"num_hard_brakes": num_hard_brakes,
"densities": densities
}
def plot_results(results_dict, output_dir):
"""绘制对比图并保存CSV"""
os.makedirs(output_dir, exist_ok=True)
# 保存CSV数据
for name, results in results_dict.items():
speed_df = pd.DataFrame(results["edge_speeds_ms"] * 3.6)
speed_df.to_csv(os.path.join(output_dir, f"{name}_speed.csv"), index=False)
occ_df = pd.DataFrame(results["edge_occs"])
occ_df.to_csv(os.path.join(output_dir, f"{name}_occupancy.csv"), index=False)
action_df = pd.DataFrame(results["actions_kmh"])
action_df.to_csv(os.path.join(output_dir, f"{name}_action.csv"), index=False)
metrics_df = pd.DataFrame({
"throughput": results["throughputs"],
"mean_speed": results["mean_speeds"],
"speed_std": results["speed_stds"],
"num_hard_brakes": results["num_hard_brakes"],
"density": results["densities"]
})
metrics_df.to_csv(os.path.join(output_dir, f"{name}_metrics.csv"), index=False)
# 速度时空图
fig, axes = plt.subplots(2, 2, figsize=(16, 12))
axes = axes.flatten()
for idx, (name, results) in enumerate(results_dict.items()):
speed_kmh = results["edge_speeds_ms"] * 3.6
im = axes[idx].imshow(speed_kmh.T, aspect="auto", cmap="RdYlGn",
vmin=0, vmax=120, origin="lower")
axes[idx].set_title(f"{name} Speed Heatmap")
axes[idx].set_xlabel("Time Step")
axes[idx].set_ylabel("Edge Index")
plt.colorbar(im, ax=axes[idx], label="Speed (km/h)")
plt.tight_layout()
plt.savefig(os.path.join(output_dir, "speed_heatmaps.png"), dpi=150)
plt.close()
# 指标对比图
fig, axes = plt.subplots(2, 2, figsize=(15, 10))
for name, results in results_dict.items():
axes[0, 0].plot(results["num_hard_brakes"], label=name, alpha=0.7)
axes[0, 1].plot(results["speed_stds"], label=name, alpha=0.7)
axes[1, 0].plot(results["throughputs"], label=name, alpha=0.7)
axes[1, 1].plot(results["mean_speeds"], label=name, alpha=0.7)
axes[0, 0].set_title("Hard Brakes")
axes[0, 0].set_xlabel("Time Step")
axes[0, 0].set_ylabel("Count")
axes[0, 0].legend()
axes[0, 0].grid(True, alpha=0.3)
axes[0, 1].set_title("Speed Variance")
axes[0, 1].set_xlabel("Time Step")
axes[0, 1].set_ylabel("Std (m/s)")
axes[0, 1].legend()
axes[0, 1].grid(True, alpha=0.3)
axes[1, 0].set_title("Throughput")
axes[1, 0].set_xlabel("Time Step")
axes[1, 0].set_ylabel("veh/h")
axes[1, 0].legend()
axes[1, 0].grid(True, alpha=0.3)
axes[1, 1].set_title("Mean Speed")
axes[1, 1].set_xlabel("Time Step")
axes[1, 1].set_ylabel("km/h")
axes[1, 1].legend()
axes[1, 1].grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig(os.path.join(output_dir, "metrics_comparison.png"), dpi=150)
plt.close()
# 打印统计摘要
print("\n" + "="*70)
print("Evaluation Summary")
print("="*70)
for name, results in results_dict.items():
print(f"\n{name}:")
print(f" Avg Throughput: {np.mean(results['throughputs']):.1f} veh/h")
print(f" Avg Speed: {np.mean(results['mean_speeds']):.1f} km/h")
print(f" Avg Speed Std: {np.mean(results['speed_stds']):.2f} m/s")
print(f" Avg Hard Brakes: {np.mean(results['num_hard_brakes']):.1f}")
print(f" Avg Density: {np.mean(results['densities']):.2f} veh/km")
print("="*70)
if __name__ == "__main__":
with open("config_sumo_vsl.yaml", "r", encoding="utf-8") as f:
config = yaml.safe_load(f)
models = {
"PPO": "checkpoints/ppo/final",
"APPO": "checkpoints/appo/final",
"DQN": "checkpoints/dqn/final",
"DDPG": "checkpoints/ddpg/final",
}
results_dict = {}
for name, checkpoint_dir in models.items():
if os.path.exists(checkpoint_dir):
print(f"Evaluating {name}...")
results_dict[name] = evaluate_model(name, checkpoint_dir, config)
if results_dict:
plot_results(results_dict, "results")
print(f"\nResults saved to results/")