""" 鍩轰簬 SUMO+TraCI 鐨?APPO 璁粌鑴氭湰 浣跨敤寰浠跨湡鐜璁粌 VSL 鎺у埗绛栫暐 """ import os import sys import copy import yaml import numpy as np import matplotlib matplotlib.use("Agg") import matplotlib.pyplot as plt from datetime import datetime from tqdm import tqdm import torch from envs.edge_vsl_env import SUMOEdgeVSLEnvironment from envs.reward_system import REWARD_COMPONENT_COLUMNS, average_reward_components, init_reward_component_totals from agents.appo_agent import APPOAgent from utils.config import get_agent_config, get_training_config from utils.episode_artifacts import save_training_episode_artifacts from utils.logger import TrainingLogger from utils.plot import plot_training_curves from utils.run_dirs import resolve_run_dirs, write_shared_run_config from utils.seeding import derive_seed, resolve_base_seed, set_global_seed def train_sumo_appo(log_dir=None, checkpoint_dir=None, run_timestamp=None): """Train APPO on the SUMO VSL environment.""" with open("config_sumo_vsl.yaml", "r", encoding="utf-8") as f: config = yaml.safe_load(f) agent_config = get_agent_config(config, "appo") train_config = get_training_config(config) base_seed = resolve_base_seed(train_config) set_global_seed(base_seed) start_episode = 1 resolved_run_timestamp, checkpoint_dir, log_dir = resolve_run_dirs( "appo", log_dir=log_dir, checkpoint_dir=checkpoint_dir, run_timestamp=run_timestamp, ) os.makedirs(checkpoint_dir, exist_ok=True) os.makedirs(log_dir, exist_ok=True) runtime_config = copy.deepcopy(config) runtime_config.setdefault("runtime", {})["output_dir"] = log_dir runtime_config["runtime"]["evaluation_mode"] = False runtime_config["runtime"]["run_timestamp"] = resolved_run_timestamp write_shared_run_config( runtime_config, log_dir=log_dir, checkpoint_dir=checkpoint_dir, run_timestamp=run_timestamp, ) logger = TrainingLogger(log_dir, "appo") env = SUMOEdgeVSLEnvironment(runtime_config) state_dim = env.state_dim action_dims = [env.action_dim] * env.num_controlled_edges print("=" * 70) print("APPO training - SUMO+TraCI VSL") print("=" * 70) print(f" State dim: {state_dim}") print(f" Action dims: {action_dims}") print(f" Episode length: {env.episode_length}") print(f" Control interval: {env.control_interval}s") print(f" Learning rate: {agent_config.get('learning_rate', 3e-4)}") print(f" Device: {agent_config.get('device', 'cuda')}") print() agent = APPOAgent( state_dim=state_dim, action_dims=action_dims, edge_feature_dim=env.features_per_edge, total_edge_count=env.num_edges, controlled_start_index=env.controlled_edge_start_index, hidden_dim=agent_config.get("hidden_dim", 128), num_heads=agent_config.get("num_heads", 4), num_layers=agent_config.get("num_layers", 2), learning_rate=agent_config.get("learning_rate", 3e-4), gamma=agent_config.get("gamma", 0.99), gae_lambda=agent_config.get("gae_lambda", 0.95), clip_epsilon=agent_config.get("clip_epsilon", 0.2), value_coef=agent_config.get("value_coef", 0.5), entropy_coef=agent_config.get("entropy_coef", 0.02), max_grad_norm=agent_config.get("max_grad_norm", 0.5), ppo_epochs=agent_config.get("ppo_epochs", 10), minibatch_size=agent_config.get("batch_size", 64), device=agent_config.get("device", "cuda"), lr_schedule=agent_config.get("lr_schedule", "cosine"), total_episodes=train_config["num_episodes"] ) # 璁粌鍙傛暟 num_episodes = train_config["num_episodes"] save_freq = train_config.get("save_freq", 50) log_freq = train_config.get("log_freq", 10) # 缁熻鍙橀噺 episode_rewards = [] episode_throughputs = [] episode_mean_speeds = [] episode_speed_variance_norms = [] episode_ttc_risks = [] policy_losses = [] value_losses = [] entropies = [] best_reward = -float("inf") print("Starting training...\n") try: for episode in range(start_episode, num_episodes + 1): # 姣忎釜 episode 浣跨敤涓嶅悓 seed 寮曞叆闅忔満鎬? seed = derive_seed(base_seed, episode) state = env.reset(seed=seed) episode_reward = 0 episode_throughput = 0 episode_speed = 0 episode_speed_variance_norm = 0.0 episode_reward_components = init_reward_component_totals() episode_ttc_risk = 0.0 done = False step = 0 pbar = tqdm( total=env.episode_length, desc=f"Ep {episode}/{num_episodes}", leave=False, ) while not done: action, log_prob, value = agent.select_action(state, deterministic=False) next_state, reward, done, info = env.step(action) agent.store_transition(state, action, reward, value, log_prob, done) episode_reward += reward episode_throughput += info["throughput"] episode_speed += info["mean_speed_kmh"] episode_speed_variance_norm += info["speed_variance_norm"] for column in REWARD_COMPONENT_COLUMNS: episode_reward_components[column] += float(info.get(column, 0.0)) episode_ttc_risk += float(info.get("ttc_risk_rate", 0.0)) state = next_state step += 1 pbar.set_postfix( r=f"{episode_reward:.1f}", tp=f"{info['throughput']:.0f}", v=f"{info['mean_speed_kmh']:.1f}", ) pbar.update(1) pbar.close() # GAE 璁$畻鍜岀瓥鐣ユ洿鏂? if done: next_value = 0.0 else: with torch.no_grad(): next_state_tensor = torch.FloatTensor(next_state).unsqueeze(0).to(agent.device) next_value = agent.policy.get_value(next_state_tensor).item() train_stats = agent.update(next_value) # 璁板綍缁熻 avg_tp = episode_throughput / max(step, 1) avg_speed = episode_speed / max(step, 1) avg_speed_variance_norm = episode_speed_variance_norm / max(step, 1) avg_reward_components = average_reward_components(episode_reward_components, step) episode_rewards.append(episode_reward) episode_throughputs.append(avg_tp) episode_mean_speeds.append(avg_speed) episode_speed_variance_norms.append(avg_speed_variance_norm) episode_ttc_risks.append(episode_ttc_risk) if train_stats: policy_losses.append(train_stats["policy_loss"]) value_losses.append(train_stats["value_loss"]) entropies.append(train_stats["entropy"]) logger.log( episode, episode_reward, avg_tp, avg_speed, speed_variance_norm=avg_speed_variance_norm, reward_components=avg_reward_components, ttc_risk=episode_ttc_risk, policy_loss=train_stats["policy_loss"], value_loss=train_stats["value_loss"], entropy=train_stats["entropy"], ) else: logger.log( episode, episode_reward, avg_tp, avg_speed, speed_variance_norm=avg_speed_variance_norm, reward_components=avg_reward_components, ttc_risk=episode_ttc_risk, ) # 淇濆瓨鏈€浣虫ā鍨? episode_summary = { "episode": episode, "reward": float(episode_reward), "avg_throughput": float(avg_tp), "avg_mean_speed_kmh": float(avg_speed), "avg_speed_variance_norm": float(avg_speed_variance_norm), "ttc_risk": float(episode_ttc_risk), } for column, value in avg_reward_components.items(): episode_summary[f"avg_{column}"] = float(value) if train_stats: episode_summary.update( policy_loss=float(train_stats["policy_loss"]), value_loss=float(train_stats["value_loss"]), entropy=float(train_stats["entropy"]), ) save_training_episode_artifacts( log_dir=log_dir, episode=episode, episode_metrics=env.episode_metrics, control_edges=env.control_edges, summary=episode_summary, ) if episode_reward > best_reward: best_reward = episode_reward agent.save(os.path.join(checkpoint_dir, "model_best.pt")) # 瀹氭湡鏃ュ織 if episode % log_freq == 0: recent_rewards = episode_rewards[-log_freq:] print(f"\nEpisode {episode}/{num_episodes}") print(f" Reward: {episode_reward:.2f} (Avg: {np.mean(recent_rewards):.2f})") print(f" Throughput: {avg_tp:.1f} veh/h") print(f" Mean Speed: {avg_speed:.1f} km/h") print(f" Normalized Speed Variance: {avg_speed_variance_norm:.4f}") print( " Reward Components: " + ", ".join( f"{column}={avg_reward_components[column]:.3f}" for column in REWARD_COMPONENT_COLUMNS ) ) if train_stats: print(f" Policy Loss: {train_stats['policy_loss']:.4f}") print(f" Value Loss: {train_stats['value_loss']:.4f}") print(f" Entropy: {train_stats['entropy']:.4f}") # 瀹氭湡淇濆瓨 if episode % save_freq == 0: agent.save(os.path.join(checkpoint_dir, f"model_ep{episode}.pt")) except KeyboardInterrupt: print("\nTraining interrupted, saving current model...") agent.save(os.path.join(checkpoint_dir, "model_interrupted.pt")) finally: env.close() # 鏈€缁堜繚瀛? agent.save(os.path.join(checkpoint_dir, f"model_ep{num_episodes}.pt")) # 缁樺埗璁粌鏇茬嚎 plot_training_curves( episode_rewards, episode_throughputs, episode_mean_speeds, episode_speed_variance_norms, episode_ttc_risks, policy_losses, value_losses, save_path=os.path.join(log_dir, "training_curves.png"), ) print("=" * 70) print("Training complete") print(f" Best reward: {best_reward:.2f}") print(f" Model dir: {checkpoint_dir}") print(f" Log dir: {log_dir}") print("=" * 70)