44 lines
1.3 KiB
Python
44 lines
1.3 KiB
Python
"""从CSV重新绘制密度时空图"""
|
|
import os
|
|
import numpy as np
|
|
import pandas as pd
|
|
import matplotlib.pyplot as plt
|
|
import matplotlib
|
|
matplotlib.use("Agg")
|
|
|
|
output_dir = "evaluation_results"
|
|
|
|
# 读取所有模型的密度数据
|
|
models = ["PPO", "APPO", "DQN", "No Control"]
|
|
density_data = {}
|
|
|
|
for model in models:
|
|
csv_path = os.path.join(output_dir, f"{model}_density.csv")
|
|
df = pd.read_csv(csv_path)
|
|
density_data[model] = df.values
|
|
|
|
# 计算全局密度范围以统一颜色映射
|
|
all_densities = np.concatenate([data.flatten() for data in density_data.values()])
|
|
vmin = 10
|
|
vmax = 50
|
|
|
|
print(f"密度范围: {vmin:.2f}% - {vmax:.2f}%")
|
|
|
|
# 绘制2x2密度热图
|
|
fig, axes = plt.subplots(2, 2, figsize=(16, 12))
|
|
axes = axes.flatten()
|
|
|
|
for idx, model in enumerate(models):
|
|
im = axes[idx].imshow(density_data[model].T, aspect="auto", cmap="YlOrRd",
|
|
vmin=vmin, vmax=vmax, origin="lower")
|
|
axes[idx].set_title(f"{model} Density Heatmap")
|
|
axes[idx].set_xlabel("Time Step")
|
|
axes[idx].set_ylabel("Detector Group Index")
|
|
plt.colorbar(im, ax=axes[idx], label="Occupancy (%)")
|
|
|
|
plt.tight_layout()
|
|
plt.savefig(os.path.join(output_dir, "density_heatmaps_adjusted.png"), dpi=150)
|
|
plt.close()
|
|
|
|
print(f"密度热图已保存到 {output_dir}/density_heatmaps_adjusted.png")
|