"""Shared training loop for value-based VSL agents.""" from __future__ import annotations import copy import inspect import os from typing import Callable import matplotlib import numpy as np import yaml from tqdm import tqdm from envs.edge_vsl_env import SUMOEdgeVSLEnvironment from envs.reward_system import REWARD_COMPONENT_COLUMNS, average_reward_components, init_reward_component_totals from utils.config import get_agent_config, get_training_config from utils.episode_artifacts import save_training_episode_artifacts from utils.logger import TrainingLogger from utils.plot import plot_training_curves from utils.run_dirs import resolve_run_dirs, write_shared_run_config from utils.seeding import derive_seed, resolve_base_seed, set_global_seed matplotlib.use("Agg") def _build_value_based_agent(agent_builder: Callable[..., object], env, agent_config: dict): candidate_kwargs = { "state_dim": env.state_dim, "num_edges": env.num_controlled_edges, "num_actions_per_edge": env.action_dim, "hidden_dim": agent_config.get("hidden_dim", 256), "mixing_hidden_dim": agent_config.get( "mixing_hidden_dim", agent_config.get("hidden_dim", 256), ), "learning_rate": agent_config.get("learning_rate", 1e-3), "gamma": agent_config.get("gamma", 0.99), "epsilon_start": agent_config.get("epsilon_start", 1.0), "epsilon_end": agent_config.get("epsilon_end", 0.01), "epsilon_decay": agent_config.get("epsilon_decay", 200), "buffer_size": agent_config.get("buffer_size", 10000), "batch_size": agent_config.get("batch_size", 64), "target_update": agent_config.get("target_update", 10), "device": agent_config.get("device", "cuda"), "edge_feature_dim": env.features_per_edge, "time_feature_dim": 3, "total_edge_count": env.num_edges, "controlled_start_index": env.controlled_edge_start_index, "num_corridor_blocks": agent_config.get("num_corridor_blocks", 2), "corridor_kernel_size": agent_config.get("corridor_kernel_size", 5), "corridor_dropout": agent_config.get("corridor_dropout", 0.05), } accepted = inspect.signature(agent_builder).parameters filtered_kwargs = { key: value for key, value in candidate_kwargs.items() if key in accepted } return agent_builder(**filtered_kwargs) def train_sumo_value_based( model_key: str, model_label: str, agent_builder: Callable[..., object], *, log_dir=None, checkpoint_dir=None, run_timestamp=None, ): with open("config_sumo_vsl.yaml", "r", encoding="utf-8") as f: config = yaml.safe_load(f) agent_config = get_agent_config(config, model_key) train_config = get_training_config(config) base_seed = resolve_base_seed(train_config) set_global_seed(base_seed) resolved_run_timestamp, checkpoint_dir, log_dir = resolve_run_dirs( model_key, log_dir=log_dir, checkpoint_dir=checkpoint_dir, run_timestamp=run_timestamp, ) os.makedirs(checkpoint_dir, exist_ok=True) os.makedirs(log_dir, exist_ok=True) runtime_config = copy.deepcopy(config) runtime_config.setdefault("runtime", {})["output_dir"] = log_dir runtime_config["runtime"]["evaluation_mode"] = False runtime_config["runtime"]["run_timestamp"] = resolved_run_timestamp write_shared_run_config( runtime_config, log_dir=log_dir, checkpoint_dir=checkpoint_dir, run_timestamp=run_timestamp, ) logger = TrainingLogger(log_dir, model_key) env = SUMOEdgeVSLEnvironment(runtime_config) state_dim = env.state_dim num_edges = env.num_controlled_edges num_actions_per_edge = env.action_dim print("=" * 70) print(f"{model_label} training - SUMO VSL environment") print("=" * 70) print(f" State dim: {state_dim}") print(f" Controlled edges: {num_edges}") print(f" Actions per edge: {num_actions_per_edge}") print(f" Episode steps: {env.episode_length}") print(f" Control interval: {env.control_interval}s") print(f" Hidden dim: {agent_config.get('hidden_dim', 256)}") print(f" LR: {agent_config.get('learning_rate', 1e-3)}") print(f" Device: {agent_config.get('device', 'cuda')}") print() print(f" Global seed: {base_seed if base_seed is not None else 'None (random)'}") print() agent = _build_value_based_agent(agent_builder, env, agent_config) num_episodes = train_config["num_episodes"] save_freq = train_config.get("save_freq", 50) log_freq = train_config.get("log_freq", 10) episode_rewards = [] episode_throughputs = [] episode_mean_speeds = [] episode_speed_variance_norms = [] episode_ttc_risks = [] value_losses = [] best_reward = -float("inf") print("Starting training...\n") try: for episode in range(1, num_episodes + 1): seed = derive_seed(base_seed, episode) state = env.reset(seed=seed) episode_reward = 0.0 episode_throughput = 0.0 episode_speed = 0.0 episode_speed_variance_norm = 0.0 episode_reward_components = init_reward_component_totals() episode_ttc_risk = 0.0 done = False step = 0 pbar = tqdm(total=env.episode_length, desc=f"Ep {episode}/{num_episodes}", leave=False) while not done: action, _, _ = agent.select_action(state, deterministic=False) next_state, reward, done, info = env.step(action) agent.store_transition(state, action, reward, next_state, done) train_stats = agent.update() if train_stats: value_losses.append(train_stats["value_loss"]) episode_reward += reward episode_throughput += info["throughput"] episode_speed += info["mean_speed_kmh"] episode_speed_variance_norm += info["speed_variance_norm"] for column in REWARD_COMPONENT_COLUMNS: episode_reward_components[column] += float(info.get(column, 0.0)) episode_ttc_risk += float(info.get("ttc_risk_rate", 0.0)) state = next_state step += 1 pbar.set_postfix( r=f"{episode_reward:.1f}", tp=f"{info['throughput']:.0f}", v=f"{info['mean_speed_kmh']:.1f}", ) pbar.update(1) pbar.close() avg_tp = episode_throughput / max(step, 1) avg_speed = episode_speed / max(step, 1) avg_speed_variance_norm = episode_speed_variance_norm / max(step, 1) avg_reward_components = average_reward_components(episode_reward_components, step) episode_rewards.append(episode_reward) episode_throughputs.append(avg_tp) episode_mean_speeds.append(avg_speed) episode_speed_variance_norms.append(avg_speed_variance_norm) episode_ttc_risks.append(episode_ttc_risk) loss_val = np.mean(value_losses[-100:]) if value_losses else None logger.log( episode, episode_reward, avg_tp, avg_speed, speed_variance_norm=avg_speed_variance_norm, reward_components=avg_reward_components, ttc_risk=episode_ttc_risk, value_loss=loss_val, ) episode_summary = { "episode": episode, "reward": float(episode_reward), "avg_throughput": float(avg_tp), "avg_mean_speed_kmh": float(avg_speed), "avg_speed_variance_norm": float(avg_speed_variance_norm), "ttc_risk": float(episode_ttc_risk), "value_loss": float(loss_val) if loss_val is not None else None, } for column, value in avg_reward_components.items(): episode_summary[f"avg_{column}"] = float(value) save_training_episode_artifacts( log_dir=log_dir, episode=episode, episode_metrics=env.episode_metrics, control_edges=env.control_edges, summary=episode_summary, ) if episode_reward > best_reward: best_reward = episode_reward agent.save(os.path.join(checkpoint_dir, "model_best.pt")) if episode % log_freq == 0: recent_rewards = episode_rewards[-log_freq:] print(f"\nEpisode {episode}/{num_episodes}") print(f" Reward: {episode_reward:.2f} (Avg: {np.mean(recent_rewards):.2f})") print(f" Throughput: {avg_tp:.1f} veh/h") print(f" Mean Speed: {avg_speed:.1f} km/h") print(f" Normalized Speed Variance: {avg_speed_variance_norm:.4f}") print( " Reward Components: " + ", ".join( f"{column}={avg_reward_components[column]:.3f}" for column in REWARD_COMPONENT_COLUMNS ) ) if loss_val is not None: print(f" Value Loss: {loss_val:.4f}") if episode % save_freq == 0: agent.save(os.path.join(checkpoint_dir, f"model_ep{episode}.pt")) except KeyboardInterrupt: print("\nTraining interrupted, saving current model...") agent.save(os.path.join(checkpoint_dir, "model_interrupted.pt")) finally: env.close() agent.save(os.path.join(checkpoint_dir, f"model_ep{num_episodes}.pt")) plot_training_curves( episode_rewards, episode_throughputs, episode_mean_speeds, episode_speed_variance_norms, episode_ttc_risks, save_path=os.path.join(log_dir, "training_curves.png"), ) print("=" * 70) print("Training complete") print(f" Best reward: {best_reward:.2f}") print(f" Model dir: {checkpoint_dir}") print(f" Log dir: {log_dir}") print("=" * 70)