96 lines
3.0 KiB
Python
96 lines
3.0 KiB
Python
"""Shared training-curve plotting utilities."""
|
|
import numpy as np
|
|
import matplotlib
|
|
|
|
matplotlib.use("Agg")
|
|
import matplotlib.pyplot as plt
|
|
|
|
|
|
def plot_training_curves(
|
|
rewards,
|
|
throughputs,
|
|
mean_speeds,
|
|
speed_variance_norms,
|
|
ttc_risks,
|
|
policy_losses=None,
|
|
value_losses=None,
|
|
save_path="training_curves.png",
|
|
):
|
|
window = 20
|
|
has_losses = policy_losses and value_losses
|
|
ncols = 4 if has_losses else 3
|
|
nrows = 2
|
|
fig, axes = plt.subplots(nrows, ncols, figsize=(6 * ncols, 10))
|
|
|
|
def _plot(ax, data, color, title, ylabel):
|
|
ax.plot(data, alpha=0.4, color=color)
|
|
if len(data) > window:
|
|
ma = np.convolve(data, np.ones(window) / window, mode="valid")
|
|
ax.plot(range(window - 1, len(data)), ma, "r-", linewidth=2)
|
|
ax.set_xlabel("Episode")
|
|
ax.set_ylabel(ylabel)
|
|
ax.set_title(title)
|
|
ax.grid(True, alpha=0.3)
|
|
|
|
summary = (
|
|
f"Episodes: {len(rewards)}\n"
|
|
f"Best: {max(rewards):.2f}\n"
|
|
f"Avg(last20): {np.mean(rewards[-20:]):.2f}"
|
|
)
|
|
|
|
if has_losses:
|
|
_plot(axes[0, 0], rewards, "blue", "Episode Reward", "Total Reward")
|
|
_plot(axes[0, 1], throughputs, "green", "Throughput", "Avg Throughput (veh/h)")
|
|
_plot(axes[0, 2], mean_speeds, "orange", "Mean Speed", "Mean Speed (km/h)")
|
|
_plot(
|
|
axes[0, 3],
|
|
speed_variance_norms,
|
|
"purple",
|
|
"Normalized Speed Variance",
|
|
"Normalized Variance",
|
|
)
|
|
_plot(axes[1, 0], ttc_risks, "red", "TTC Risk", "Average TTC Risk")
|
|
axes[1, 1].plot(policy_losses, "b-", alpha=0.6)
|
|
axes[1, 1].set_title("Policy Loss")
|
|
axes[1, 1].grid(True, alpha=0.3)
|
|
axes[1, 2].plot(value_losses, "r-", alpha=0.6)
|
|
axes[1, 2].set_title("Value Loss")
|
|
axes[1, 2].grid(True, alpha=0.3)
|
|
axes[1, 3].axis("off")
|
|
axes[1, 3].text(
|
|
0.1,
|
|
0.5,
|
|
summary,
|
|
fontsize=12,
|
|
family="monospace",
|
|
verticalalignment="center",
|
|
transform=axes[1, 3].transAxes,
|
|
)
|
|
else:
|
|
_plot(axes[0, 0], rewards, "blue", "Episode Reward", "Total Reward")
|
|
_plot(axes[0, 1], throughputs, "green", "Throughput", "Avg Throughput (veh/h)")
|
|
_plot(axes[0, 2], mean_speeds, "orange", "Mean Speed", "Mean Speed (km/h)")
|
|
_plot(
|
|
axes[1, 0],
|
|
speed_variance_norms,
|
|
"purple",
|
|
"Normalized Speed Variance",
|
|
"Normalized Variance",
|
|
)
|
|
_plot(axes[1, 1], ttc_risks, "red", "TTC Risk", "Average TTC Risk")
|
|
axes[1, 2].axis("off")
|
|
axes[1, 2].text(
|
|
0.1,
|
|
0.5,
|
|
summary,
|
|
fontsize=12,
|
|
family="monospace",
|
|
verticalalignment="center",
|
|
transform=axes[1, 2].transAxes,
|
|
)
|
|
|
|
plt.tight_layout()
|
|
plt.savefig(save_path, dpi=150, bbox_inches="tight")
|
|
plt.close()
|
|
print(f"训练曲线已保存: {save_path}")
|