增加绘制模型控制action时空图
This commit is contained in:
parent
d4eadccbf3
commit
f653c61c03
|
|
@ -64,6 +64,7 @@ def evaluate_model(model_type, checkpoint_dir, config):
|
|||
state = env.reset(seed=42)
|
||||
speed_data = []
|
||||
density_data = []
|
||||
action_data = []
|
||||
brake_counts = []
|
||||
speed_variances = []
|
||||
throughputs = []
|
||||
|
|
@ -77,6 +78,7 @@ def evaluate_model(model_type, checkpoint_dir, config):
|
|||
else:
|
||||
action, _, _ = agent.select_action(state, deterministic=True)
|
||||
|
||||
action_data.append(action.copy())
|
||||
state, reward, done, info = env.step(action)
|
||||
|
||||
if not done:
|
||||
|
|
@ -132,6 +134,7 @@ def evaluate_model(model_type, checkpoint_dir, config):
|
|||
return {
|
||||
"speed_data": np.array(speed_data),
|
||||
"density_data": np.array(density_data),
|
||||
"action_data": np.array(action_data),
|
||||
"brake_counts": brake_counts,
|
||||
"speed_variances": speed_variances,
|
||||
"throughputs": throughputs
|
||||
|
|
@ -150,6 +153,9 @@ def plot_results(results_dict, output_dir):
|
|||
density_df = pd.DataFrame(results["density_data"])
|
||||
density_df.to_csv(os.path.join(output_dir, f"{name}_density.csv"), index=False)
|
||||
|
||||
action_df = pd.DataFrame(results["action_data"])
|
||||
action_df.to_csv(os.path.join(output_dir, f"{name}_action.csv"), index=False)
|
||||
|
||||
metrics_df = pd.DataFrame({
|
||||
"brake_counts": results["brake_counts"],
|
||||
"speed_variances": results["speed_variances"],
|
||||
|
|
@ -252,6 +258,7 @@ def evaluate_no_control(config):
|
|||
state = env.reset(seed=42)
|
||||
speed_data = []
|
||||
density_data = []
|
||||
action_data = []
|
||||
brake_counts = []
|
||||
speed_variances = []
|
||||
throughputs = []
|
||||
|
|
@ -259,6 +266,7 @@ def evaluate_no_control(config):
|
|||
|
||||
while not done:
|
||||
action = np.array([4] * env.num_edges)
|
||||
action_data.append(action.copy())
|
||||
state, reward, done, info = env.step(action)
|
||||
|
||||
if not done:
|
||||
|
|
@ -306,6 +314,7 @@ def evaluate_no_control(config):
|
|||
return {
|
||||
"speed_data": np.array(speed_data),
|
||||
"density_data": np.array(density_data),
|
||||
"action_data": np.array(action_data),
|
||||
"brake_counts": brake_counts,
|
||||
"speed_variances": speed_variances,
|
||||
"throughputs": throughputs
|
||||
|
|
|
|||
|
|
@ -0,0 +1,32 @@
|
|||
"""从CSV绘制决策时空图"""
|
||||
import os
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import matplotlib.pyplot as plt
|
||||
import matplotlib
|
||||
matplotlib.use("Agg")
|
||||
|
||||
output_dir = "evaluation_results"
|
||||
models = ["PPO", "APPO", "DQN", "No Control"]
|
||||
speed_map = [40, 60, 80, 100, 120]
|
||||
|
||||
fig, axes = plt.subplots(2, 2, figsize=(16, 12))
|
||||
axes = axes.flatten()
|
||||
|
||||
for idx, model in enumerate(models):
|
||||
csv_path = os.path.join(output_dir, f"{model}_action.csv")
|
||||
df = pd.read_csv(csv_path)
|
||||
actions = df.values
|
||||
vsl_speeds = np.vectorize(lambda x: speed_map[int(x)])(actions)
|
||||
|
||||
im = axes[idx].imshow(vsl_speeds.T, aspect="auto", cmap="RdYlGn",
|
||||
vmin=40, vmax=120, origin="lower")
|
||||
axes[idx].set_title(f"{model} VSL Decision")
|
||||
axes[idx].set_xlabel("Time Step")
|
||||
axes[idx].set_ylabel("Control Edge Index")
|
||||
cbar = plt.colorbar(im, ax=axes[idx], label="VSL (km/h)")
|
||||
cbar.set_ticks([40, 60, 80, 100, 120])
|
||||
|
||||
plt.tight_layout()
|
||||
plt.savefig(os.path.join(output_dir, "action_heatmaps.png"), dpi=150)
|
||||
print("决策时空图已保存")
|
||||
Loading…
Reference in New Issue