ctm-dqn/utils/plot.py

51 lines
2.3 KiB
Python

"""共享训练曲线绘图工具"""
import numpy as np
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
def plot_training_curves(rewards, throughputs, mean_speeds, hard_brakes,
policy_losses=None, value_losses=None, save_path="training_curves.png"):
window = 20
has_losses = policy_losses and value_losses
ncols = 4 if has_losses else 2
nrows = 2
fig, axes = plt.subplots(nrows, ncols, figsize=(6 * ncols, 10))
def _plot(ax, data, color, title, ylabel):
ax.plot(data, alpha=0.4, color=color)
if len(data) > window:
ma = np.convolve(data, np.ones(window) / window, mode="valid")
ax.plot(range(window - 1, len(data)), ma, "r-", linewidth=2)
ax.set_xlabel("Episode")
ax.set_ylabel(ylabel)
ax.set_title(title)
ax.grid(True, alpha=0.3)
if has_losses:
_plot(axes[0, 0], rewards, "blue", "Episode Reward", "Total Reward")
_plot(axes[0, 1], throughputs, "green", "Throughput", "Avg Throughput (veh/h)")
_plot(axes[0, 2], mean_speeds, "orange", "Mean Speed", "Mean Speed (km/h)")
_plot(axes[0, 3], hard_brakes, "red", "Hard Brakes", "Hard Brakes Count")
axes[1, 0].plot(policy_losses, "b-", alpha=0.6)
axes[1, 0].set_title("Policy Loss"); axes[1, 0].grid(True, alpha=0.3)
axes[1, 1].plot(value_losses, "r-", alpha=0.6)
axes[1, 1].set_title("Value Loss"); axes[1, 1].grid(True, alpha=0.3)
summary = (f"Episodes: {len(rewards)}\nBest: {max(rewards):.2f}\n"
f"Avg(last20): {np.mean(rewards[-20:]):.2f}")
axes[1, 2].axis("off")
axes[1, 2].text(0.1, 0.5, summary, fontsize=12, family="monospace",
verticalalignment="center", transform=axes[1, 2].transAxes)
axes[1, 3].axis("off")
else:
_plot(axes[0, 0], rewards, "blue", "Episode Reward", "Total Reward")
_plot(axes[0, 1], throughputs, "green", "Throughput", "Avg Throughput (veh/h)")
_plot(axes[1, 0], mean_speeds, "orange", "Mean Speed", "Mean Speed (km/h)")
_plot(axes[1, 1], hard_brakes, "red", "Hard Brakes", "Hard Brakes Count")
plt.tight_layout()
plt.savefig(save_path, dpi=150, bbox_inches="tight")
plt.close()
print(f"训练曲线已保存: {save_path}")