ctm-dqn/plot_density_from_csv.py

44 lines
1.3 KiB
Python

"""从CSV重新绘制密度时空图"""
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib
matplotlib.use("Agg")
output_dir = "evaluation_results"
# 读取所有模型的密度数据
models = ["PPO", "APPO", "DQN", "No Control"]
density_data = {}
for model in models:
csv_path = os.path.join(output_dir, f"{model}_density.csv")
df = pd.read_csv(csv_path)
density_data[model] = df.values
# 计算全局密度范围以统一颜色映射
all_densities = np.concatenate([data.flatten() for data in density_data.values()])
vmin = 10
vmax = 50
print(f"密度范围: {vmin:.2f}% - {vmax:.2f}%")
# 绘制2x2密度热图
fig, axes = plt.subplots(2, 2, figsize=(16, 12))
axes = axes.flatten()
for idx, model in enumerate(models):
im = axes[idx].imshow(density_data[model].T, aspect="auto", cmap="YlOrRd",
vmin=vmin, vmax=vmax, origin="lower")
axes[idx].set_title(f"{model} Density Heatmap")
axes[idx].set_xlabel("Time Step")
axes[idx].set_ylabel("Detector Group Index")
plt.colorbar(im, ax=axes[idx], label="Occupancy (%)")
plt.tight_layout()
plt.savefig(os.path.join(output_dir, "density_heatmaps_adjusted.png"), dpi=150)
plt.close()
print(f"密度热图已保存到 {output_dir}/density_heatmaps_adjusted.png")