"""Shared training-curve plotting utilities.""" import numpy as np import matplotlib matplotlib.use("Agg") import matplotlib.pyplot as plt def plot_training_curves( rewards, throughputs, mean_speeds, speed_stds, hard_brakes, policy_losses=None, value_losses=None, save_path="training_curves.png", ): window = 20 has_losses = policy_losses and value_losses ncols = 4 if has_losses else 3 nrows = 2 fig, axes = plt.subplots(nrows, ncols, figsize=(6 * ncols, 10)) def _plot(ax, data, color, title, ylabel): ax.plot(data, alpha=0.4, color=color) if len(data) > window: ma = np.convolve(data, np.ones(window) / window, mode="valid") ax.plot(range(window - 1, len(data)), ma, "r-", linewidth=2) ax.set_xlabel("Episode") ax.set_ylabel(ylabel) ax.set_title(title) ax.grid(True, alpha=0.3) summary = ( f"Episodes: {len(rewards)}\n" f"Best: {max(rewards):.2f}\n" f"Avg(last20): {np.mean(rewards[-20:]):.2f}" ) if has_losses: _plot(axes[0, 0], rewards, "blue", "Episode Reward", "Total Reward") _plot(axes[0, 1], throughputs, "green", "Throughput", "Avg Throughput (veh/h)") _plot(axes[0, 2], mean_speeds, "orange", "Mean Speed", "Mean Speed (km/h)") _plot(axes[0, 3], speed_stds, "purple", "Speed Std", "Speed Std (km/h)") _plot(axes[1, 0], hard_brakes, "red", "Hard Brakes", "Hard Brakes Count") axes[1, 1].plot(policy_losses, "b-", alpha=0.6) axes[1, 1].set_title("Policy Loss") axes[1, 1].grid(True, alpha=0.3) axes[1, 2].plot(value_losses, "r-", alpha=0.6) axes[1, 2].set_title("Value Loss") axes[1, 2].grid(True, alpha=0.3) axes[1, 3].axis("off") axes[1, 3].text( 0.1, 0.5, summary, fontsize=12, family="monospace", verticalalignment="center", transform=axes[1, 3].transAxes, ) else: _plot(axes[0, 0], rewards, "blue", "Episode Reward", "Total Reward") _plot(axes[0, 1], throughputs, "green", "Throughput", "Avg Throughput (veh/h)") _plot(axes[0, 2], mean_speeds, "orange", "Mean Speed", "Mean Speed (km/h)") _plot(axes[1, 0], speed_stds, "purple", "Speed Std", "Speed Std (km/h)") _plot(axes[1, 1], hard_brakes, "red", "Hard Brakes", "Hard Brakes Count") axes[1, 2].axis("off") axes[1, 2].text( 0.1, 0.5, summary, fontsize=12, family="monospace", verticalalignment="center", transform=axes[1, 2].transAxes, ) plt.tight_layout() plt.savefig(save_path, dpi=150, bbox_inches="tight") plt.close() print(f"训练曲线已保存: {save_path}")