增加绘制模型控制action时空图

This commit is contained in:
Zihan Ye 2026-03-29 02:29:01 +08:00
parent d4eadccbf3
commit f653c61c03
2 changed files with 41 additions and 0 deletions

View File

@ -64,6 +64,7 @@ def evaluate_model(model_type, checkpoint_dir, config):
state = env.reset(seed=42)
speed_data = []
density_data = []
action_data = []
brake_counts = []
speed_variances = []
throughputs = []
@ -77,6 +78,7 @@ def evaluate_model(model_type, checkpoint_dir, config):
else:
action, _, _ = agent.select_action(state, deterministic=True)
action_data.append(action.copy())
state, reward, done, info = env.step(action)
if not done:
@ -132,6 +134,7 @@ def evaluate_model(model_type, checkpoint_dir, config):
return {
"speed_data": np.array(speed_data),
"density_data": np.array(density_data),
"action_data": np.array(action_data),
"brake_counts": brake_counts,
"speed_variances": speed_variances,
"throughputs": throughputs
@ -150,6 +153,9 @@ def plot_results(results_dict, output_dir):
density_df = pd.DataFrame(results["density_data"])
density_df.to_csv(os.path.join(output_dir, f"{name}_density.csv"), index=False)
action_df = pd.DataFrame(results["action_data"])
action_df.to_csv(os.path.join(output_dir, f"{name}_action.csv"), index=False)
metrics_df = pd.DataFrame({
"brake_counts": results["brake_counts"],
"speed_variances": results["speed_variances"],
@ -252,6 +258,7 @@ def evaluate_no_control(config):
state = env.reset(seed=42)
speed_data = []
density_data = []
action_data = []
brake_counts = []
speed_variances = []
throughputs = []
@ -259,6 +266,7 @@ def evaluate_no_control(config):
while not done:
action = np.array([4] * env.num_edges)
action_data.append(action.copy())
state, reward, done, info = env.step(action)
if not done:
@ -306,6 +314,7 @@ def evaluate_no_control(config):
return {
"speed_data": np.array(speed_data),
"density_data": np.array(density_data),
"action_data": np.array(action_data),
"brake_counts": brake_counts,
"speed_variances": speed_variances,
"throughputs": throughputs

32
plot_actions.py Normal file
View File

@ -0,0 +1,32 @@
"""从CSV绘制决策时空图"""
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib
matplotlib.use("Agg")
output_dir = "evaluation_results"
models = ["PPO", "APPO", "DQN", "No Control"]
speed_map = [40, 60, 80, 100, 120]
fig, axes = plt.subplots(2, 2, figsize=(16, 12))
axes = axes.flatten()
for idx, model in enumerate(models):
csv_path = os.path.join(output_dir, f"{model}_action.csv")
df = pd.read_csv(csv_path)
actions = df.values
vsl_speeds = np.vectorize(lambda x: speed_map[int(x)])(actions)
im = axes[idx].imshow(vsl_speeds.T, aspect="auto", cmap="RdYlGn",
vmin=40, vmax=120, origin="lower")
axes[idx].set_title(f"{model} VSL Decision")
axes[idx].set_xlabel("Time Step")
axes[idx].set_ylabel("Control Edge Index")
cbar = plt.colorbar(im, ax=axes[idx], label="VSL (km/h)")
cbar.set_ticks([40, 60, 80, 100, 120])
plt.tight_layout()
plt.savefig(os.path.join(output_dir, "action_heatmaps.png"), dpi=150)
print("决策时空图已保存")