""" Deep Q-Network (DQN) Agent for speed limit control. """ import torch import torch.nn as nn import torch.optim as optim import numpy as np from collections import deque import random from typing import Tuple, List class QNetwork(nn.Module): """Q-Network for DQN.""" def __init__(self, state_dim: int, action_dim: int, hidden_layers: List[int]): """ Initialize Q-Network. Args: state_dim: Dimension of state space action_dim: Dimension of action space hidden_layers: List of hidden layer sizes """ super(QNetwork, self).__init__() layers = [] input_dim = state_dim for hidden_dim in hidden_layers: layers.append(nn.Linear(input_dim, hidden_dim)) layers.append(nn.ReLU()) input_dim = hidden_dim layers.append(nn.Linear(input_dim, action_dim)) self.network = nn.Sequential(*layers) def forward(self, state: torch.Tensor) -> torch.Tensor: """Forward pass.""" return self.network(state) class ReplayBuffer: """Experience replay buffer.""" def __init__(self, capacity: int): """Initialize replay buffer.""" self.buffer = deque(maxlen=capacity) def push( self, state: np.ndarray, action: int, reward: float, next_state: np.ndarray, done: bool, ): """Add experience to buffer.""" self.buffer.append((state, action, reward, next_state, done)) def sample(self, batch_size: int) -> Tuple: """Sample a batch of experiences.""" batch = random.sample(self.buffer, batch_size) states, actions, rewards, next_states, dones = zip(*batch) return ( np.array(states), np.array(actions), np.array(rewards), np.array(next_states), np.array(dones), ) def __len__(self) -> int: """Return buffer size.""" return len(self.buffer) class DQNAgent: """DQN Agent for speed limit control.""" def __init__( self, state_dim: int, action_dim: int, hidden_layers: List[int] = [128, 128], learning_rate: float = 0.0001, gamma: float = 0.99, epsilon_start: float = 1.0, epsilon_end: float = 0.01, epsilon_decay: float = 0.995, buffer_size: int = 50000, batch_size: int = 64, target_update_freq: int = 10, device: str = "cuda", ): """Initialize DQN agent.""" self.state_dim = state_dim self.action_dim = action_dim self.gamma = gamma self.epsilon = epsilon_start self.epsilon_end = epsilon_end self.epsilon_decay = epsilon_decay self.batch_size = batch_size self.target_update_freq = target_update_freq self.device = torch.device(device if torch.cuda.is_available() else "cpu") self.q_network = QNetwork(state_dim, action_dim, hidden_layers).to( self.device ) self.target_network = QNetwork(state_dim, action_dim, hidden_layers).to( self.device ) self.target_network.load_state_dict(self.q_network.state_dict()) self.optimizer = optim.Adam(self.q_network.parameters(), lr=learning_rate) self.replay_buffer = ReplayBuffer(buffer_size) self.episode_count = 0 def select_action(self, state: np.ndarray, training: bool = True) -> int: """Select action using epsilon-greedy policy.""" if training and random.random() < self.epsilon: return random.randint(0, self.action_dim - 1) with torch.no_grad(): state_tensor = torch.FloatTensor(state).unsqueeze(0).to(self.device) q_values = self.q_network(state_tensor) return q_values.argmax().item() def store_transition( self, state: np.ndarray, action: int, reward: float, next_state: np.ndarray, done: bool, ): """Store transition in replay buffer.""" self.replay_buffer.push(state, action, reward, next_state, done) def train(self) -> float: """Train the agent on a batch of experiences.""" if len(self.replay_buffer) < self.batch_size: return 0.0 states, actions, rewards, next_states, dones = self.replay_buffer.sample( self.batch_size ) states = torch.FloatTensor(states).to(self.device) actions = torch.LongTensor(actions).to(self.device) rewards = torch.FloatTensor(rewards).to(self.device) next_states = torch.FloatTensor(next_states).to(self.device) dones = torch.FloatTensor(dones).to(self.device) current_q_values = self.q_network(states).gather(1, actions.unsqueeze(1)) with torch.no_grad(): next_q_values = self.target_network(next_states).max(1)[0] target_q_values = rewards + (1 - dones) * self.gamma * next_q_values loss = nn.MSELoss()(current_q_values.squeeze(), target_q_values) self.optimizer.zero_grad() loss.backward() self.optimizer.step() return loss.item() def update_target_network(self): """Update target network with current Q-network weights.""" self.target_network.load_state_dict(self.q_network.state_dict()) def update_epsilon(self): """Decay epsilon for exploration.""" self.epsilon = max(self.epsilon_end, self.epsilon * self.epsilon_decay) def end_episode(self): """Called at the end of each episode.""" self.episode_count += 1 if self.episode_count % self.target_update_freq == 0: self.update_target_network() self.update_epsilon() def save(self, path: str): """Save model checkpoint.""" torch.save( { "q_network": self.q_network.state_dict(), "target_network": self.target_network.state_dict(), "optimizer": self.optimizer.state_dict(), "epsilon": self.epsilon, "episode_count": self.episode_count, }, path, ) def load(self, path: str): """Load model checkpoint.""" checkpoint = torch.load(path, map_location=self.device) self.q_network.load_state_dict(checkpoint["q_network"]) self.target_network.load_state_dict(checkpoint["target_network"]) self.optimizer.load_state_dict(checkpoint["optimizer"]) self.epsilon = checkpoint["epsilon"] self.episode_count = checkpoint["episode_count"]