ctm-dqn/test_no_control.py

94 lines
2.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
无控制基线测试 - SUMO环境
固定最高限速不进行VSL控制
"""
import yaml
import numpy as np
from tqdm import tqdm
from sumo_vsl_environment import SUMOVSLEnvironment
def test_no_control():
with open("config_sumo_vsl.yaml", "r", encoding="utf-8") as f:
config = yaml.safe_load(f)
env = SUMOVSLEnvironment(config)
print("=" * 70)
print("无控制基线测试 - SUMO VSL环境")
print("=" * 70)
print(f" 测试回合数: 10")
print(f" 控制策略: 固定最高限速 (120 km/h)")
print()
num_episodes = 10
base_seed = 42
episode_rewards = []
episode_throughputs = []
episode_speeds = []
for episode in range(1, num_episodes + 1):
seed = base_seed + episode
state = env.reset(seed=seed)
episode_reward = 0
episode_throughput = 0
episode_speed = 0
done = False
step = 0
pbar = tqdm(total=env.episode_length, desc=f"Episode {episode}/{num_episodes}", leave=False)
while not done:
# 无控制所有zone都使用最高限速
action = np.array([4, 4, 4, 4, 4]) # 索引4 = 120 km/h
next_state, reward, done, info = env.step(action)
episode_reward += reward
episode_throughput += info["throughput"]
episode_speed += info["mean_speed_kmh"]
state = next_state
step += 1
pbar.set_postfix(r=f"{episode_reward:.1f}",
tp=f"{info['throughput']:.0f}",
v=f"{info['mean_speed_kmh']:.1f}")
pbar.update(1)
pbar.close()
avg_tp = episode_throughput / step
avg_speed = episode_speed / step
episode_rewards.append(episode_reward)
episode_throughputs.append(avg_tp)
episode_speeds.append(avg_speed)
print(f"Episode {episode}: Reward={episode_reward:.2f}, "
f"Throughput={avg_tp:.1f} veh/h, Speed={avg_speed:.1f} km/h")
env.close()
print("\n" + "=" * 70)
print("无控制基线测试结果")
print("=" * 70)
print(f"累积奖励: {np.mean(episode_rewards):.2f} ± {np.std(episode_rewards):.2f}")
print(f"平均通行量: {np.mean(episode_throughputs):.1f} ± {np.std(episode_throughputs):.1f} veh/h")
print(f"平均速度: {np.mean(episode_speeds):.1f} ± {np.std(episode_speeds):.1f} km/h")
print("=" * 70)
return {
"reward_mean": np.mean(episode_rewards),
"reward_std": np.std(episode_rewards),
"throughput_mean": np.mean(episode_throughputs),
"throughput_std": np.std(episode_throughputs),
"speed_mean": np.mean(episode_speeds),
"speed_std": np.std(episode_speeds)
}
if __name__ == "__main__":
test_no_control()