208 lines
6.4 KiB
Python
208 lines
6.4 KiB
Python
"""
|
||
对比四种方法:APPO、PPO、DQN、无控制基线
|
||
"""
|
||
import os
|
||
import yaml
|
||
import numpy as np
|
||
import matplotlib
|
||
matplotlib.use("Agg")
|
||
import matplotlib.pyplot as plt
|
||
from tqdm import tqdm
|
||
|
||
from sumo_vsl_environment import SUMOVSLEnvironment
|
||
from appo_v2_agent import APPOv2Agent
|
||
from ppo_agent import PPOAgent
|
||
from dqn_agent import DQNAgent
|
||
|
||
|
||
def test_method(env, agent, method_name, num_episodes=10, base_seed=42):
|
||
"""测试单个方法"""
|
||
print(f"\n{'='*70}")
|
||
print(f"测试 {method_name}")
|
||
print(f"{'='*70}")
|
||
|
||
episode_rewards = []
|
||
episode_throughputs = []
|
||
episode_speeds = []
|
||
|
||
for episode in range(1, num_episodes + 1):
|
||
seed = base_seed + episode
|
||
state = env.reset(seed=seed)
|
||
episode_reward = 0
|
||
episode_throughput = 0
|
||
episode_speed = 0
|
||
done = False
|
||
step = 0
|
||
|
||
pbar = tqdm(total=env.episode_length, desc=f"Ep {episode}/{num_episodes}", leave=False)
|
||
|
||
while not done:
|
||
if agent is None: # 无控制基线
|
||
action = np.array([4, 4, 4, 4, 4])
|
||
elif method_name == "DQN":
|
||
action_idx = agent.select_action(state, deterministic=True)
|
||
action = []
|
||
temp = action_idx
|
||
for _ in range(5):
|
||
action.append(temp % 5)
|
||
temp //= 5
|
||
action = np.array(action)
|
||
else: # PPO or APPO
|
||
action, _, _ = agent.select_action(state, deterministic=True)
|
||
|
||
next_state, reward, done, info = env.step(action)
|
||
|
||
episode_reward += reward
|
||
episode_throughput += info["throughput"]
|
||
episode_speed += info["mean_speed_kmh"]
|
||
state = next_state
|
||
step += 1
|
||
|
||
pbar.set_postfix(r=f"{episode_reward:.1f}",
|
||
tp=f"{info['throughput']:.0f}",
|
||
v=f"{info['mean_speed_kmh']:.1f}")
|
||
pbar.update(1)
|
||
|
||
pbar.close()
|
||
|
||
avg_tp = episode_throughput / step
|
||
avg_speed = episode_speed / step
|
||
|
||
episode_rewards.append(episode_reward)
|
||
episode_throughputs.append(avg_tp)
|
||
episode_speeds.append(avg_speed)
|
||
|
||
print(f"Ep {episode}: R={episode_reward:.2f}, TP={avg_tp:.1f}, V={avg_speed:.1f}")
|
||
|
||
return {
|
||
"reward_mean": np.mean(episode_rewards),
|
||
"reward_std": np.std(episode_rewards),
|
||
"throughput_mean": np.mean(episode_throughputs),
|
||
"throughput_std": np.std(episode_throughputs),
|
||
"speed_mean": np.mean(episode_speeds),
|
||
"speed_std": np.std(episode_speeds),
|
||
"rewards": episode_rewards,
|
||
"throughputs": episode_throughputs,
|
||
"speeds": episode_speeds
|
||
}
|
||
|
||
|
||
def main():
|
||
with open("config_sumo_vsl.yaml", "r", encoding="utf-8") as f:
|
||
config = yaml.safe_load(f)
|
||
|
||
env = SUMOVSLEnvironment(config)
|
||
state_dim = env.state_dim
|
||
action_dim = env.action_dim
|
||
|
||
results = {}
|
||
|
||
# 1. 无控制基线
|
||
print("\n" + "="*70)
|
||
print("开始测试:无控制基线")
|
||
print("="*70)
|
||
results["No-Control"] = test_method(env, None, "No-Control")
|
||
|
||
# 2. APPO
|
||
print("\n" + "="*70)
|
||
print("开始测试:APPO")
|
||
print("="*70)
|
||
appo_agent = APPOv2Agent(
|
||
state_dim=state_dim,
|
||
num_zones=5,
|
||
num_actions_per_zone=5,
|
||
hidden_dim=128,
|
||
num_heads=4,
|
||
num_layers=2,
|
||
device="cuda"
|
||
)
|
||
appo_agent.load("checkpoints_sumo_appo/20260322_203511/model_best.pt")
|
||
results["APPO"] = test_method(env, appo_agent, "APPO")
|
||
|
||
# 3. PPO
|
||
print("\n" + "="*70)
|
||
print("开始测试:PPO")
|
||
print("="*70)
|
||
ppo_agent = PPOAgent(
|
||
state_dim=state_dim,
|
||
action_dims=[5] * 5,
|
||
hidden_layers=[512, 256],
|
||
device="cuda"
|
||
)
|
||
ppo_agent.load("checkpoints_sumo_vsl/20260319_174256/model_best.pt")
|
||
results["PPO"] = test_method(env, ppo_agent, "PPO")
|
||
|
||
# 4. DQN
|
||
print("\n" + "="*70)
|
||
print("开始测试:DQN")
|
||
print("="*70)
|
||
dqn_agent = DQNAgent(
|
||
state_dim=state_dim,
|
||
num_actions=5**5,
|
||
hidden_dim=256,
|
||
device="cuda"
|
||
)
|
||
dqn_agent.load("checkpoints_sumo_dqn/20260323_062719/model_best.pt")
|
||
results["DQN"] = test_method(env, dqn_agent, "DQN")
|
||
|
||
env.close()
|
||
|
||
# 打印对比结果
|
||
print("\n" + "="*70)
|
||
print("四种方法对比结果")
|
||
print("="*70)
|
||
print(f"{'方法':<15} {'累积奖励':<20} {'通行量(veh/h)':<20} {'平均速度(km/h)':<20}")
|
||
print("-"*70)
|
||
|
||
for method in ["No-Control", "APPO", "PPO", "DQN"]:
|
||
r = results[method]
|
||
print(f"{method:<15} {r['reward_mean']:>8.2f}±{r['reward_std']:<8.2f} "
|
||
f"{r['throughput_mean']:>8.1f}±{r['throughput_std']:<8.1f} "
|
||
f"{r['speed_mean']:>8.1f}±{r['speed_std']:<8.1f}")
|
||
|
||
print("="*70)
|
||
|
||
# 绘制对比图
|
||
fig, axes = plt.subplots(1, 3, figsize=(18, 5))
|
||
|
||
methods = ["No-Control", "APPO", "PPO", "DQN"]
|
||
colors = ["gray", "blue", "green", "orange"]
|
||
|
||
# 累积奖励
|
||
means = [results[m]["reward_mean"] for m in methods]
|
||
stds = [results[m]["reward_std"] for m in methods]
|
||
axes[0].bar(methods, means, yerr=stds, color=colors, alpha=0.7, capsize=5)
|
||
axes[0].set_ylabel("Cumulative Reward")
|
||
axes[0].set_title("Cumulative Reward Comparison")
|
||
axes[0].grid(True, alpha=0.3, axis='y')
|
||
|
||
# 通行量
|
||
means = [results[m]["throughput_mean"] for m in methods]
|
||
stds = [results[m]["throughput_std"] for m in methods]
|
||
axes[1].bar(methods, means, yerr=stds, color=colors, alpha=0.7, capsize=5)
|
||
axes[1].set_ylabel("Throughput (veh/h)")
|
||
axes[1].set_title("Throughput Comparison")
|
||
axes[1].grid(True, alpha=0.3, axis='y')
|
||
|
||
# 平均速度
|
||
means = [results[m]["speed_mean"] for m in methods]
|
||
stds = [results[m]["speed_std"] for m in methods]
|
||
axes[2].bar(methods, means, yerr=stds, color=colors, alpha=0.7, capsize=5)
|
||
axes[2].set_ylabel("Mean Speed (km/h)")
|
||
axes[2].set_title("Mean Speed Comparison")
|
||
axes[2].grid(True, alpha=0.3, axis='y')
|
||
|
||
plt.tight_layout()
|
||
plt.savefig("method_comparison.png", dpi=150, bbox_inches="tight")
|
||
print(f"\n对比图已保存: method_comparison.png")
|
||
|
||
# 保存详细结果
|
||
import json
|
||
with open("comparison_results.json", "w", encoding="utf-8") as f:
|
||
json.dump(results, f, indent=2, ensure_ascii=False)
|
||
print(f"详细结果已保存: comparison_results.json")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|