添加模型评估功能
This commit is contained in:
parent
999430422d
commit
f64d74b986
|
|
@ -0,0 +1,336 @@
|
|||
"""评估训练好的模型并绘制速度时空图"""
|
||||
import os
|
||||
import yaml
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
import matplotlib
|
||||
matplotlib.use("Agg")
|
||||
from sumo_edge_vsl_environment import SUMOEdgeVSLEnvironment
|
||||
from ppo_agent import PPOAgent
|
||||
from appo_sumo_agent import APPOSUMOAgent
|
||||
from dqn_agent import DQNAgent
|
||||
import traci
|
||||
|
||||
def evaluate_model(model_type, checkpoint_dir, config):
|
||||
"""评估单个模型"""
|
||||
env = SUMOEdgeVSLEnvironment(config)
|
||||
|
||||
if model_type == "PPO":
|
||||
agent = PPOAgent(
|
||||
state_dim=env.state_dim,
|
||||
action_dims=[env.action_dim] * env.num_edges,
|
||||
hidden_layers=[256, 256, 128],
|
||||
device="cuda"
|
||||
)
|
||||
agent.load(os.path.join(checkpoint_dir, "model_best.pt"))
|
||||
elif model_type == "APPO":
|
||||
agent = APPOSUMOAgent(
|
||||
state_dim=env.state_dim,
|
||||
action_dims=[env.action_dim] * env.num_edges,
|
||||
hidden_dim=128,
|
||||
device="cuda",
|
||||
total_episodes=500
|
||||
)
|
||||
agent.load(os.path.join(checkpoint_dir, "model_best.pt"))
|
||||
elif model_type == "DQN":
|
||||
agents = []
|
||||
for i in range(env.num_edges):
|
||||
agent_i = DQNAgent(
|
||||
state_dim=env.state_dim,
|
||||
num_actions=5,
|
||||
hidden_dim=256,
|
||||
device="cuda"
|
||||
)
|
||||
agent_i.load(os.path.join(checkpoint_dir, f"model_best_edge{i}.pt"))
|
||||
agents.append(agent_i)
|
||||
agent = agents
|
||||
|
||||
# 收集所有车检器组位置(只包含有检测器的位置)
|
||||
detector_groups = []
|
||||
for edge_id in env.control_edges:
|
||||
edge_info = env.parser.edge_info[edge_id]
|
||||
for pos in edge_info.detector_positions:
|
||||
pos_idx = edge_info.detector_positions.index(pos)
|
||||
# 检查该位置是否有检测器
|
||||
has_detector = False
|
||||
for lane_idx in edge_info.traffic_lane_indices:
|
||||
if edge_info.detectors.get((lane_idx, pos_idx)):
|
||||
has_detector = True
|
||||
break
|
||||
if has_detector:
|
||||
detector_groups.append((edge_id, pos))
|
||||
|
||||
# 运行仿真
|
||||
state = env.reset(seed=42)
|
||||
speed_data = []
|
||||
density_data = []
|
||||
brake_counts = []
|
||||
speed_variances = []
|
||||
throughputs = []
|
||||
|
||||
done = False
|
||||
step = 0
|
||||
|
||||
while not done:
|
||||
if model_type == "DQN":
|
||||
action = np.array([ag.select_action(state) for ag in agent])
|
||||
else:
|
||||
action, _, _ = agent.select_action(state, deterministic=True)
|
||||
|
||||
state, reward, done, info = env.step(action)
|
||||
|
||||
if not done:
|
||||
step_speeds = []
|
||||
step_densities = []
|
||||
for edge_id, pos in detector_groups:
|
||||
edge_info = env.parser.edge_info[edge_id]
|
||||
pos_idx = edge_info.detector_positions.index(pos)
|
||||
|
||||
# 获取该位置所有车道的速度和占有率并取平均
|
||||
speeds = []
|
||||
occupancies = []
|
||||
for lane_idx in edge_info.traffic_lane_indices:
|
||||
det_id = edge_info.detectors.get((lane_idx, pos_idx))
|
||||
if det_id:
|
||||
spd = traci.inductionloop.getLastIntervalMeanSpeed(det_id)
|
||||
occ = traci.inductionloop.getLastIntervalOccupancy(det_id)
|
||||
if spd >= 0:
|
||||
speeds.append(spd * 3.6)
|
||||
if occ >= 0:
|
||||
occupancies.append(occ)
|
||||
|
||||
# 如果没有检测器数据,使用edge的实时速度作为备用
|
||||
if speeds:
|
||||
avg_speed = np.mean(speeds)
|
||||
else:
|
||||
edge_speed = traci.edge.getLastStepMeanSpeed(edge_id)
|
||||
avg_speed = edge_speed * 3.6 if edge_speed >= 0 else 0
|
||||
step_speeds.append(avg_speed)
|
||||
|
||||
avg_occupancy = np.mean(occupancies) if occupancies else 0
|
||||
step_densities.append(avg_occupancy)
|
||||
|
||||
speed_data.append(step_speeds)
|
||||
density_data.append(step_densities)
|
||||
|
||||
total_brakes = 0
|
||||
all_speeds = []
|
||||
for edge_id, edge_info in env.parser.edge_info.items():
|
||||
for det_id in edge_info.detectors.values():
|
||||
total_brakes += traci.inductionloop.getLastIntervalVehicleNumber(det_id)
|
||||
spd = traci.inductionloop.getLastIntervalMeanSpeed(det_id)
|
||||
if spd >= 0:
|
||||
all_speeds.append(spd * 3.6)
|
||||
|
||||
brake_counts.append(total_brakes)
|
||||
speed_variances.append(np.var(all_speeds) if all_speeds else 0)
|
||||
throughputs.append(info["throughput"])
|
||||
step += 1
|
||||
|
||||
env.close()
|
||||
|
||||
return {
|
||||
"speed_data": np.array(speed_data),
|
||||
"density_data": np.array(density_data),
|
||||
"brake_counts": brake_counts,
|
||||
"speed_variances": speed_variances,
|
||||
"throughputs": throughputs
|
||||
}
|
||||
|
||||
def plot_results(results_dict, output_dir):
|
||||
"""绘制对比图并保存CSV"""
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
# 保存CSV数据
|
||||
import pandas as pd
|
||||
for name, results in results_dict.items():
|
||||
speed_df = pd.DataFrame(results["speed_data"])
|
||||
speed_df.to_csv(os.path.join(output_dir, f"{name}_speed.csv"), index=False)
|
||||
|
||||
density_df = pd.DataFrame(results["density_data"])
|
||||
density_df.to_csv(os.path.join(output_dir, f"{name}_density.csv"), index=False)
|
||||
|
||||
metrics_df = pd.DataFrame({
|
||||
"brake_counts": results["brake_counts"],
|
||||
"speed_variances": results["speed_variances"],
|
||||
"throughputs": results["throughputs"]
|
||||
})
|
||||
metrics_df.to_csv(os.path.join(output_dir, f"{name}_metrics.csv"), index=False)
|
||||
|
||||
# 速度时空图 - 2x2布局
|
||||
fig, axes = plt.subplots(2, 2, figsize=(16, 12))
|
||||
axes = axes.flatten()
|
||||
|
||||
for idx, (name, results) in enumerate(results_dict.items()):
|
||||
im = axes[idx].imshow(results["speed_data"].T, aspect="auto", cmap="RdYlGn",
|
||||
vmin=0, vmax=120, origin="lower")
|
||||
axes[idx].set_title(f"{name} Speed Heatmap")
|
||||
axes[idx].set_xlabel("Time Step")
|
||||
axes[idx].set_ylabel("Detector Group Index")
|
||||
plt.colorbar(im, ax=axes[idx], label="Speed (km/h)")
|
||||
|
||||
plt.tight_layout()
|
||||
plt.savefig(os.path.join(output_dir, "speed_heatmaps.png"), dpi=150)
|
||||
plt.close()
|
||||
|
||||
# 密度时空图 - 2x2布局
|
||||
fig, axes = plt.subplots(2, 2, figsize=(16, 12))
|
||||
axes = axes.flatten()
|
||||
|
||||
for idx, (name, results) in enumerate(results_dict.items()):
|
||||
im = axes[idx].imshow(results["density_data"].T, aspect="auto", cmap="YlOrRd",
|
||||
vmin=0, vmax=100, origin="lower")
|
||||
axes[idx].set_title(f"{name} Density Heatmap")
|
||||
axes[idx].set_xlabel("Time Step")
|
||||
axes[idx].set_ylabel("Detector Group Index")
|
||||
plt.colorbar(im, ax=axes[idx], label="Occupancy (%)")
|
||||
|
||||
plt.tight_layout()
|
||||
plt.savefig(os.path.join(output_dir, "density_heatmaps.png"), dpi=150)
|
||||
plt.close()
|
||||
|
||||
# 关键指标曲线
|
||||
fig, axes = plt.subplots(2, 2, figsize=(15, 10))
|
||||
|
||||
for name, results in results_dict.items():
|
||||
axes[0, 0].plot(results["brake_counts"], label=name, alpha=0.7)
|
||||
axes[0, 1].plot(results["speed_variances"], label=name, alpha=0.7)
|
||||
axes[1, 0].plot(results["throughputs"], label=name, alpha=0.7)
|
||||
|
||||
axes[0, 0].set_title("Brake Counts Over Time")
|
||||
axes[0, 0].set_xlabel("Time Step")
|
||||
axes[0, 0].set_ylabel("Brake Count")
|
||||
axes[0, 0].legend()
|
||||
axes[0, 0].grid(True, alpha=0.3)
|
||||
|
||||
axes[0, 1].set_title("Speed Variance Over Time")
|
||||
axes[0, 1].set_xlabel("Time Step")
|
||||
axes[0, 1].set_ylabel("Variance (km/h)²")
|
||||
axes[0, 1].legend()
|
||||
axes[0, 1].grid(True, alpha=0.3)
|
||||
|
||||
axes[1, 0].set_title("Throughput Over Time")
|
||||
axes[1, 0].set_xlabel("Time Step")
|
||||
axes[1, 0].set_ylabel("Throughput (veh/h)")
|
||||
axes[1, 0].legend()
|
||||
axes[1, 0].grid(True, alpha=0.3)
|
||||
|
||||
# 统计摘要
|
||||
summary_text = ""
|
||||
for name, results in results_dict.items():
|
||||
summary_text += f"{name}:\n"
|
||||
summary_text += f" Avg Brake: {np.mean(results['brake_counts']):.1f}\n"
|
||||
summary_text += f" Avg Variance: {np.mean(results['speed_variances']):.2f}\n"
|
||||
summary_text += f" Avg Throughput: {np.mean(results['throughputs']):.1f}\n\n"
|
||||
|
||||
axes[1, 1].text(0.1, 0.5, summary_text, fontsize=10, family="monospace",
|
||||
verticalalignment="center", transform=axes[1, 1].transAxes)
|
||||
axes[1, 1].axis("off")
|
||||
|
||||
plt.tight_layout()
|
||||
plt.savefig(os.path.join(output_dir, "metrics_comparison.png"), dpi=150)
|
||||
plt.close()
|
||||
|
||||
def evaluate_no_control(config):
|
||||
"""评估无管控基准"""
|
||||
env = SUMOEdgeVSLEnvironment(config)
|
||||
|
||||
# 收集所有车检器组位置(只包含有检测器的位置)
|
||||
detector_groups = []
|
||||
for edge_id in env.control_edges:
|
||||
edge_info = env.parser.edge_info[edge_id]
|
||||
for pos in edge_info.detector_positions:
|
||||
pos_idx = edge_info.detector_positions.index(pos)
|
||||
has_detector = False
|
||||
for lane_idx in edge_info.traffic_lane_indices:
|
||||
if edge_info.detectors.get((lane_idx, pos_idx)):
|
||||
has_detector = True
|
||||
break
|
||||
if has_detector:
|
||||
detector_groups.append((edge_id, pos))
|
||||
|
||||
state = env.reset(seed=42)
|
||||
speed_data = []
|
||||
density_data = []
|
||||
brake_counts = []
|
||||
speed_variances = []
|
||||
throughputs = []
|
||||
done = False
|
||||
|
||||
while not done:
|
||||
action = np.array([4] * env.num_edges)
|
||||
state, reward, done, info = env.step(action)
|
||||
|
||||
if not done:
|
||||
step_speeds = []
|
||||
step_densities = []
|
||||
for edge_id, pos in detector_groups:
|
||||
edge_info = env.parser.edge_info[edge_id]
|
||||
pos_idx = edge_info.detector_positions.index(pos)
|
||||
speeds = []
|
||||
occupancies = []
|
||||
for lane_idx in edge_info.traffic_lane_indices:
|
||||
det_id = edge_info.detectors.get((lane_idx, pos_idx))
|
||||
if det_id:
|
||||
spd = traci.inductionloop.getLastIntervalMeanSpeed(det_id)
|
||||
occ = traci.inductionloop.getLastIntervalOccupancy(det_id)
|
||||
if spd >= 0:
|
||||
speeds.append(spd * 3.6)
|
||||
if occ >= 0:
|
||||
occupancies.append(occ)
|
||||
if speeds:
|
||||
avg_speed = np.mean(speeds)
|
||||
else:
|
||||
edge_speed = traci.edge.getLastStepMeanSpeed(edge_id)
|
||||
avg_speed = edge_speed * 3.6 if edge_speed >= 0 else 0
|
||||
step_speeds.append(avg_speed)
|
||||
|
||||
avg_occupancy = np.mean(occupancies) if occupancies else 0
|
||||
step_densities.append(avg_occupancy)
|
||||
|
||||
speed_data.append(step_speeds)
|
||||
density_data.append(step_densities)
|
||||
total_brakes = 0
|
||||
all_speeds = []
|
||||
for edge_id, edge_info in env.parser.edge_info.items():
|
||||
for det_id in edge_info.detectors.values():
|
||||
total_brakes += traci.inductionloop.getLastIntervalVehicleNumber(det_id)
|
||||
spd = traci.inductionloop.getLastIntervalMeanSpeed(det_id)
|
||||
if spd >= 0:
|
||||
all_speeds.append(spd * 3.6)
|
||||
brake_counts.append(total_brakes)
|
||||
speed_variances.append(np.var(all_speeds) if all_speeds else 0)
|
||||
throughputs.append(info["throughput"])
|
||||
|
||||
env.close()
|
||||
return {
|
||||
"speed_data": np.array(speed_data),
|
||||
"density_data": np.array(density_data),
|
||||
"brake_counts": brake_counts,
|
||||
"speed_variances": speed_variances,
|
||||
"throughputs": throughputs
|
||||
}
|
||||
|
||||
if __name__ == "__main__":
|
||||
with open("config_sumo_vsl.yaml", "r", encoding="utf-8") as f:
|
||||
config = yaml.safe_load(f)
|
||||
|
||||
models = {
|
||||
"PPO": "checkpoints_sumo_vsl/final",
|
||||
"APPO": "checkpoints_sumo_appo/final",
|
||||
"DQN": "checkpoints_sumo_dqn/final"
|
||||
# "PPO": "checkpoints_sumo_vsl/20260324_202539",
|
||||
# "APPO": "checkpoints_sumo_appo/20260324_202514",
|
||||
# "DQN": "checkpoints_sumo_dqn/20260324_202548"
|
||||
}
|
||||
|
||||
results = {}
|
||||
for name, ckpt in models.items():
|
||||
print(f"Evaluating {name}...")
|
||||
results[name] = evaluate_model(name, ckpt, config)
|
||||
|
||||
print("Evaluating No Control...")
|
||||
results["No Control"] = evaluate_no_control(config)
|
||||
|
||||
plot_results(results, "evaluation_results")
|
||||
print("Evaluation complete! Results saved to evaluation_results/")
|
||||
|
|
@ -0,0 +1,43 @@
|
|||
"""从CSV重新绘制密度时空图"""
|
||||
import os
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import matplotlib.pyplot as plt
|
||||
import matplotlib
|
||||
matplotlib.use("Agg")
|
||||
|
||||
output_dir = "evaluation_results"
|
||||
|
||||
# 读取所有模型的密度数据
|
||||
models = ["PPO", "APPO", "DQN", "No Control"]
|
||||
density_data = {}
|
||||
|
||||
for model in models:
|
||||
csv_path = os.path.join(output_dir, f"{model}_density.csv")
|
||||
df = pd.read_csv(csv_path)
|
||||
density_data[model] = df.values
|
||||
|
||||
# 计算全局密度范围以统一颜色映射
|
||||
all_densities = np.concatenate([data.flatten() for data in density_data.values()])
|
||||
vmin = 10
|
||||
vmax = 50
|
||||
|
||||
print(f"密度范围: {vmin:.2f}% - {vmax:.2f}%")
|
||||
|
||||
# 绘制2x2密度热图
|
||||
fig, axes = plt.subplots(2, 2, figsize=(16, 12))
|
||||
axes = axes.flatten()
|
||||
|
||||
for idx, model in enumerate(models):
|
||||
im = axes[idx].imshow(density_data[model].T, aspect="auto", cmap="YlOrRd",
|
||||
vmin=vmin, vmax=vmax, origin="lower")
|
||||
axes[idx].set_title(f"{model} Density Heatmap")
|
||||
axes[idx].set_xlabel("Time Step")
|
||||
axes[idx].set_ylabel("Detector Group Index")
|
||||
plt.colorbar(im, ax=axes[idx], label="Occupancy (%)")
|
||||
|
||||
plt.tight_layout()
|
||||
plt.savefig(os.path.join(output_dir, "density_heatmaps_adjusted.png"), dpi=150)
|
||||
plt.close()
|
||||
|
||||
print(f"密度热图已保存到 {output_dir}/density_heatmaps_adjusted.png")
|
||||
|
|
@ -0,0 +1,33 @@
|
|||
"""绘制全局区段速度方差曲线"""
|
||||
import os
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import matplotlib.pyplot as plt
|
||||
import matplotlib
|
||||
matplotlib.use("Agg")
|
||||
|
||||
output_dir = "evaluation_results"
|
||||
models = ["PPO", "APPO", "DQN", "No Control"]
|
||||
|
||||
fig, ax = plt.subplots(figsize=(12, 6))
|
||||
|
||||
for model in models:
|
||||
csv_path = os.path.join(output_dir, f"{model}_speed.csv")
|
||||
df = pd.read_csv(csv_path)
|
||||
|
||||
# 计算每个时刻所有区段的速度方差
|
||||
variances = df.var(axis=1)
|
||||
|
||||
ax.plot(variances, label=model, alpha=0.8, linewidth=1.5)
|
||||
|
||||
ax.set_xlabel("Time Step", fontsize=12)
|
||||
ax.set_ylabel("Speed Variance (km/h)²", fontsize=12)
|
||||
ax.set_title("Global Speed Variance Over Time", fontsize=14)
|
||||
ax.legend(fontsize=11)
|
||||
ax.grid(True, alpha=0.3)
|
||||
|
||||
plt.tight_layout()
|
||||
plt.savefig(os.path.join(output_dir, "global_speed_variance.png"), dpi=150)
|
||||
plt.close()
|
||||
|
||||
print(f"全局速度方差曲线已保存到 {output_dir}/global_speed_variance.png")
|
||||
Loading…
Reference in New Issue