ctm-dqn/compare_all_methods.py

208 lines
6.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
对比四种方法APPO、PPO、DQN、无控制基线
"""
import os
import yaml
import numpy as np
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
from tqdm import tqdm
from sumo_vsl_environment import SUMOVSLEnvironment
from appo_v2_agent import APPOv2Agent
from ppo_agent import PPOAgent
from dqn_agent import DQNAgent
def test_method(env, agent, method_name, num_episodes=10, base_seed=42):
"""测试单个方法"""
print(f"\n{'='*70}")
print(f"测试 {method_name}")
print(f"{'='*70}")
episode_rewards = []
episode_throughputs = []
episode_speeds = []
for episode in range(1, num_episodes + 1):
seed = base_seed + episode
state = env.reset(seed=seed)
episode_reward = 0
episode_throughput = 0
episode_speed = 0
done = False
step = 0
pbar = tqdm(total=env.episode_length, desc=f"Ep {episode}/{num_episodes}", leave=False)
while not done:
if agent is None: # 无控制基线
action = np.array([4, 4, 4, 4, 4])
elif method_name == "DQN":
action_idx = agent.select_action(state, deterministic=True)
action = []
temp = action_idx
for _ in range(5):
action.append(temp % 5)
temp //= 5
action = np.array(action)
else: # PPO or APPO
action, _, _ = agent.select_action(state, deterministic=True)
next_state, reward, done, info = env.step(action)
episode_reward += reward
episode_throughput += info["throughput"]
episode_speed += info["mean_speed_kmh"]
state = next_state
step += 1
pbar.set_postfix(r=f"{episode_reward:.1f}",
tp=f"{info['throughput']:.0f}",
v=f"{info['mean_speed_kmh']:.1f}")
pbar.update(1)
pbar.close()
avg_tp = episode_throughput / step
avg_speed = episode_speed / step
episode_rewards.append(episode_reward)
episode_throughputs.append(avg_tp)
episode_speeds.append(avg_speed)
print(f"Ep {episode}: R={episode_reward:.2f}, TP={avg_tp:.1f}, V={avg_speed:.1f}")
return {
"reward_mean": np.mean(episode_rewards),
"reward_std": np.std(episode_rewards),
"throughput_mean": np.mean(episode_throughputs),
"throughput_std": np.std(episode_throughputs),
"speed_mean": np.mean(episode_speeds),
"speed_std": np.std(episode_speeds),
"rewards": episode_rewards,
"throughputs": episode_throughputs,
"speeds": episode_speeds
}
def main():
with open("config_sumo_vsl.yaml", "r", encoding="utf-8") as f:
config = yaml.safe_load(f)
env = SUMOVSLEnvironment(config)
state_dim = env.state_dim
action_dim = env.action_dim
results = {}
# 1. 无控制基线
print("\n" + "="*70)
print("开始测试:无控制基线")
print("="*70)
results["No-Control"] = test_method(env, None, "No-Control")
# 2. APPO
print("\n" + "="*70)
print("开始测试APPO")
print("="*70)
appo_agent = APPOv2Agent(
state_dim=state_dim,
num_zones=5,
num_actions_per_zone=5,
hidden_dim=128,
num_heads=4,
num_layers=2,
device="cuda"
)
appo_agent.load("checkpoints_sumo_appo/20260322_203511/model_best.pt")
results["APPO"] = test_method(env, appo_agent, "APPO")
# 3. PPO
print("\n" + "="*70)
print("开始测试PPO")
print("="*70)
ppo_agent = PPOAgent(
state_dim=state_dim,
action_dims=[5] * 5,
hidden_layers=[512, 256],
device="cuda"
)
ppo_agent.load("checkpoints_sumo_vsl/20260319_174256/model_best.pt")
results["PPO"] = test_method(env, ppo_agent, "PPO")
# 4. DQN
print("\n" + "="*70)
print("开始测试DQN")
print("="*70)
dqn_agent = DQNAgent(
state_dim=state_dim,
num_actions=5**5,
hidden_dim=256,
device="cuda"
)
dqn_agent.load("checkpoints_sumo_dqn/20260323_062719/model_best.pt")
results["DQN"] = test_method(env, dqn_agent, "DQN")
env.close()
# 打印对比结果
print("\n" + "="*70)
print("四种方法对比结果")
print("="*70)
print(f"{'方法':<15} {'累积奖励':<20} {'通行量(veh/h)':<20} {'平均速度(km/h)':<20}")
print("-"*70)
for method in ["No-Control", "APPO", "PPO", "DQN"]:
r = results[method]
print(f"{method:<15} {r['reward_mean']:>8.2f}±{r['reward_std']:<8.2f} "
f"{r['throughput_mean']:>8.1f}±{r['throughput_std']:<8.1f} "
f"{r['speed_mean']:>8.1f}±{r['speed_std']:<8.1f}")
print("="*70)
# 绘制对比图
fig, axes = plt.subplots(1, 3, figsize=(18, 5))
methods = ["No-Control", "APPO", "PPO", "DQN"]
colors = ["gray", "blue", "green", "orange"]
# 累积奖励
means = [results[m]["reward_mean"] for m in methods]
stds = [results[m]["reward_std"] for m in methods]
axes[0].bar(methods, means, yerr=stds, color=colors, alpha=0.7, capsize=5)
axes[0].set_ylabel("Cumulative Reward")
axes[0].set_title("Cumulative Reward Comparison")
axes[0].grid(True, alpha=0.3, axis='y')
# 通行量
means = [results[m]["throughput_mean"] for m in methods]
stds = [results[m]["throughput_std"] for m in methods]
axes[1].bar(methods, means, yerr=stds, color=colors, alpha=0.7, capsize=5)
axes[1].set_ylabel("Throughput (veh/h)")
axes[1].set_title("Throughput Comparison")
axes[1].grid(True, alpha=0.3, axis='y')
# 平均速度
means = [results[m]["speed_mean"] for m in methods]
stds = [results[m]["speed_std"] for m in methods]
axes[2].bar(methods, means, yerr=stds, color=colors, alpha=0.7, capsize=5)
axes[2].set_ylabel("Mean Speed (km/h)")
axes[2].set_title("Mean Speed Comparison")
axes[2].grid(True, alpha=0.3, axis='y')
plt.tight_layout()
plt.savefig("method_comparison.png", dpi=150, bbox_inches="tight")
print(f"\n对比图已保存: method_comparison.png")
# 保存详细结果
import json
with open("comparison_results.json", "w", encoding="utf-8") as f:
json.dump(results, f, indent=2, ensure_ascii=False)
print(f"详细结果已保存: comparison_results.json")
if __name__ == "__main__":
main()