"""Shared-parameter multi-head DQN agent for corridor VSL control.""" from __future__ import annotations import random from collections import deque from typing import Deque, Dict, Tuple import numpy as np import torch import torch.nn as nn import torch.optim as optim class DQNNetwork(nn.Module): """Shared Q-network with one discrete action head per controlled edge.""" def __init__( self, state_dim: int, num_agents: int, num_actions_per_agent: int, hidden_dim: int = 256, ): super().__init__() self.num_agents = num_agents self.num_actions_per_agent = num_actions_per_agent self.network = nn.Sequential( nn.Linear(state_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, hidden_dim // 2), nn.ReLU(), nn.Linear(hidden_dim // 2, num_agents * num_actions_per_agent), ) def forward(self, x: torch.Tensor) -> torch.Tensor: q_values = self.network(x) return q_values.view(-1, self.num_agents, self.num_actions_per_agent) class DQNAgent: """Shared-parameter multi-head DQN with epsilon-greedy exploration.""" def __init__( self, state_dim: int, num_edges: int, num_actions_per_edge: int, hidden_dim: int = 256, learning_rate: float = 1e-3, gamma: float = 0.99, epsilon_start: float = 1.0, epsilon_end: float = 0.01, epsilon_decay: int = 200, buffer_size: int = 10000, batch_size: int = 64, target_update: int = 10, device: str = "cuda", ): self.device = torch.device(device if torch.cuda.is_available() else "cpu") self.num_agents = num_edges self.num_actions_per_agent = num_actions_per_edge self.gamma = gamma self.epsilon_start = epsilon_start self.epsilon_end = epsilon_end self.epsilon_decay = epsilon_decay self.batch_size = batch_size self.target_update = max(int(target_update), 1) self.q_net = DQNNetwork( state_dim=state_dim, num_agents=num_edges, num_actions_per_agent=num_actions_per_edge, hidden_dim=hidden_dim, ).to(self.device) self.target_q_net = DQNNetwork( state_dim=state_dim, num_agents=num_edges, num_actions_per_agent=num_actions_per_edge, hidden_dim=hidden_dim, ).to(self.device) self.target_q_net.load_state_dict(self.q_net.state_dict()) self.target_q_net.eval() self.optimizer = optim.Adam(self.q_net.parameters(), lr=learning_rate) self.replay_buffer: Deque[Tuple[np.ndarray, np.ndarray, float, np.ndarray, bool]] = deque( maxlen=buffer_size ) self.env_steps = 0 self.update_steps = 0 def _current_epsilon(self) -> float: return self.epsilon_end + (self.epsilon_start - self.epsilon_end) * np.exp( -float(self.env_steps) / max(float(self.epsilon_decay), 1.0) ) def _greedy_action(self, state: np.ndarray) -> np.ndarray: state_tensor = torch.as_tensor(state, dtype=torch.float32, device=self.device).unsqueeze(0) with torch.no_grad(): q_values = self.q_net(state_tensor) return q_values.argmax(dim=2).squeeze(0).cpu().numpy().astype(np.int64) def select_action( self, state: np.ndarray, deterministic: bool = False, ) -> Tuple[np.ndarray, float, float]: if deterministic: return self._greedy_action(state), 0.0, 0.0 epsilon = self._current_epsilon() self.env_steps += 1 if random.random() < epsilon: action = np.array( [random.randrange(self.num_actions_per_agent) for _ in range(self.num_agents)], dtype=np.int64, ) else: action = self._greedy_action(state) return action, 0.0, 0.0 def store_transition(self, state, action, reward, next_state, done): transition = ( np.asarray(state, dtype=np.float32), np.asarray(action, dtype=np.int64), float(reward), np.asarray(next_state, dtype=np.float32), bool(done), ) self.replay_buffer.append(transition) def _sync_target_network(self): self.target_q_net.load_state_dict(self.q_net.state_dict()) def _sample_batch(self): batch = random.sample(self.replay_buffer, self.batch_size) states, actions, rewards, next_states, dones = zip(*batch) states_t = torch.as_tensor(np.array(states), dtype=torch.float32, device=self.device) actions_t = torch.as_tensor(np.array(actions), dtype=torch.long, device=self.device) rewards_t = torch.as_tensor(rewards, dtype=torch.float32, device=self.device).unsqueeze(1) next_states_t = torch.as_tensor(np.array(next_states), dtype=torch.float32, device=self.device) dones_t = torch.as_tensor(dones, dtype=torch.float32, device=self.device).unsqueeze(1) return states_t, actions_t, rewards_t, next_states_t, dones_t def update(self) -> Dict[str, float]: if len(self.replay_buffer) < self.batch_size: return {} states_t, actions_t, rewards_t, next_states_t, dones_t = self._sample_batch() current_q_all = self.q_net(states_t) current_q = current_q_all.gather(2, actions_t.unsqueeze(2)).squeeze(2) with torch.no_grad(): next_q_all = self.target_q_net(next_states_t) next_q = next_q_all.max(dim=2)[0] target_q = rewards_t + (1.0 - dones_t) * self.gamma * next_q loss = nn.MSELoss()(current_q, target_q) self.optimizer.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm_(self.q_net.parameters(), 1.0) self.optimizer.step() self.update_steps += 1 if self.update_steps % self.target_update == 0: self._sync_target_network() value_loss = float(loss.item()) return {"loss": value_loss, "value_loss": value_loss} def save(self, path: str): torch.save( { "q_net": self.q_net.state_dict(), "target_q_net": self.target_q_net.state_dict(), "optimizer": self.optimizer.state_dict(), "env_steps": int(self.env_steps), "update_steps": int(self.update_steps), }, path, ) def load(self, path: str): checkpoint = torch.load(path, map_location=self.device, weights_only=False) self.q_net.load_state_dict(checkpoint["q_net"]) self.target_q_net.load_state_dict(checkpoint["target_q_net"]) optimizer_state = checkpoint.get("optimizer") if optimizer_state is not None: self.optimizer.load_state_dict(optimizer_state) self.env_steps = int(checkpoint.get("env_steps", 0)) self.update_steps = int(checkpoint.get("update_steps", 0))