""" Temporal Credit Assignment MAPPO training script for SUMO + TraCI VSL. """ import copy import os import matplotlib import numpy as np import yaml from tqdm import tqdm matplotlib.use("Agg") from agents.tcamappo_agent import TCAMAPPOAgent from envs.edge_vsl_env import SUMOEdgeVSLEnvironment from utils.config import get_agent_config, get_training_config from utils.episode_artifacts import save_training_episode_artifacts from utils.logger import TrainingLogger from utils.plot import plot_training_curves from utils.run_dirs import resolve_run_dirs def train_sumo_tcamappo(log_dir=None, checkpoint_dir=None, run_timestamp=None): with open("config_sumo_vsl.yaml", "r", encoding="utf-8") as f: config = yaml.safe_load(f) agent_config = get_agent_config(config, "tcamappo") train_config = get_training_config(config) _, checkpoint_dir, log_dir = resolve_run_dirs( "tcamappo", log_dir=log_dir, checkpoint_dir=checkpoint_dir, run_timestamp=run_timestamp, ) os.makedirs(checkpoint_dir, exist_ok=True) os.makedirs(log_dir, exist_ok=True) runtime_config = copy.deepcopy(config) runtime_config.setdefault("runtime", {})["output_dir"] = log_dir with open(os.path.join(checkpoint_dir, "config.yaml"), "w", encoding="utf-8") as f: yaml.dump(runtime_config, f) logger = TrainingLogger(log_dir, "tcamappo") env = SUMOEdgeVSLEnvironment(runtime_config) print("=" * 70) print("TCA-MAPPO training - SUMO+TraCI VSL environment") print("=" * 70) print(f" State dim: {env.state_dim}") print(f" Agents: {env.num_edges}") print(f" Actions per agent: {env.action_dim}") print(f" Episode steps: {env.episode_length}") print(f" Control interval: {env.control_interval}s") print(f" Hidden dim: {agent_config.get('hidden_dim', 256)}") print(f" History window: {agent_config.get('history_window', 6)}") print(f" Critic heads/layers: {agent_config.get('critic_num_heads', 4)}/{agent_config.get('critic_num_layers', 2)}") print(f" LR: {agent_config.get('learning_rate', 3e-4)}") print(f" Device: {agent_config.get('device', 'cuda')}") print() agent = TCAMAPPOAgent( state_dim=env.state_dim, num_agents=env.num_edges, num_actions=env.action_dim, edge_feature_dim=env.features_per_edge, hidden_dim=agent_config.get("hidden_dim", 256), critic_hidden_dim=agent_config.get("critic_hidden_dim", 256), history_window=agent_config.get("history_window", 6), critic_num_heads=agent_config.get("critic_num_heads", 4), critic_num_layers=agent_config.get("critic_num_layers", 2), critic_dropout=agent_config.get("critic_dropout", 0.05), learning_rate=agent_config.get("learning_rate", 3e-4), gamma=agent_config.get("gamma", 0.99), gae_lambda=agent_config.get("gae_lambda", 0.95), clip_epsilon=agent_config.get("clip_epsilon", 0.2), value_coef=agent_config.get("value_coef", 0.5), entropy_coef=agent_config.get("entropy_coef", 0.01), max_grad_norm=agent_config.get("max_grad_norm", 0.5), ppo_epochs=agent_config.get("ppo_epochs", 4), minibatch_size=agent_config.get("batch_size", 15), device=agent_config.get("device", "cuda"), lr_schedule=agent_config.get("lr_schedule", "cosine"), total_episodes=train_config["num_episodes"], ) num_episodes = train_config["num_episodes"] save_freq = train_config.get("save_freq", 50) log_freq = train_config.get("log_freq", 10) base_seed = train_config.get("random_seed", 42) episode_rewards = [] episode_throughputs = [] episode_mean_speeds = [] episode_speed_stds = [] episode_hard_brakes = [] policy_losses = [] value_losses = [] entropies = [] best_reward = -float("inf") print("Starting training...\n") try: for episode in range(1, num_episodes + 1): seed = base_seed + episode state = env.reset(seed=seed) agent.reset_episode() episode_reward = 0.0 episode_throughput = 0.0 episode_speed = 0.0 episode_speed_std = 0.0 episode_r_flow = 0.0 episode_r_var = 0.0 episode_r_brake = 0.0 episode_r_penalty = 0.0 episode_brakes = 0 done = False step = 0 pbar = tqdm(total=env.episode_length, desc=f"Ep {episode}/{num_episodes}", leave=False) while not done: history_token_stack = agent._get_history_stack().copy() action, log_probs, value = agent.select_action(state, deterministic=False) next_state, reward, done, info = env.step(action) agent.store_transition(state, history_token_stack, action, reward, value, log_probs, done) agent.update_temporal_context(state, action, reward, info) episode_reward += reward episode_throughput += info["throughput"] episode_speed += info["mean_speed_kmh"] episode_speed_std += info["speed_std"] * 3.6 episode_r_flow += info["r_flow"] episode_r_var += info["r_var"] episode_r_brake += info["r_brake"] episode_r_penalty += info["r_penalty"] episode_brakes += info["num_hard_brakes"] state = next_state step += 1 pbar.set_postfix( r=f"{episode_reward:.1f}", tp=f"{info['throughput']:.0f}", v=f"{info['mean_speed_kmh']:.1f}", ) pbar.update(1) pbar.close() next_value = 0.0 if done else agent.get_value(next_state) train_stats = agent.update(next_value) avg_tp = episode_throughput / max(step, 1) avg_speed = episode_speed / max(step, 1) avg_speed_std = episode_speed_std / max(step, 1) avg_r_flow = episode_r_flow / max(step, 1) avg_r_var = episode_r_var / max(step, 1) avg_r_brake = episode_r_brake / max(step, 1) avg_r_penalty = episode_r_penalty / max(step, 1) episode_rewards.append(episode_reward) episode_throughputs.append(avg_tp) episode_mean_speeds.append(avg_speed) episode_speed_stds.append(avg_speed_std) episode_hard_brakes.append(episode_brakes) if train_stats: policy_losses.append(train_stats["policy_loss"]) value_losses.append(train_stats["value_loss"]) entropies.append(train_stats["entropy"]) logger.log( episode, episode_reward, avg_tp, avg_speed, speed_std=avg_speed_std, r_flow=avg_r_flow, r_var=avg_r_var, r_brake=avg_r_brake, r_penalty=avg_r_penalty, hard_brakes=episode_brakes, policy_loss=train_stats["policy_loss"], value_loss=train_stats["value_loss"], entropy=train_stats["entropy"], ) else: logger.log( episode, episode_reward, avg_tp, avg_speed, speed_std=avg_speed_std, r_flow=avg_r_flow, r_var=avg_r_var, r_brake=avg_r_brake, r_penalty=avg_r_penalty, hard_brakes=episode_brakes, ) episode_summary = { "episode": episode, "reward": float(episode_reward), "avg_throughput": float(avg_tp), "avg_mean_speed_kmh": float(avg_speed), "avg_speed_std_kmh": float(avg_speed_std), "avg_r_flow": float(avg_r_flow), "avg_r_var": float(avg_r_var), "avg_r_brake": float(avg_r_brake), "avg_r_penalty": float(avg_r_penalty), "hard_brakes": int(episode_brakes), } if train_stats: episode_summary.update( policy_loss=float(train_stats["policy_loss"]), value_loss=float(train_stats["value_loss"]), entropy=float(train_stats["entropy"]), ) save_training_episode_artifacts( log_dir=log_dir, episode=episode, episode_metrics=env.episode_metrics, control_edges=env.control_edges, summary=episode_summary, ) if episode_reward > best_reward: best_reward = episode_reward agent.save(os.path.join(checkpoint_dir, "model_best.pt")) if episode % log_freq == 0: recent_rewards = episode_rewards[-log_freq:] print(f"\nEpisode {episode}/{num_episodes}") print(f" Reward: {episode_reward:.2f} (Avg: {np.mean(recent_rewards):.2f})") print(f" Throughput: {avg_tp:.1f} veh/h") print(f" Mean Speed: {avg_speed:.1f} km/h") print(f" Speed Std: {avg_speed_std:.2f} km/h") print( f" R(flow/var/brake/pen): " f"{avg_r_flow:.3f} / {avg_r_var:.3f} / {avg_r_brake:.3f} / {avg_r_penalty:.3f}" ) if train_stats: print(f" Policy Loss: {train_stats['policy_loss']:.4f}") print(f" Value Loss: {train_stats['value_loss']:.4f}") print(f" Entropy: {train_stats['entropy']:.4f}") if episode % save_freq == 0: agent.save(os.path.join(checkpoint_dir, f"model_ep{episode}.pt")) except KeyboardInterrupt: print("\nTraining interrupted, saving current model...") agent.save(os.path.join(checkpoint_dir, "model_interrupted.pt")) finally: env.close() agent.save(os.path.join(checkpoint_dir, f"model_ep{num_episodes}.pt")) plot_training_curves( episode_rewards, episode_throughputs, episode_mean_speeds, episode_speed_stds, episode_hard_brakes, policy_losses, value_losses, save_path=os.path.join(log_dir, "training_curves.png"), ) print("=" * 70) print("Training complete") print(f" Best reward: {best_reward:.2f}") print(f" Model dir: {checkpoint_dir}") print(f" Log dir: {log_dir}") print("=" * 70)