137 lines
4.5 KiB
Python
137 lines
4.5 KiB
Python
"""绘制无策略vs PPO策略的时空速度热力图"""
|
||
import yaml
|
||
import numpy as np
|
||
import matplotlib.pyplot as plt
|
||
from tqdm import tqdm
|
||
import os
|
||
from sumo_edge_vsl_environment import SUMOEdgeVSLEnvironment
|
||
from ppo_agent import PPOAgent
|
||
|
||
def collect_speed_data(env, agent=None, seed=42):
|
||
"""收集一个episode的速度数据(按cell空间位置)"""
|
||
import traci
|
||
|
||
state = env.reset(seed=seed)
|
||
|
||
# 构建cell列表:按edge顺序和position排序
|
||
parser = env.parser
|
||
cells = [] # [(edge_id, pos, [det_ids])]
|
||
|
||
for edge_id in parser.edges:
|
||
edge_info = parser.edge_info.get(edge_id)
|
||
if not edge_info:
|
||
continue
|
||
|
||
# 按position分组检测器
|
||
pos_detectors = {}
|
||
for (lane_idx, pos_idx), det_id in edge_info.detectors.items():
|
||
det_info = parser.detectors[det_id]
|
||
pos = det_info.position
|
||
if pos not in pos_detectors:
|
||
pos_detectors[pos] = []
|
||
pos_detectors[pos].append(det_id)
|
||
|
||
# 按position排序
|
||
for pos in sorted(pos_detectors.keys()):
|
||
cells.append((edge_id, pos, pos_detectors[pos]))
|
||
|
||
speed_data = []
|
||
|
||
for step in tqdm(range(env.episode_length), desc="Collecting", leave=False):
|
||
if agent:
|
||
action, _, _ = agent.select_action(state, deterministic=True)
|
||
else:
|
||
action = np.array([4] * env.num_controlled_edges)
|
||
|
||
# 应用VSL限速
|
||
edge_speeds = env._decode_action(action)
|
||
env._apply_vsl(edge_speeds)
|
||
|
||
# 推进仿真60秒
|
||
for _ in range(env.steps_per_action):
|
||
traci.simulationStep()
|
||
|
||
# 收集60秒间隔的平均速度
|
||
cell_speeds = []
|
||
for edge_id, pos, det_ids in cells:
|
||
speeds = []
|
||
for det_id in det_ids:
|
||
try:
|
||
spd = traci.inductionloop.getLastIntervalMeanSpeed(det_id)
|
||
if spd <= 0:
|
||
spd = traci.lane.getMaxSpeed(traci.inductionloop.getLaneID(det_id))
|
||
speeds.append(spd)
|
||
except:
|
||
pass
|
||
avg_speed = np.mean(speeds) if speeds else 33.33
|
||
cell_speeds.append(avg_speed)
|
||
speed_data.append(cell_speeds)
|
||
|
||
state = env._collect_state()
|
||
|
||
if step == 0:
|
||
print(f"Step 0 - Cell count: {len(cells)}")
|
||
|
||
return np.array(speed_data) * 3.6
|
||
|
||
with open("config_sumo_vsl.yaml", "r", encoding="utf-8") as f:
|
||
config = yaml.safe_load(f)
|
||
|
||
# 检查缓存文件
|
||
cache_file = "speed_data_cache.npz"
|
||
if os.path.exists(cache_file):
|
||
print("从缓存加载数据...")
|
||
data = np.load(cache_file)
|
||
no_control_speeds = data['no_control']
|
||
ppo_speeds = data['ppo']
|
||
else:
|
||
print("收集无策略数据...")
|
||
env = SUMOEdgeVSLEnvironment(config)
|
||
no_control_speeds = collect_speed_data(env, agent=None)
|
||
env.close()
|
||
|
||
print("收集PPO策略数据...")
|
||
env = SUMOEdgeVSLEnvironment(config)
|
||
agent = PPOAgent(
|
||
state_dim=env.state_dim,
|
||
action_dims=[env.action_dim] * env.num_controlled_edges,
|
||
hidden_layers=[256, 256, 128],
|
||
learning_rate=3e-4,
|
||
device="cuda"
|
||
)
|
||
agent.load("checkpoints_sumo_vsl/20260324_100734/model_best.pt")
|
||
print("已加载训练好的PPO模型")
|
||
ppo_speeds = collect_speed_data(env, agent=agent)
|
||
env.close()
|
||
|
||
# 保存缓存
|
||
np.savez(cache_file, no_control=no_control_speeds, ppo=ppo_speeds)
|
||
print(f"数据已缓存到 {cache_file}")
|
||
|
||
# 绘图
|
||
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
|
||
|
||
print(f"无策略速度范围: {no_control_speeds.min():.1f} - {no_control_speeds.max():.1f} km/h")
|
||
print(f"PPO策略速度范围: {ppo_speeds.min():.1f} - {ppo_speeds.max():.1f} km/h")
|
||
|
||
# 使用实际速度范围
|
||
vmax = max(no_control_speeds.max(), ppo_speeds.max())
|
||
|
||
im1 = axes[0].imshow(no_control_speeds.T, aspect='auto', cmap='RdYlGn',
|
||
vmin=0, vmax=vmax, origin='lower')
|
||
axes[0].set_xlabel('Time Step')
|
||
axes[0].set_ylabel('Cell Position')
|
||
axes[0].set_title('No Control (Fixed 120 km/h)')
|
||
plt.colorbar(im1, ax=axes[0], label='Speed (km/h)')
|
||
|
||
im2 = axes[1].imshow(ppo_speeds.T, aspect='auto', cmap='RdYlGn',
|
||
vmin=0, vmax=vmax, origin='lower')
|
||
axes[1].set_xlabel('Time Step')
|
||
axes[1].set_ylabel('Cell Position')
|
||
axes[1].set_title('PPO Control')
|
||
plt.colorbar(im2, ax=axes[1], label='Speed (km/h)')
|
||
|
||
plt.tight_layout()
|
||
plt.savefig('speed_heatmap_comparison.png', dpi=150)
|
||
print("图片已保存: speed_heatmap_comparison.png")
|