增加绘制模型控制action时空图
This commit is contained in:
parent
d4eadccbf3
commit
f653c61c03
|
|
@ -64,6 +64,7 @@ def evaluate_model(model_type, checkpoint_dir, config):
|
||||||
state = env.reset(seed=42)
|
state = env.reset(seed=42)
|
||||||
speed_data = []
|
speed_data = []
|
||||||
density_data = []
|
density_data = []
|
||||||
|
action_data = []
|
||||||
brake_counts = []
|
brake_counts = []
|
||||||
speed_variances = []
|
speed_variances = []
|
||||||
throughputs = []
|
throughputs = []
|
||||||
|
|
@ -77,6 +78,7 @@ def evaluate_model(model_type, checkpoint_dir, config):
|
||||||
else:
|
else:
|
||||||
action, _, _ = agent.select_action(state, deterministic=True)
|
action, _, _ = agent.select_action(state, deterministic=True)
|
||||||
|
|
||||||
|
action_data.append(action.copy())
|
||||||
state, reward, done, info = env.step(action)
|
state, reward, done, info = env.step(action)
|
||||||
|
|
||||||
if not done:
|
if not done:
|
||||||
|
|
@ -132,6 +134,7 @@ def evaluate_model(model_type, checkpoint_dir, config):
|
||||||
return {
|
return {
|
||||||
"speed_data": np.array(speed_data),
|
"speed_data": np.array(speed_data),
|
||||||
"density_data": np.array(density_data),
|
"density_data": np.array(density_data),
|
||||||
|
"action_data": np.array(action_data),
|
||||||
"brake_counts": brake_counts,
|
"brake_counts": brake_counts,
|
||||||
"speed_variances": speed_variances,
|
"speed_variances": speed_variances,
|
||||||
"throughputs": throughputs
|
"throughputs": throughputs
|
||||||
|
|
@ -150,6 +153,9 @@ def plot_results(results_dict, output_dir):
|
||||||
density_df = pd.DataFrame(results["density_data"])
|
density_df = pd.DataFrame(results["density_data"])
|
||||||
density_df.to_csv(os.path.join(output_dir, f"{name}_density.csv"), index=False)
|
density_df.to_csv(os.path.join(output_dir, f"{name}_density.csv"), index=False)
|
||||||
|
|
||||||
|
action_df = pd.DataFrame(results["action_data"])
|
||||||
|
action_df.to_csv(os.path.join(output_dir, f"{name}_action.csv"), index=False)
|
||||||
|
|
||||||
metrics_df = pd.DataFrame({
|
metrics_df = pd.DataFrame({
|
||||||
"brake_counts": results["brake_counts"],
|
"brake_counts": results["brake_counts"],
|
||||||
"speed_variances": results["speed_variances"],
|
"speed_variances": results["speed_variances"],
|
||||||
|
|
@ -252,6 +258,7 @@ def evaluate_no_control(config):
|
||||||
state = env.reset(seed=42)
|
state = env.reset(seed=42)
|
||||||
speed_data = []
|
speed_data = []
|
||||||
density_data = []
|
density_data = []
|
||||||
|
action_data = []
|
||||||
brake_counts = []
|
brake_counts = []
|
||||||
speed_variances = []
|
speed_variances = []
|
||||||
throughputs = []
|
throughputs = []
|
||||||
|
|
@ -259,6 +266,7 @@ def evaluate_no_control(config):
|
||||||
|
|
||||||
while not done:
|
while not done:
|
||||||
action = np.array([4] * env.num_edges)
|
action = np.array([4] * env.num_edges)
|
||||||
|
action_data.append(action.copy())
|
||||||
state, reward, done, info = env.step(action)
|
state, reward, done, info = env.step(action)
|
||||||
|
|
||||||
if not done:
|
if not done:
|
||||||
|
|
@ -306,6 +314,7 @@ def evaluate_no_control(config):
|
||||||
return {
|
return {
|
||||||
"speed_data": np.array(speed_data),
|
"speed_data": np.array(speed_data),
|
||||||
"density_data": np.array(density_data),
|
"density_data": np.array(density_data),
|
||||||
|
"action_data": np.array(action_data),
|
||||||
"brake_counts": brake_counts,
|
"brake_counts": brake_counts,
|
||||||
"speed_variances": speed_variances,
|
"speed_variances": speed_variances,
|
||||||
"throughputs": throughputs
|
"throughputs": throughputs
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,32 @@
|
||||||
|
"""从CSV绘制决策时空图"""
|
||||||
|
import os
|
||||||
|
import numpy as np
|
||||||
|
import pandas as pd
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import matplotlib
|
||||||
|
matplotlib.use("Agg")
|
||||||
|
|
||||||
|
output_dir = "evaluation_results"
|
||||||
|
models = ["PPO", "APPO", "DQN", "No Control"]
|
||||||
|
speed_map = [40, 60, 80, 100, 120]
|
||||||
|
|
||||||
|
fig, axes = plt.subplots(2, 2, figsize=(16, 12))
|
||||||
|
axes = axes.flatten()
|
||||||
|
|
||||||
|
for idx, model in enumerate(models):
|
||||||
|
csv_path = os.path.join(output_dir, f"{model}_action.csv")
|
||||||
|
df = pd.read_csv(csv_path)
|
||||||
|
actions = df.values
|
||||||
|
vsl_speeds = np.vectorize(lambda x: speed_map[int(x)])(actions)
|
||||||
|
|
||||||
|
im = axes[idx].imshow(vsl_speeds.T, aspect="auto", cmap="RdYlGn",
|
||||||
|
vmin=40, vmax=120, origin="lower")
|
||||||
|
axes[idx].set_title(f"{model} VSL Decision")
|
||||||
|
axes[idx].set_xlabel("Time Step")
|
||||||
|
axes[idx].set_ylabel("Control Edge Index")
|
||||||
|
cbar = plt.colorbar(im, ax=axes[idx], label="VSL (km/h)")
|
||||||
|
cbar.set_ticks([40, 60, 80, 100, 120])
|
||||||
|
|
||||||
|
plt.tight_layout()
|
||||||
|
plt.savefig(os.path.join(output_dir, "action_heatmaps.png"), dpi=150)
|
||||||
|
print("决策时空图已保存")
|
||||||
Loading…
Reference in New Issue