ctm-dqn/utils/plot.py

96 lines
3.0 KiB
Python

"""Shared training-curve plotting utilities."""
import numpy as np
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
def plot_training_curves(
rewards,
throughputs,
mean_speeds,
speed_variance_norms,
ttc_risks,
policy_losses=None,
value_losses=None,
save_path="training_curves.png",
):
window = 20
has_losses = policy_losses and value_losses
ncols = 4 if has_losses else 3
nrows = 2
fig, axes = plt.subplots(nrows, ncols, figsize=(6 * ncols, 10))
def _plot(ax, data, color, title, ylabel):
ax.plot(data, alpha=0.4, color=color)
if len(data) > window:
ma = np.convolve(data, np.ones(window) / window, mode="valid")
ax.plot(range(window - 1, len(data)), ma, "r-", linewidth=2)
ax.set_xlabel("Episode")
ax.set_ylabel(ylabel)
ax.set_title(title)
ax.grid(True, alpha=0.3)
summary = (
f"Episodes: {len(rewards)}\n"
f"Best: {max(rewards):.2f}\n"
f"Avg(last20): {np.mean(rewards[-20:]):.2f}"
)
if has_losses:
_plot(axes[0, 0], rewards, "blue", "Episode Reward", "Total Reward")
_plot(axes[0, 1], throughputs, "green", "Throughput", "Avg Throughput (veh/h)")
_plot(axes[0, 2], mean_speeds, "orange", "Mean Speed", "Mean Speed (km/h)")
_plot(
axes[0, 3],
speed_variance_norms,
"purple",
"Normalized Speed Variance",
"Normalized Variance",
)
_plot(axes[1, 0], ttc_risks, "red", "TTC Risk", "Average TTC Risk")
axes[1, 1].plot(policy_losses, "b-", alpha=0.6)
axes[1, 1].set_title("Policy Loss")
axes[1, 1].grid(True, alpha=0.3)
axes[1, 2].plot(value_losses, "r-", alpha=0.6)
axes[1, 2].set_title("Value Loss")
axes[1, 2].grid(True, alpha=0.3)
axes[1, 3].axis("off")
axes[1, 3].text(
0.1,
0.5,
summary,
fontsize=12,
family="monospace",
verticalalignment="center",
transform=axes[1, 3].transAxes,
)
else:
_plot(axes[0, 0], rewards, "blue", "Episode Reward", "Total Reward")
_plot(axes[0, 1], throughputs, "green", "Throughput", "Avg Throughput (veh/h)")
_plot(axes[0, 2], mean_speeds, "orange", "Mean Speed", "Mean Speed (km/h)")
_plot(
axes[1, 0],
speed_variance_norms,
"purple",
"Normalized Speed Variance",
"Normalized Variance",
)
_plot(axes[1, 1], ttc_risks, "red", "TTC Risk", "Average TTC Risk")
axes[1, 2].axis("off")
axes[1, 2].text(
0.1,
0.5,
summary,
fontsize=12,
family="monospace",
verticalalignment="center",
transform=axes[1, 2].transAxes,
)
plt.tight_layout()
plt.savefig(save_path, dpi=150, bbox_inches="tight")
plt.close()
print(f"训练曲线已保存: {save_path}")