""" DQN训练脚本 - SUMO VSL环境 """ import os import sys import yaml import numpy as np import matplotlib matplotlib.use("Agg") import matplotlib.pyplot as plt from datetime import datetime from tqdm import tqdm from envs.edge_vsl_env import SUMOEdgeVSLEnvironment from agents.dqn_agent import DQNAgent from utils.config import get_agent_config, get_training_config from utils.logger import TrainingLogger from utils.plot import plot_training_curves def train_sumo_dqn(): with open("config_sumo_vsl.yaml", "r", encoding="utf-8") as f: config = yaml.safe_load(f) agent_config = get_agent_config(config, "dqn") train_config = get_training_config(config) start_episode = 1 timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") checkpoint_dir = os.path.join("checkpoints", "dqn", timestamp) log_dir = os.path.join("logs", "dqn", timestamp) os.makedirs(checkpoint_dir, exist_ok=True) os.makedirs(log_dir, exist_ok=True) with open(os.path.join(checkpoint_dir, "config.yaml"), "w", encoding="utf-8") as f: yaml.dump(config, f) logger = TrainingLogger(log_dir, "dqn") env = SUMOEdgeVSLEnvironment(config) state_dim = env.state_dim # DQN使用单个网络处理所有边 num_edges = env.num_edges num_actions_per_edge = env.action_dim print("=" * 70) print("DQN训练 - SUMO VSL环境") print("=" * 70) print(f" 状态维度: {state_dim}") print(f" 控制边数: {num_edges}") print(f" 每边动作数: {num_actions_per_edge}") print(f" Episode步数: {env.episode_length}") print() # 创建单个DQN agent agent = DQNAgent( state_dim=state_dim, num_edges=num_edges, num_actions_per_edge=num_actions_per_edge, hidden_dim=agent_config.get("hidden_dim", 256), learning_rate=agent_config.get("learning_rate", 1e-3), gamma=agent_config.get("gamma", 0.99), epsilon_start=agent_config.get("epsilon_start", 1.0), epsilon_end=agent_config.get("epsilon_end", 0.01), epsilon_decay=agent_config.get("epsilon_decay", 200), buffer_size=agent_config.get("buffer_size", 10000), batch_size=agent_config.get("batch_size", 64), target_update=agent_config.get("target_update", 10), device=agent_config.get("device", "cuda") ) num_episodes = train_config["num_episodes"] save_freq = train_config.get("save_freq", 50) log_freq = train_config.get("log_freq", 10) base_seed = train_config.get("random_seed", 42) episode_rewards = [] episode_throughputs = [] episode_mean_speeds = [] episode_hard_brakes = [] losses = [] best_reward = -float("inf") print("开始训练...\n") try: for episode in range(start_episode, num_episodes + 1): seed = base_seed + episode state = env.reset(seed=seed) episode_reward = 0 episode_throughput = 0 episode_speed = 0 episode_speed_std = 0 episode_r_flow = 0 episode_r_var = 0 episode_r_brake = 0 episode_r_penalty = 0 episode_brakes = 0 done = False step = 0 pbar = tqdm(total=env.episode_length, desc=f"Ep {episode}/{num_episodes}", leave=False) while not done: # 单个agent选择所有边的动作 action = agent.select_action(state) next_state, reward, done, info = env.step(action) # 存储转换并更新 agent.store_transition(state, action, reward, next_state, done) train_stats = agent.update() if train_stats: losses.append(train_stats["loss"]) episode_reward += reward episode_throughput += info["throughput"] episode_speed += info["mean_speed_kmh"] episode_speed_std += info["speed_std"] * 3.6 episode_r_flow += info["r_flow"] episode_r_var += info["r_var"] episode_r_brake += info["r_brake"] episode_r_penalty += info["r_penalty"] episode_brakes += info["num_hard_brakes"] state = next_state step += 1 pbar.set_postfix(r=f"{episode_reward:.1f}", tp=f"{info['throughput']:.0f}", v=f"{info['mean_speed_kmh']:.1f}") pbar.update(1) pbar.close() if episode % agent.target_update == 0: agent.update_target_network() avg_tp = episode_throughput / max(step, 1) avg_speed = episode_speed / max(step, 1) avg_speed_std = episode_speed_std / max(step, 1) avg_r_flow = episode_r_flow / max(step, 1) avg_r_var = episode_r_var / max(step, 1) avg_r_brake = episode_r_brake / max(step, 1) avg_r_penalty = episode_r_penalty / max(step, 1) episode_rewards.append(episode_reward) episode_throughputs.append(avg_tp) episode_mean_speeds.append(avg_speed) episode_hard_brakes.append(episode_brakes) loss_val = np.mean(losses[-100:]) if losses else None logger.log( episode, episode_reward, avg_tp, avg_speed, speed_std=avg_speed_std, r_flow=avg_r_flow, r_var=avg_r_var, r_brake=avg_r_brake, r_penalty=avg_r_penalty, hard_brakes=episode_brakes, value_loss=loss_val, ) if episode_reward > best_reward: best_reward = episode_reward agent.save(os.path.join(checkpoint_dir, "model_best.pt")) if episode % log_freq == 0: recent_rewards = episode_rewards[-log_freq:] print(f"\nEpisode {episode}/{num_episodes}") print(f" Reward: {episode_reward:.2f} (Avg: {np.mean(recent_rewards):.2f})") print(f" Throughput: {avg_tp:.1f} veh/h") print(f" Mean Speed: {avg_speed:.1f} km/h") print(f" Speed Std: {avg_speed_std:.2f} km/h") print(f" R(flow/var/brake/pen): {avg_r_flow:.3f} / {avg_r_var:.3f} / {avg_r_brake:.3f} / {avg_r_penalty:.3f}") if losses: print(f" Loss: {np.mean(losses[-100:]):.4f}") if episode % save_freq == 0: agent.save(os.path.join(checkpoint_dir, f"model_ep{episode}.pt")) except KeyboardInterrupt: print("\n训练被中断") agent.save(os.path.join(checkpoint_dir, "model_interrupted.pt")) finally: env.close() agent.save(os.path.join(checkpoint_dir, f"model_ep{num_episodes}.pt")) # 绘制训练曲线 plot_training_curves( episode_rewards, episode_throughputs, episode_mean_speeds, episode_hard_brakes, save_path=os.path.join(log_dir, "training_curves.png"), ) print("=" * 70) print("训练完成!") print(f" 最佳奖励: {best_reward:.2f}") print(f" 模型目录: {checkpoint_dir}") print(f" 日志目录: {log_dir}") print("=" * 70) if __name__ == "__main__": train_sumo_dqn()