""" Training script for DQN-based speed limit control. """ import os import random import numpy as np import torch from tqdm import tqdm import matplotlib.pyplot as plt from utils import load_config, create_directories from environment import TrafficEnvironment from parallel_env import ParallelEnvironment from dqn_agent import DQNAgent def set_random_seed(seed: int): """Set random seed for reproducibility.""" random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False def train(config_path: str = "config.yaml"): """Train DQN agent.""" config = load_config(config_path) create_directories(config) # Set random seed for reproducibility random_seed = config["training"].get("random_seed", 42) set_random_seed(random_seed) print(f"Random seed set to: {random_seed}") # Create environment (parallel or single) num_parallel_envs = config["training"].get("num_parallel_envs", 1) if num_parallel_envs > 1: env = ParallelEnvironment(config, num_envs=num_parallel_envs) print(f"Using {num_parallel_envs} parallel environments with multiprocessing") else: env = TrafficEnvironment(config) print("Using single environment") agent = DQNAgent( state_dim=env.state_dim, action_dim=env.action_dim, hidden_layers=config["agent"]["hidden_layers"], learning_rate=config["agent"]["learning_rate"], gamma=config["agent"]["gamma"], epsilon_start=config["agent"]["epsilon_start"], epsilon_end=config["agent"]["epsilon_end"], epsilon_decay=config["agent"]["epsilon_decay"], buffer_size=config["agent"]["buffer_size"], batch_size=config["agent"]["batch_size"], target_update_freq=config["agent"]["target_update_freq"], device=config["agent"]["device"], ) num_episodes = config["training"]["num_episodes"] save_freq = config["training"]["save_freq"] log_freq = config["training"]["log_freq"] checkpoint_dir = config["training"]["checkpoint_dir"] train_freq = config["training"].get("train_freq", 1) # Train every N steps episode_rewards = [] episode_losses = [] episode_throughputs = [] # Track best model best_reward = float('-inf') best_model_path = os.path.join(checkpoint_dir, "model_best.pt") print(f"Starting training for {num_episodes} episodes...") print(f"State dim: {env.state_dim}, Action dim: {env.action_dim}") print(f"Device: {agent.device}") print(f"Train frequency: every {train_freq} steps") # Check if using parallel environments is_vectorized = num_parallel_envs > 1 for episode in tqdm(range(num_episodes), desc="Training"): states = env.reset() step_count = 0 # Track steps for training frequency if is_vectorized: # Parallel environment training episode_rewards_vec = np.zeros(num_parallel_envs) episode_loss = 0 loss_count = 0 done_flags = np.zeros(num_parallel_envs, dtype=bool) while not np.all(done_flags): # Select actions for all environments actions = np.array([agent.select_action(states[i], training=True) for i in range(num_parallel_envs)]) next_states, rewards, dones, infos = env.step(actions) # Store transitions for all environments for i in range(num_parallel_envs): if not done_flags[i]: agent.store_transition(states[i], actions[i], rewards[i], next_states[i], dones[i]) episode_rewards_vec[i] += rewards[i] if dones[i]: done_flags[i] = True # Train agent every train_freq steps step_count += 1 if step_count % train_freq == 0: loss = agent.train() if loss > 0: episode_loss += loss loss_count += 1 states = next_states episode_reward = np.mean(episode_rewards_vec) else: # Single environment training episode_reward = 0 episode_loss = 0 loss_count = 0 state = states step_count = 0 while True: action = agent.select_action(state, training=True) next_state, reward, done, info = env.step(action) agent.store_transition(state, action, reward, next_state, done) # Train agent every train_freq steps step_count += 1 if step_count % train_freq == 0: loss = agent.train() if loss > 0: episode_loss += loss loss_count += 1 episode_reward += reward state = next_state if done: break agent.end_episode() episode_rewards.append(episode_reward) episode_losses.append(episode_loss / max(1, loss_count)) # Calculate average throughput if is_vectorized: all_metrics = env.get_episode_metrics() avg_throughput = np.mean([ np.mean([m["throughput"] for m in metrics]) for metrics in all_metrics if len(metrics) > 0 ]) else: avg_throughput = np.mean([m["throughput"] for m in env.episode_metrics]) episode_throughputs.append(avg_throughput) # Save best model based on episode reward if episode_reward > best_reward: best_reward = episode_reward agent.save(best_model_path) print(f"\n*** New best model saved! Episode {episode + 1}, Reward: {episode_reward:.2f} ***") if (episode + 1) % log_freq == 0: avg_reward = np.mean(episode_rewards[-log_freq:]) avg_loss = np.mean(episode_losses[-log_freq:]) avg_tp = np.mean(episode_throughputs[-log_freq:]) print( f"\nEpisode {episode + 1}/{num_episodes} | " f"Avg Reward: {avg_reward:.2f} | " f"Avg Loss: {avg_loss:.4f} | " f"Avg Throughput: {avg_tp:.2f} | " f"Epsilon: {agent.epsilon:.3f}" ) if (episode + 1) % save_freq == 0: model_path = os.path.join( checkpoint_dir, f"model_episode_{episode + 1}.pt" ) agent.save(model_path) print(f"Model saved to {model_path}") final_model_path = os.path.join(checkpoint_dir, "model_final.pt") agent.save(final_model_path) print(f"\nTraining completed!") print(f"Final model saved to {final_model_path}") print(f"Best model saved to {best_model_path} (Best Reward: {best_reward:.2f})") plot_training_results( episode_rewards, episode_losses, episode_throughputs, config["training"]["log_dir"] ) def plot_training_results(rewards, losses, throughputs, log_dir): """Plot and save training results.""" fig, axes = plt.subplots(3, 1, figsize=(10, 12)) axes[0].plot(rewards) axes[0].set_xlabel("Episode") axes[0].set_ylabel("Total Reward") axes[0].set_title("Episode Rewards") axes[0].grid(True) axes[1].plot(losses) axes[1].set_xlabel("Episode") axes[1].set_ylabel("Average Loss") axes[1].set_title("Training Loss") axes[1].grid(True) axes[2].plot(throughputs) axes[2].set_xlabel("Episode") axes[2].set_ylabel("Average Throughput (veh/h)") axes[2].set_title("Traffic Throughput") axes[2].grid(True) plt.tight_layout() plot_path = os.path.join(log_dir, "training_results.png") plt.savefig(plot_path) print(f"Training plots saved to {plot_path}") plt.close() if __name__ == "__main__": train()