ctm-dqn/scripts/plot_actions.py

33 lines
1.0 KiB
Python

"""从CSV绘制决策时空图"""
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib
matplotlib.use("Agg")
output_dir = "evaluation_results"
models = ["PPO", "APPO", "DQN", "No Control"]
speed_map = [40, 60, 80, 100, 120]
fig, axes = plt.subplots(2, 2, figsize=(16, 12))
axes = axes.flatten()
for idx, model in enumerate(models):
csv_path = os.path.join(output_dir, f"{model}_action.csv")
df = pd.read_csv(csv_path)
actions = df.values
vsl_speeds = np.vectorize(lambda x: speed_map[int(x)])(actions)
im = axes[idx].imshow(vsl_speeds.T, aspect="auto", cmap="RdYlGn",
vmin=40, vmax=120, origin="lower")
axes[idx].set_title(f"{model} VSL Decision")
axes[idx].set_xlabel("Time Step")
axes[idx].set_ylabel("Control Edge Index")
cbar = plt.colorbar(im, ax=axes[idx], label="VSL (km/h)")
cbar.set_ticks([40, 60, 80, 100, 120])
plt.tight_layout()
plt.savefig(os.path.join(output_dir, "action_heatmaps.png"), dpi=150)
print("决策时空图已保存")