33 lines
1.0 KiB
Python
33 lines
1.0 KiB
Python
"""从CSV绘制决策时空图"""
|
|
import os
|
|
import numpy as np
|
|
import pandas as pd
|
|
import matplotlib.pyplot as plt
|
|
import matplotlib
|
|
matplotlib.use("Agg")
|
|
|
|
output_dir = "evaluation_results"
|
|
models = ["PPO", "APPO", "DQN", "No Control"]
|
|
speed_map = [40, 60, 80, 100, 120]
|
|
|
|
fig, axes = plt.subplots(2, 2, figsize=(16, 12))
|
|
axes = axes.flatten()
|
|
|
|
for idx, model in enumerate(models):
|
|
csv_path = os.path.join(output_dir, f"{model}_action.csv")
|
|
df = pd.read_csv(csv_path)
|
|
actions = df.values
|
|
vsl_speeds = np.vectorize(lambda x: speed_map[int(x)])(actions)
|
|
|
|
im = axes[idx].imshow(vsl_speeds.T, aspect="auto", cmap="RdYlGn",
|
|
vmin=40, vmax=120, origin="lower")
|
|
axes[idx].set_title(f"{model} VSL Decision")
|
|
axes[idx].set_xlabel("Time Step")
|
|
axes[idx].set_ylabel("Control Edge Index")
|
|
cbar = plt.colorbar(im, ax=axes[idx], label="VSL (km/h)")
|
|
cbar.set_ticks([40, 60, 80, 100, 120])
|
|
|
|
plt.tight_layout()
|
|
plt.savefig(os.path.join(output_dir, "action_heatmaps.png"), dpi=150)
|
|
print("决策时空图已保存")
|