ctm-dqn/train.py

232 lines
8.0 KiB
Python

"""
Training script for DQN-based speed limit control.
"""
import os
import random
import numpy as np
import torch
from tqdm import tqdm
import matplotlib.pyplot as plt
from utils import load_config, create_directories
from environment import TrafficEnvironment
from parallel_env import ParallelEnvironment
from dqn_agent import DQNAgent
def set_random_seed(seed: int):
"""Set random seed for reproducibility."""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
def train(config_path: str = "config.yaml"):
"""Train DQN agent."""
config = load_config(config_path)
create_directories(config)
# Set random seed for reproducibility
random_seed = config["training"].get("random_seed", 42)
set_random_seed(random_seed)
print(f"Random seed set to: {random_seed}")
# Create environment (parallel or single)
num_parallel_envs = config["training"].get("num_parallel_envs", 1)
if num_parallel_envs > 1:
env = ParallelEnvironment(config, num_envs=num_parallel_envs)
print(f"Using {num_parallel_envs} parallel environments with multiprocessing")
else:
env = TrafficEnvironment(config)
print("Using single environment")
agent = DQNAgent(
state_dim=env.state_dim,
action_dim=env.action_dim,
hidden_layers=config["agent"]["hidden_layers"],
learning_rate=config["agent"]["learning_rate"],
gamma=config["agent"]["gamma"],
epsilon_start=config["agent"]["epsilon_start"],
epsilon_end=config["agent"]["epsilon_end"],
epsilon_decay=config["agent"]["epsilon_decay"],
buffer_size=config["agent"]["buffer_size"],
batch_size=config["agent"]["batch_size"],
target_update_freq=config["agent"]["target_update_freq"],
device=config["agent"]["device"],
)
num_episodes = config["training"]["num_episodes"]
save_freq = config["training"]["save_freq"]
log_freq = config["training"]["log_freq"]
checkpoint_dir = config["training"]["checkpoint_dir"]
train_freq = config["training"].get("train_freq", 1) # Train every N steps
episode_rewards = []
episode_losses = []
episode_throughputs = []
# Track best model
best_reward = float('-inf')
best_model_path = os.path.join(checkpoint_dir, "model_best.pt")
print(f"Starting training for {num_episodes} episodes...")
print(f"State dim: {env.state_dim}, Action dim: {env.action_dim}")
print(f"Device: {agent.device}")
print(f"Train frequency: every {train_freq} steps")
# Check if using parallel environments
is_vectorized = num_parallel_envs > 1
for episode in tqdm(range(num_episodes), desc="Training"):
states = env.reset()
step_count = 0 # Track steps for training frequency
if is_vectorized:
# Parallel environment training
episode_rewards_vec = np.zeros(num_parallel_envs)
episode_loss = 0
loss_count = 0
done_flags = np.zeros(num_parallel_envs, dtype=bool)
while not np.all(done_flags):
# Select actions for all environments
actions = np.array([agent.select_action(states[i], training=True)
for i in range(num_parallel_envs)])
next_states, rewards, dones, infos = env.step(actions)
# Store transitions for all environments
for i in range(num_parallel_envs):
if not done_flags[i]:
agent.store_transition(states[i], actions[i], rewards[i],
next_states[i], dones[i])
episode_rewards_vec[i] += rewards[i]
if dones[i]:
done_flags[i] = True
# Train agent every train_freq steps
step_count += 1
if step_count % train_freq == 0:
loss = agent.train()
if loss > 0:
episode_loss += loss
loss_count += 1
states = next_states
episode_reward = np.mean(episode_rewards_vec)
else:
# Single environment training
episode_reward = 0
episode_loss = 0
loss_count = 0
state = states
step_count = 0
while True:
action = agent.select_action(state, training=True)
next_state, reward, done, info = env.step(action)
agent.store_transition(state, action, reward, next_state, done)
# Train agent every train_freq steps
step_count += 1
if step_count % train_freq == 0:
loss = agent.train()
if loss > 0:
episode_loss += loss
loss_count += 1
episode_reward += reward
state = next_state
if done:
break
agent.end_episode()
episode_rewards.append(episode_reward)
episode_losses.append(episode_loss / max(1, loss_count))
# Calculate average throughput
if is_vectorized:
all_metrics = env.get_episode_metrics()
avg_throughput = np.mean([
np.mean([m["throughput"] for m in metrics])
for metrics in all_metrics if len(metrics) > 0
])
else:
avg_throughput = np.mean([m["throughput"] for m in env.episode_metrics])
episode_throughputs.append(avg_throughput)
# Save best model based on episode reward
if episode_reward > best_reward:
best_reward = episode_reward
agent.save(best_model_path)
print(f"\n*** New best model saved! Episode {episode + 1}, Reward: {episode_reward:.2f} ***")
if (episode + 1) % log_freq == 0:
avg_reward = np.mean(episode_rewards[-log_freq:])
avg_loss = np.mean(episode_losses[-log_freq:])
avg_tp = np.mean(episode_throughputs[-log_freq:])
print(
f"\nEpisode {episode + 1}/{num_episodes} | "
f"Avg Reward: {avg_reward:.2f} | "
f"Avg Loss: {avg_loss:.4f} | "
f"Avg Throughput: {avg_tp:.2f} | "
f"Epsilon: {agent.epsilon:.3f}"
)
if (episode + 1) % save_freq == 0:
model_path = os.path.join(
checkpoint_dir, f"model_episode_{episode + 1}.pt"
)
agent.save(model_path)
print(f"Model saved to {model_path}")
final_model_path = os.path.join(checkpoint_dir, "model_final.pt")
agent.save(final_model_path)
print(f"\nTraining completed!")
print(f"Final model saved to {final_model_path}")
print(f"Best model saved to {best_model_path} (Best Reward: {best_reward:.2f})")
plot_training_results(
episode_rewards, episode_losses, episode_throughputs, config["training"]["log_dir"]
)
def plot_training_results(rewards, losses, throughputs, log_dir):
"""Plot and save training results."""
fig, axes = plt.subplots(3, 1, figsize=(10, 12))
axes[0].plot(rewards)
axes[0].set_xlabel("Episode")
axes[0].set_ylabel("Total Reward")
axes[0].set_title("Episode Rewards")
axes[0].grid(True)
axes[1].plot(losses)
axes[1].set_xlabel("Episode")
axes[1].set_ylabel("Average Loss")
axes[1].set_title("Training Loss")
axes[1].grid(True)
axes[2].plot(throughputs)
axes[2].set_xlabel("Episode")
axes[2].set_ylabel("Average Throughput (veh/h)")
axes[2].set_title("Traffic Throughput")
axes[2].grid(True)
plt.tight_layout()
plot_path = os.path.join(log_dir, "training_results.png")
plt.savefig(plot_path)
print(f"Training plots saved to {plot_path}")
plt.close()
if __name__ == "__main__":
train()