""" 鍩轰簬 SUMO+TraCI 鐨?TD3 璁粌鑴氭湰 浣跨敤 Stable-Baselines3 鐨?TD3 绠楁硶 """ import os import copy import yaml import numpy as np import matplotlib matplotlib.use("Agg") import matplotlib.pyplot as plt from datetime import datetime from tqdm import tqdm from envs.edge_vsl_env import SUMOEdgeVSLEnvironment from envs.reward_system import REWARD_COMPONENT_COLUMNS, average_reward_components, init_reward_component_totals from agents.td3_agent import TD3Agent from utils.config import get_agent_config, get_training_config from utils.episode_artifacts import save_training_episode_artifacts from utils.logger import TrainingLogger from utils.plot import plot_training_curves from utils.run_dirs import resolve_run_dirs, write_shared_run_config from utils.seeding import derive_seed, resolve_base_seed, set_global_seed def train_sumo_td3( log_dir=None, checkpoint_dir=None, run_timestamp=None, model_name: str = "td3", config_key: str = "td3", display_name: str = "TD3", agent_class=TD3Agent, ): """Train TD3 on the SUMO VSL environment.""" with open("config_sumo_vsl.yaml", "r", encoding="utf-8") as f: config = yaml.safe_load(f) agent_config = get_agent_config(config, config_key) train_config = get_training_config(config) base_seed = resolve_base_seed(train_config) set_global_seed(base_seed) resolved_run_timestamp, checkpoint_dir, log_dir = resolve_run_dirs( model_name, log_dir=log_dir, checkpoint_dir=checkpoint_dir, run_timestamp=run_timestamp, ) os.makedirs(checkpoint_dir, exist_ok=True) os.makedirs(log_dir, exist_ok=True) runtime_config = copy.deepcopy(config) runtime_config.setdefault("runtime", {})["output_dir"] = log_dir runtime_config["runtime"]["evaluation_mode"] = False runtime_config["runtime"]["run_timestamp"] = resolved_run_timestamp write_shared_run_config( runtime_config, log_dir=log_dir, checkpoint_dir=checkpoint_dir, run_timestamp=run_timestamp, ) logger = TrainingLogger(log_dir, model_name) env = SUMOEdgeVSLEnvironment(runtime_config) state_dim = env.state_dim action_dims = [env.action_dim] * env.num_controlled_edges print("=" * 70) print(f"{display_name} training - SUMO+TraCI VSL") print("=" * 70) print(f" State dim: {state_dim}") print(f" Action dims: {action_dims}") print(f" Episode length: {env.episode_length}") print(f" Control interval: {env.control_interval}s") print(f" Learning rate: {agent_config.get('learning_rate', 3e-4)}") print(f" Device: {agent_config.get('device', 'cuda')}") print() common_kwargs = dict( state_dim=state_dim, action_dims=action_dims, learning_rate=agent_config.get("learning_rate", 3e-4), buffer_size=agent_config.get("buffer_size", 100000), learning_starts=agent_config.get("learning_starts", 1000), batch_size=agent_config.get("batch_size", 256), tau=agent_config.get("tau", 0.005), gamma=agent_config.get("gamma", 0.99), exploration_sigma=agent_config.get("exploration_sigma", 0.1), device=agent_config.get("device", "cuda"), actor_hidden_dims=agent_config.get("actor_hidden_dims"), critic_hidden_dims=agent_config.get("critic_hidden_dims"), activation_fn=agent_config.get("activation_fn", "relu"), ) if "policy_delay" in agent_config: common_kwargs["policy_delay"] = agent_config.get("policy_delay", 2) if config_key == "sctd3": common_kwargs.update( edge_feature_dim=env.features_per_edge, total_edge_count=env.num_edges, controlled_start_index=env.controlled_edge_start_index, extractor_feature_dim=agent_config.get("extractor_feature_dim", 128), extractor_edge_hidden_dim=agent_config.get("extractor_edge_hidden_dim", 16), extractor_global_hidden_dim=agent_config.get("extractor_global_hidden_dim", 32), extractor_spatial_blocks=agent_config.get("extractor_spatial_blocks", 1), extractor_kernel_size=agent_config.get("extractor_kernel_size", 3), ) common_kwargs["seed"] = base_seed agent = agent_class(**common_kwargs) num_episodes = train_config["num_episodes"] save_freq = train_config.get("save_freq", 50) log_freq = train_config.get("log_freq", 10) episode_rewards = [] episode_throughputs = [] episode_mean_speeds = [] episode_speed_variance_norms = [] episode_ttc_risks = [] best_reward = -float("inf") print("Starting training...\n") try: for episode in range(1, num_episodes + 1): seed = derive_seed(base_seed, episode) state = env.reset(seed=seed) episode_reward = 0 episode_throughput = 0 episode_speed = 0 episode_speed_variance_norm = 0.0 episode_reward_components = init_reward_component_totals() episode_ttc_risk = 0.0 done = False step = 0 pbar = tqdm( total=env.episode_length, desc=f"Ep {episode}/{num_episodes}", leave=False, ) while not done: action, _, _ = agent.select_action(state, deterministic=False) next_state, reward, done, info = env.step(action) agent.store_transition(state, action, reward, next_state, done) agent.update() episode_reward += reward episode_throughput += info["throughput"] episode_speed += info["mean_speed_kmh"] episode_speed_variance_norm += info["speed_variance_norm"] for column in REWARD_COMPONENT_COLUMNS: episode_reward_components[column] += float(info.get(column, 0.0)) episode_ttc_risk += float(info.get("ttc_risk_rate", 0.0)) state = next_state step += 1 pbar.set_postfix( r=f"{episode_reward:.1f}", tp=f"{info['throughput']:.0f}", v=f"{info['mean_speed_kmh']:.1f}", ) pbar.update(1) pbar.close() avg_tp = episode_throughput / max(step, 1) avg_speed = episode_speed / max(step, 1) avg_speed_variance_norm = episode_speed_variance_norm / max(step, 1) avg_reward_components = average_reward_components(episode_reward_components, step) episode_rewards.append(episode_reward) episode_throughputs.append(avg_tp) episode_mean_speeds.append(avg_speed) episode_speed_variance_norms.append(avg_speed_variance_norm) episode_ttc_risks.append(episode_ttc_risk) logger.log( episode, episode_reward, avg_tp, avg_speed, speed_variance_norm=avg_speed_variance_norm, reward_components=avg_reward_components, ttc_risk=episode_ttc_risk, ) episode_summary = { "episode": episode, "reward": float(episode_reward), "avg_throughput": float(avg_tp), "avg_mean_speed_kmh": float(avg_speed), "avg_speed_variance_norm": float(avg_speed_variance_norm), "ttc_risk": float(episode_ttc_risk), } for column, value in avg_reward_components.items(): episode_summary[f"avg_{column}"] = float(value) save_training_episode_artifacts( log_dir=log_dir, episode=episode, episode_metrics=env.episode_metrics, control_edges=env.control_edges, summary=episode_summary, ) if episode_reward > best_reward: best_reward = episode_reward agent.save(os.path.join(checkpoint_dir, "model_best")) if episode % log_freq == 0: recent_rewards = episode_rewards[-log_freq:] print(f"\nEpisode {episode}/{num_episodes}") print(f" Reward: {episode_reward:.2f} (Avg: {np.mean(recent_rewards):.2f})") print(f" Throughput: {avg_tp:.1f} veh/h") print(f" Mean Speed: {avg_speed:.1f} km/h") print(f" Normalized Speed Variance: {avg_speed_variance_norm:.4f}") print( " Reward Components: " + ", ".join( f"{column}={avg_reward_components[column]:.3f}" for column in REWARD_COMPONENT_COLUMNS ) ) if episode % save_freq == 0: agent.save(os.path.join(checkpoint_dir, f"model_ep{episode}")) except KeyboardInterrupt: print("\nTraining interrupted, saving current model...") agent.save(os.path.join(checkpoint_dir, "model_interrupted")) finally: env.close() agent.save(os.path.join(checkpoint_dir, f"model_ep{num_episodes}")) plot_training_curves( episode_rewards, episode_throughputs, episode_mean_speeds, episode_speed_variance_norms, episode_ttc_risks, save_path=os.path.join(log_dir, "training_curves.png"), ) print("=" * 70) print("Training complete") print(f" Best reward: {best_reward:.2f}") print(f" Model dir: {checkpoint_dir}") print(f" Log dir: {log_dir}") print("=" * 70)