""" 基于 SUMO+TraCI 的 PPO 训练脚本 使用微观仿真环境训练 VSL 控制策略 """ import os import sys import yaml import numpy as np import matplotlib matplotlib.use("Agg") import matplotlib.pyplot as plt from datetime import datetime from tqdm import tqdm import torch from sumo_edge_vsl_environment import SUMOEdgeVSLEnvironment from ppo_agent import PPOAgent from training_logger import TrainingLogger def train_sumo_ppo(): """SUMO 环境下的 PPO 训练主函数""" # 加载配置 with open("config_sumo_vsl.yaml", "r", encoding="utf-8") as f: config = yaml.safe_load(f) agent_config = config.get("agent", {}) train_config = config["training"] start_episode = 1 timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") checkpoint_dir = os.path.join("checkpoints_sumo_ppo", timestamp) log_dir = os.path.join("logs_sumo_ppo", timestamp) os.makedirs(checkpoint_dir, exist_ok=True) os.makedirs(log_dir, exist_ok=True) with open(os.path.join(checkpoint_dir, "config.yaml"), "w", encoding="utf-8") as f: yaml.dump(config, f) logger = TrainingLogger(log_dir, "ppo") # 创建环境 env = SUMOEdgeVSLEnvironment(config) state_dim = env.state_dim action_dims = [env.action_dim] * env.num_edges print("=" * 70) print("PPO 训练 - SUMO+TraCI VSL 环境") print("=" * 70) print(f" 状态维度: {state_dim}") print(f" 控制边数: {env.num_edges}") print(f" 每边动作数: {env.action_dim}") print(f" Episode 步数: {env.episode_length}") print(f" 控制间隔: {env.control_interval}s") print(f" 隐藏层: {agent_config.get('hidden_layers', [512, 256])}") print(f" 学习率: {agent_config.get('learning_rate', 3e-4)}") print(f" 设备: {agent_config.get('device', 'cuda')}") print() # 创建 PPO 智能体 agent = PPOAgent( state_dim=state_dim, action_dims=action_dims, hidden_layers=agent_config.get("hidden_layers", [512, 256]), learning_rate=agent_config.get("learning_rate", 3e-4), gamma=agent_config.get("gamma", 0.99), gae_lambda=agent_config.get("gae_lambda", 0.95), clip_epsilon=agent_config.get("clip_epsilon", 0.2), value_coef=agent_config.get("value_coef", 0.5), entropy_coef=agent_config.get("entropy_coef", 0.02), max_grad_norm=agent_config.get("max_grad_norm", 0.5), ppo_epochs=agent_config.get("ppo_epochs", 10), minibatch_size=agent_config.get("batch_size", 64), device=agent_config.get("device", "cuda"), ) # 训练参数 num_episodes = train_config["num_episodes"] save_freq = train_config.get("save_freq", 50) log_freq = train_config.get("log_freq", 10) base_seed = train_config.get("random_seed", 42) # 统计变量 episode_rewards = [] episode_throughputs = [] episode_mean_speeds = [] episode_hard_brakes = [] policy_losses = [] value_losses = [] entropies = [] best_reward = -float("inf") print("开始训练...\n") try: for episode in range(start_episode, num_episodes + 1): # 每个 episode 使用不同 seed 引入随机性 seed = base_seed + episode state = env.reset(seed=seed) episode_reward = 0 episode_throughput = 0 episode_speed = 0 episode_brakes = 0 done = False step = 0 pbar = tqdm( total=env.episode_length, desc=f"Ep {episode}/{num_episodes}", leave=False, ) while not done: action, log_prob, value = agent.select_action(state, deterministic=False) next_state, reward, done, info = env.step(action) agent.store_transition(state, action, reward, value, log_prob, done) episode_reward += reward episode_throughput += info["throughput"] episode_speed += info["mean_speed_kmh"] episode_brakes += info["num_hard_brakes"] state = next_state step += 1 pbar.set_postfix( r=f"{episode_reward:.1f}", tp=f"{info['throughput']:.0f}", v=f"{info['mean_speed_kmh']:.1f}", ) pbar.update(1) pbar.close() # GAE 计算和策略更新 if done: next_value = 0.0 else: with torch.no_grad(): next_state_tensor = torch.FloatTensor(next_state).unsqueeze(0).to(agent.device) next_value = agent.policy.get_value(next_state_tensor).item() train_stats = agent.update(next_value) # 记录统计 avg_tp = episode_throughput / max(step, 1) avg_speed = episode_speed / max(step, 1) episode_rewards.append(episode_reward) episode_throughputs.append(avg_tp) episode_mean_speeds.append(avg_speed) episode_hard_brakes.append(episode_brakes) if train_stats: policy_losses.append(train_stats["policy_loss"]) value_losses.append(train_stats["value_loss"]) entropies.append(train_stats["entropy"]) logger.log(episode, episode_reward, avg_tp, avg_speed, episode_brakes, train_stats["policy_loss"], train_stats["value_loss"], train_stats["entropy"]) else: logger.log(episode, episode_reward, avg_tp, avg_speed, episode_brakes) # 保存最佳模型 if episode_reward > best_reward: best_reward = episode_reward agent.save(os.path.join(checkpoint_dir, "model_best.pt")) # 定期日志 if episode % log_freq == 0: recent_rewards = episode_rewards[-log_freq:] print(f"\nEpisode {episode}/{num_episodes}") print(f" Reward: {episode_reward:.2f} (Avg: {np.mean(recent_rewards):.2f})") print(f" Throughput: {avg_tp:.1f} veh/h") print(f" Mean Speed: {avg_speed:.1f} km/h") if train_stats: print(f" Policy Loss: {train_stats['policy_loss']:.4f}") print(f" Value Loss: {train_stats['value_loss']:.4f}") print(f" Entropy: {train_stats['entropy']:.4f}") # 定期保存 if episode % save_freq == 0: agent.save(os.path.join(checkpoint_dir, f"model_ep{episode}.pt")) except KeyboardInterrupt: print("\n训练被中断,保存当前模型...") agent.save(os.path.join(checkpoint_dir, "model_interrupted.pt")) finally: env.close() # 最终保存 agent.save(os.path.join(checkpoint_dir, f"model_ep{num_episodes}.pt")) # 绘制训练曲线 _plot_training_curves( episode_rewards, episode_throughputs, episode_mean_speeds, episode_hard_brakes, policy_losses, value_losses, save_path=os.path.join(log_dir, "training_curves.png"), ) print("=" * 70) print("训练完成!") print(f" 最佳奖励: {best_reward:.2f}") print(f" 模型目录: {checkpoint_dir}") print(f" 日志目录: {log_dir}") print("=" * 70) def _plot_training_curves( rewards, throughputs, mean_speeds, hard_brakes, policy_losses, value_losses, save_path: str, ): """绘制训练曲线""" fig, axes = plt.subplots(2, 4, figsize=(24, 10)) window = 20 # Rewards axes[0, 0].plot(rewards, alpha=0.4, color="blue") if len(rewards) > window: ma = np.convolve(rewards, np.ones(window) / window, mode="valid") axes[0, 0].plot(range(window - 1, len(rewards)), ma, "r-", linewidth=2) axes[0, 0].set_xlabel("Episode") axes[0, 0].set_ylabel("Total Reward") axes[0, 0].set_title("Episode Reward") axes[0, 0].grid(True, alpha=0.3) # Throughput axes[0, 1].plot(throughputs, alpha=0.4, color="green") if len(throughputs) > window: ma = np.convolve(throughputs, np.ones(window) / window, mode="valid") axes[0, 1].plot(range(window - 1, len(throughputs)), ma, "r-", linewidth=2) axes[0, 1].set_xlabel("Episode") axes[0, 1].set_ylabel("Avg Throughput (veh/h)") axes[0, 1].set_title("Throughput") axes[0, 1].grid(True, alpha=0.3) # Mean Speed axes[0, 2].plot(mean_speeds, alpha=0.4, color="orange") if len(mean_speeds) > window: ma = np.convolve(mean_speeds, np.ones(window) / window, mode="valid") axes[0, 2].plot(range(window - 1, len(mean_speeds)), ma, "r-", linewidth=2) axes[0, 2].set_xlabel("Episode") axes[0, 2].set_ylabel("Mean Speed (km/h)") axes[0, 2].set_title("Mean Speed") axes[0, 2].grid(True, alpha=0.3) # Hard Brakes axes[0, 3].plot(hard_brakes, alpha=0.4, color="red") if len(hard_brakes) > window: ma = np.convolve(hard_brakes, np.ones(window) / window, mode="valid") axes[0, 3].plot(range(window - 1, len(hard_brakes)), ma, "r-", linewidth=2) axes[0, 3].set_xlabel("Episode") axes[0, 3].set_ylabel("Hard Brakes Count") axes[0, 3].set_title("Hard Brakes") axes[0, 3].grid(True, alpha=0.3) # Policy Loss if policy_losses: axes[1, 0].plot(policy_losses, "b-", alpha=0.6) axes[1, 0].set_xlabel("Episode") axes[1, 0].set_ylabel("Policy Loss") axes[1, 0].set_title("Policy Loss") axes[1, 0].grid(True, alpha=0.3) # Value Loss if value_losses: axes[1, 1].plot(value_losses, "r-", alpha=0.6) axes[1, 1].set_xlabel("Episode") axes[1, 1].set_ylabel("Value Loss") axes[1, 1].set_title("Value Loss") axes[1, 1].grid(True, alpha=0.3) # Summary text axes[1, 2].axis("off") summary = ( f"Training Summary\n" f"{'='*30}\n" f"Episodes: {len(rewards)}\n" f"Best Reward: {max(rewards):.2f}\n" f"Final Avg Reward: {np.mean(rewards[-20:]):.2f}\n" f"Final Avg Throughput: {np.mean(throughputs[-20:]):.1f}\n" f"Final Avg Speed: {np.mean(mean_speeds[-20:]):.1f} km/h\n" f"Final Avg Hard Brakes: {np.mean(hard_brakes[-20:]):.1f}" ) axes[1, 2].text(0.1, 0.5, summary, fontsize=12, family="monospace", verticalalignment="center", transform=axes[1, 2].transAxes) axes[1, 3].axis("off") plt.tight_layout() plt.savefig(save_path, dpi=150, bbox_inches="tight") print(f"训练曲线已保存: {save_path}") if __name__ == "__main__": train_sumo_ppo()