ctm-dqn/scripts/plot_speed_heatmap.py

137 lines
4.5 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""绘制无策略vs PPO策略的时空速度热力图"""
import yaml
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import os
from sumo_edge_vsl_environment import SUMOEdgeVSLEnvironment
from ppo_agent import PPOAgent
def collect_speed_data(env, agent=None, seed=42):
"""收集一个episode的速度数据按cell空间位置"""
import traci
state = env.reset(seed=seed)
# 构建cell列表按edge顺序和position排序
parser = env.parser
cells = [] # [(edge_id, pos, [det_ids])]
for edge_id in parser.edges:
edge_info = parser.edge_info.get(edge_id)
if not edge_info:
continue
# 按position分组检测器
pos_detectors = {}
for (lane_idx, pos_idx), det_id in edge_info.detectors.items():
det_info = parser.detectors[det_id]
pos = det_info.position
if pos not in pos_detectors:
pos_detectors[pos] = []
pos_detectors[pos].append(det_id)
# 按position排序
for pos in sorted(pos_detectors.keys()):
cells.append((edge_id, pos, pos_detectors[pos]))
speed_data = []
for step in tqdm(range(env.episode_length), desc="Collecting", leave=False):
if agent:
action, _, _ = agent.select_action(state, deterministic=True)
else:
action = np.array([4] * env.num_controlled_edges)
# 应用VSL限速
edge_speeds = env._decode_action(action)
env._apply_vsl(edge_speeds)
# 推进仿真60秒
for _ in range(env.steps_per_action):
traci.simulationStep()
# 收集60秒间隔的平均速度
cell_speeds = []
for edge_id, pos, det_ids in cells:
speeds = []
for det_id in det_ids:
try:
spd = traci.inductionloop.getLastIntervalMeanSpeed(det_id)
if spd <= 0:
spd = traci.lane.getMaxSpeed(traci.inductionloop.getLaneID(det_id))
speeds.append(spd)
except:
pass
avg_speed = np.mean(speeds) if speeds else 33.33
cell_speeds.append(avg_speed)
speed_data.append(cell_speeds)
state = env._collect_state()
if step == 0:
print(f"Step 0 - Cell count: {len(cells)}")
return np.array(speed_data) * 3.6
with open("config_sumo_vsl.yaml", "r", encoding="utf-8") as f:
config = yaml.safe_load(f)
# 检查缓存文件
cache_file = "speed_data_cache.npz"
if os.path.exists(cache_file):
print("从缓存加载数据...")
data = np.load(cache_file)
no_control_speeds = data['no_control']
ppo_speeds = data['ppo']
else:
print("收集无策略数据...")
env = SUMOEdgeVSLEnvironment(config)
no_control_speeds = collect_speed_data(env, agent=None)
env.close()
print("收集PPO策略数据...")
env = SUMOEdgeVSLEnvironment(config)
agent = PPOAgent(
state_dim=env.state_dim,
action_dims=[env.action_dim] * env.num_controlled_edges,
hidden_layers=[256, 256, 128],
learning_rate=3e-4,
device="cuda"
)
agent.load("checkpoints_sumo_vsl/20260324_100734/model_best.pt")
print("已加载训练好的PPO模型")
ppo_speeds = collect_speed_data(env, agent=agent)
env.close()
# 保存缓存
np.savez(cache_file, no_control=no_control_speeds, ppo=ppo_speeds)
print(f"数据已缓存到 {cache_file}")
# 绘图
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
print(f"无策略速度范围: {no_control_speeds.min():.1f} - {no_control_speeds.max():.1f} km/h")
print(f"PPO策略速度范围: {ppo_speeds.min():.1f} - {ppo_speeds.max():.1f} km/h")
# 使用实际速度范围
vmax = max(no_control_speeds.max(), ppo_speeds.max())
im1 = axes[0].imshow(no_control_speeds.T, aspect='auto', cmap='RdYlGn',
vmin=0, vmax=vmax, origin='lower')
axes[0].set_xlabel('Time Step')
axes[0].set_ylabel('Cell Position')
axes[0].set_title('No Control (Fixed 120 km/h)')
plt.colorbar(im1, ax=axes[0], label='Speed (km/h)')
im2 = axes[1].imshow(ppo_speeds.T, aspect='auto', cmap='RdYlGn',
vmin=0, vmax=vmax, origin='lower')
axes[1].set_xlabel('Time Step')
axes[1].set_ylabel('Cell Position')
axes[1].set_title('PPO Control')
plt.colorbar(im2, ax=axes[1], label='Speed (km/h)')
plt.tight_layout()
plt.savefig('speed_heatmap_comparison.png', dpi=150)
print("图片已保存: speed_heatmap_comparison.png")