ctm-dqn/evaluate_models.py

346 lines
13 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""评估训练好的模型并绘制速度时空图"""
import os
import yaml
import numpy as np
import matplotlib.pyplot as plt
import matplotlib
matplotlib.use("Agg")
from sumo_edge_vsl_environment import SUMOEdgeVSLEnvironment
from ppo_agent import PPOAgent
from appo_sumo_agent import APPOSUMOAgent
from dqn_agent import DQNAgent
import traci
def evaluate_model(model_type, checkpoint_dir, config):
"""评估单个模型"""
env = SUMOEdgeVSLEnvironment(config)
if model_type == "PPO":
agent = PPOAgent(
state_dim=env.state_dim,
action_dims=[env.action_dim] * env.num_edges,
hidden_layers=[256, 256, 128],
device="cuda"
)
agent.load(os.path.join(checkpoint_dir, "model_best.pt"))
elif model_type == "APPO":
agent = APPOSUMOAgent(
state_dim=env.state_dim,
action_dims=[env.action_dim] * env.num_edges,
hidden_dim=128,
device="cuda",
total_episodes=500
)
agent.load(os.path.join(checkpoint_dir, "model_best.pt"))
elif model_type == "DQN":
agents = []
for i in range(env.num_edges):
agent_i = DQNAgent(
state_dim=env.state_dim,
num_actions=5,
hidden_dim=256,
device="cuda"
)
agent_i.load(os.path.join(checkpoint_dir, f"model_best_edge{i}.pt"))
agents.append(agent_i)
agent = agents
# 收集所有车检器组位置(只包含有检测器的位置)
detector_groups = []
for edge_id in env.control_edges:
edge_info = env.parser.edge_info[edge_id]
for pos in edge_info.detector_positions:
pos_idx = edge_info.detector_positions.index(pos)
# 检查该位置是否有检测器
has_detector = False
for lane_idx in edge_info.traffic_lane_indices:
if edge_info.detectors.get((lane_idx, pos_idx)):
has_detector = True
break
if has_detector:
detector_groups.append((edge_id, pos))
# 运行仿真
state = env.reset(seed=42)
speed_data = []
density_data = []
action_data = []
brake_counts = []
speed_variances = []
throughputs = []
done = False
step = 0
while not done:
if model_type == "DQN":
action = np.array([ag.select_action(state) for ag in agent])
else:
action, _, _ = agent.select_action(state, deterministic=True)
action_data.append(action.copy())
state, reward, done, info = env.step(action)
if not done:
step_speeds = []
step_densities = []
for edge_id, pos in detector_groups:
edge_info = env.parser.edge_info[edge_id]
pos_idx = edge_info.detector_positions.index(pos)
# 获取该位置所有车道的速度和占有率并取平均
speeds = []
occupancies = []
for lane_idx in edge_info.traffic_lane_indices:
det_id = edge_info.detectors.get((lane_idx, pos_idx))
if det_id:
spd = traci.inductionloop.getLastIntervalMeanSpeed(det_id)
occ = traci.inductionloop.getLastIntervalOccupancy(det_id)
if spd >= 0:
speeds.append(spd * 3.6)
if occ >= 0:
occupancies.append(occ)
# 如果没有检测器数据使用edge的实时速度作为备用
if speeds:
avg_speed = np.mean(speeds)
else:
edge_speed = traci.edge.getLastStepMeanSpeed(edge_id)
avg_speed = edge_speed * 3.6 if edge_speed >= 0 else 0
step_speeds.append(avg_speed)
avg_occupancy = np.mean(occupancies) if occupancies else 0
step_densities.append(avg_occupancy)
speed_data.append(step_speeds)
density_data.append(step_densities)
total_brakes = 0
all_speeds = []
for edge_id, edge_info in env.parser.edge_info.items():
for det_id in edge_info.detectors.values():
total_brakes += traci.inductionloop.getLastIntervalVehicleNumber(det_id)
spd = traci.inductionloop.getLastIntervalMeanSpeed(det_id)
if spd >= 0:
all_speeds.append(spd * 3.6)
brake_counts.append(total_brakes)
speed_variances.append(np.var(all_speeds) if all_speeds else 0)
throughputs.append(info["throughput"])
step += 1
env.close()
return {
"speed_data": np.array(speed_data),
"density_data": np.array(density_data),
"action_data": np.array(action_data),
"brake_counts": brake_counts,
"speed_variances": speed_variances,
"throughputs": throughputs
}
def plot_results(results_dict, output_dir):
"""绘制对比图并保存CSV"""
os.makedirs(output_dir, exist_ok=True)
# 保存CSV数据
import pandas as pd
for name, results in results_dict.items():
speed_df = pd.DataFrame(results["speed_data"])
speed_df.to_csv(os.path.join(output_dir, f"{name}_speed.csv"), index=False)
density_df = pd.DataFrame(results["density_data"])
density_df.to_csv(os.path.join(output_dir, f"{name}_density.csv"), index=False)
action_df = pd.DataFrame(results["action_data"])
action_df.to_csv(os.path.join(output_dir, f"{name}_action.csv"), index=False)
metrics_df = pd.DataFrame({
"brake_counts": results["brake_counts"],
"speed_variances": results["speed_variances"],
"throughputs": results["throughputs"]
})
metrics_df.to_csv(os.path.join(output_dir, f"{name}_metrics.csv"), index=False)
# 速度时空图 - 2x2布局
fig, axes = plt.subplots(2, 2, figsize=(16, 12))
axes = axes.flatten()
for idx, (name, results) in enumerate(results_dict.items()):
im = axes[idx].imshow(results["speed_data"].T, aspect="auto", cmap="RdYlGn",
vmin=0, vmax=120, origin="lower")
axes[idx].set_title(f"{name} Speed Heatmap")
axes[idx].set_xlabel("Time Step")
axes[idx].set_ylabel("Detector Group Index")
plt.colorbar(im, ax=axes[idx], label="Speed (km/h)")
plt.tight_layout()
plt.savefig(os.path.join(output_dir, "speed_heatmaps.png"), dpi=150)
plt.close()
# 密度时空图 - 2x2布局
fig, axes = plt.subplots(2, 2, figsize=(16, 12))
axes = axes.flatten()
for idx, (name, results) in enumerate(results_dict.items()):
im = axes[idx].imshow(results["density_data"].T, aspect="auto", cmap="YlOrRd",
vmin=0, vmax=100, origin="lower")
axes[idx].set_title(f"{name} Density Heatmap")
axes[idx].set_xlabel("Time Step")
axes[idx].set_ylabel("Detector Group Index")
plt.colorbar(im, ax=axes[idx], label="Occupancy (%)")
plt.tight_layout()
plt.savefig(os.path.join(output_dir, "density_heatmaps.png"), dpi=150)
plt.close()
# 关键指标曲线
fig, axes = plt.subplots(2, 2, figsize=(15, 10))
for name, results in results_dict.items():
axes[0, 0].plot(results["brake_counts"], label=name, alpha=0.7)
axes[0, 1].plot(results["speed_variances"], label=name, alpha=0.7)
axes[1, 0].plot(results["throughputs"], label=name, alpha=0.7)
axes[0, 0].set_title("Brake Counts Over Time")
axes[0, 0].set_xlabel("Time Step")
axes[0, 0].set_ylabel("Brake Count")
axes[0, 0].legend()
axes[0, 0].grid(True, alpha=0.3)
axes[0, 1].set_title("Speed Variance Over Time")
axes[0, 1].set_xlabel("Time Step")
axes[0, 1].set_ylabel("Variance (km/h)²")
axes[0, 1].legend()
axes[0, 1].grid(True, alpha=0.3)
axes[1, 0].set_title("Throughput Over Time")
axes[1, 0].set_xlabel("Time Step")
axes[1, 0].set_ylabel("Throughput (veh/h)")
axes[1, 0].legend()
axes[1, 0].grid(True, alpha=0.3)
# 统计摘要
summary_text = ""
for name, results in results_dict.items():
summary_text += f"{name}:\n"
summary_text += f" Avg Brake: {np.mean(results['brake_counts']):.1f}\n"
summary_text += f" Avg Variance: {np.mean(results['speed_variances']):.2f}\n"
summary_text += f" Avg Throughput: {np.mean(results['throughputs']):.1f}\n\n"
axes[1, 1].text(0.1, 0.5, summary_text, fontsize=10, family="monospace",
verticalalignment="center", transform=axes[1, 1].transAxes)
axes[1, 1].axis("off")
plt.tight_layout()
plt.savefig(os.path.join(output_dir, "metrics_comparison.png"), dpi=150)
plt.close()
def evaluate_no_control(config):
"""评估无管控基准"""
env = SUMOEdgeVSLEnvironment(config)
# 收集所有车检器组位置(只包含有检测器的位置)
detector_groups = []
for edge_id in env.control_edges:
edge_info = env.parser.edge_info[edge_id]
for pos in edge_info.detector_positions:
pos_idx = edge_info.detector_positions.index(pos)
has_detector = False
for lane_idx in edge_info.traffic_lane_indices:
if edge_info.detectors.get((lane_idx, pos_idx)):
has_detector = True
break
if has_detector:
detector_groups.append((edge_id, pos))
state = env.reset(seed=42)
speed_data = []
density_data = []
action_data = []
brake_counts = []
speed_variances = []
throughputs = []
done = False
while not done:
action = np.array([4] * env.num_edges)
action_data.append(action.copy())
state, reward, done, info = env.step(action)
if not done:
step_speeds = []
step_densities = []
for edge_id, pos in detector_groups:
edge_info = env.parser.edge_info[edge_id]
pos_idx = edge_info.detector_positions.index(pos)
speeds = []
occupancies = []
for lane_idx in edge_info.traffic_lane_indices:
det_id = edge_info.detectors.get((lane_idx, pos_idx))
if det_id:
spd = traci.inductionloop.getLastIntervalMeanSpeed(det_id)
occ = traci.inductionloop.getLastIntervalOccupancy(det_id)
if spd >= 0:
speeds.append(spd * 3.6)
if occ >= 0:
occupancies.append(occ)
if speeds:
avg_speed = np.mean(speeds)
else:
edge_speed = traci.edge.getLastStepMeanSpeed(edge_id)
avg_speed = edge_speed * 3.6 if edge_speed >= 0 else 0
step_speeds.append(avg_speed)
avg_occupancy = np.mean(occupancies) if occupancies else 0
step_densities.append(avg_occupancy)
speed_data.append(step_speeds)
density_data.append(step_densities)
total_brakes = 0
all_speeds = []
for edge_id, edge_info in env.parser.edge_info.items():
for det_id in edge_info.detectors.values():
total_brakes += traci.inductionloop.getLastIntervalVehicleNumber(det_id)
spd = traci.inductionloop.getLastIntervalMeanSpeed(det_id)
if spd >= 0:
all_speeds.append(spd * 3.6)
brake_counts.append(total_brakes)
speed_variances.append(np.var(all_speeds) if all_speeds else 0)
throughputs.append(info["throughput"])
env.close()
return {
"speed_data": np.array(speed_data),
"density_data": np.array(density_data),
"action_data": np.array(action_data),
"brake_counts": brake_counts,
"speed_variances": speed_variances,
"throughputs": throughputs
}
if __name__ == "__main__":
with open("config_sumo_vsl.yaml", "r", encoding="utf-8") as f:
config = yaml.safe_load(f)
models = {
"PPO": "checkpoints_sumo_vsl/final",
"APPO": "checkpoints_sumo_appo/final",
"DQN": "checkpoints_sumo_dqn/final"
# "PPO": "checkpoints_sumo_vsl/20260324_202539",
# "APPO": "checkpoints_sumo_appo/20260324_202514",
# "DQN": "checkpoints_sumo_dqn/20260324_202548"
}
results = {}
for name, ckpt in models.items():
print(f"Evaluating {name}...")
results[name] = evaluate_model(name, ckpt, config)
print("Evaluating No Control...")
results["No Control"] = evaluate_no_control(config)
plot_results(results, "evaluation_results")
print("Evaluation complete! Results saved to evaluation_results/")