"""Independent multi-agent DQN agent for corridor VSL control.""" from __future__ import annotations import random from collections import deque from typing import Deque, Dict, Tuple import numpy as np import torch import torch.nn as nn import torch.optim as optim class MADQNNetwork(nn.Module): """Per-agent Q-network that maps the global state to local action values.""" def __init__(self, state_dim: int, num_actions: int, hidden_dim: int = 256): super().__init__() self.network = nn.Sequential( nn.Linear(state_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, hidden_dim // 2), nn.ReLU(), nn.Linear(hidden_dim // 2, num_actions), ) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.network(x) class MADQNAgent: """Independent DQN ensemble with one Q-network per controlled edge.""" def __init__( self, state_dim: int, num_edges: int, num_actions_per_edge: int, hidden_dim: int = 256, learning_rate: float = 1e-3, gamma: float = 0.99, epsilon_start: float = 1.0, epsilon_end: float = 0.01, epsilon_decay: int = 200, buffer_size: int = 10000, batch_size: int = 64, target_update: int = 10, device: str = "cuda", ): self.device = torch.device(device if torch.cuda.is_available() else "cpu") self.num_agents = num_edges self.num_actions_per_agent = num_actions_per_edge self.gamma = gamma self.epsilon_start = epsilon_start self.epsilon_end = epsilon_end self.epsilon_decay = epsilon_decay self.batch_size = batch_size self.target_update = max(int(target_update), 1) self.q_nets = nn.ModuleList( [ MADQNNetwork( state_dim=state_dim, num_actions=num_actions_per_edge, hidden_dim=hidden_dim, ) for _ in range(num_edges) ] ).to(self.device) self.target_q_nets = nn.ModuleList( [ MADQNNetwork( state_dim=state_dim, num_actions=num_actions_per_edge, hidden_dim=hidden_dim, ) for _ in range(num_edges) ] ).to(self.device) self._sync_target_network() self.target_q_nets.eval() self.optimizer = optim.Adam(self.q_nets.parameters(), lr=learning_rate) self.replay_buffer: Deque[Tuple[np.ndarray, np.ndarray, float, np.ndarray, bool]] = deque( maxlen=buffer_size ) self.env_steps = 0 self.update_steps = 0 def _current_epsilon(self) -> float: return self.epsilon_end + (self.epsilon_start - self.epsilon_end) * np.exp( -float(self.env_steps) / max(float(self.epsilon_decay), 1.0) ) def _greedy_action(self, state: np.ndarray) -> np.ndarray: state_tensor = torch.as_tensor(state, dtype=torch.float32, device=self.device).unsqueeze(0) actions = [] with torch.no_grad(): for q_net in self.q_nets: q_values = q_net(state_tensor) actions.append(int(q_values.argmax(dim=1).item())) return np.asarray(actions, dtype=np.int64) def select_action( self, state: np.ndarray, deterministic: bool = False, ) -> Tuple[np.ndarray, float, float]: if deterministic: return self._greedy_action(state), 0.0, 0.0 epsilon = self._current_epsilon() self.env_steps += 1 if random.random() < epsilon: action = np.array( [random.randrange(self.num_actions_per_agent) for _ in range(self.num_agents)], dtype=np.int64, ) else: action = self._greedy_action(state) return action, 0.0, 0.0 def store_transition(self, state, action, reward, next_state, done): transition = ( np.asarray(state, dtype=np.float32), np.asarray(action, dtype=np.int64), float(reward), np.asarray(next_state, dtype=np.float32), bool(done), ) self.replay_buffer.append(transition) def _sync_target_network(self): for q_net, target_q_net in zip(self.q_nets, self.target_q_nets): target_q_net.load_state_dict(q_net.state_dict()) def _sample_batch(self): batch = random.sample(self.replay_buffer, self.batch_size) states, actions, rewards, next_states, dones = zip(*batch) states_t = torch.as_tensor(np.array(states), dtype=torch.float32, device=self.device) actions_t = torch.as_tensor(np.array(actions), dtype=torch.long, device=self.device) rewards_t = torch.as_tensor(rewards, dtype=torch.float32, device=self.device).unsqueeze(1) next_states_t = torch.as_tensor(np.array(next_states), dtype=torch.float32, device=self.device) dones_t = torch.as_tensor(dones, dtype=torch.float32, device=self.device).unsqueeze(1) return states_t, actions_t, rewards_t, next_states_t, dones_t def update(self) -> Dict[str, float]: if len(self.replay_buffer) < self.batch_size: return {} states_t, actions_t, rewards_t, next_states_t, dones_t = self._sample_batch() current_q_values = [] target_q_values = [] for agent_idx, (q_net, target_q_net) in enumerate(zip(self.q_nets, self.target_q_nets)): current_q = q_net(states_t).gather(1, actions_t[:, agent_idx].unsqueeze(1)) with torch.no_grad(): next_q = target_q_net(next_states_t).max(dim=1, keepdim=True)[0] target_q = rewards_t + (1.0 - dones_t) * self.gamma * next_q current_q_values.append(current_q) target_q_values.append(target_q) current_q_tensor = torch.cat(current_q_values, dim=1) target_q_tensor = torch.cat(target_q_values, dim=1) loss = nn.MSELoss()(current_q_tensor, target_q_tensor) self.optimizer.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm_(self.q_nets.parameters(), 1.0) self.optimizer.step() self.update_steps += 1 if self.update_steps % self.target_update == 0: self._sync_target_network() value_loss = float(loss.item()) return {"loss": value_loss, "value_loss": value_loss} def save(self, path: str): torch.save( { "q_nets": [q_net.state_dict() for q_net in self.q_nets], "target_q_nets": [target_q_net.state_dict() for target_q_net in self.target_q_nets], "optimizer": self.optimizer.state_dict(), "env_steps": int(self.env_steps), "update_steps": int(self.update_steps), }, path, ) def load(self, path: str): checkpoint = torch.load(path, map_location=self.device, weights_only=False) for q_net, state_dict in zip(self.q_nets, checkpoint["q_nets"]): q_net.load_state_dict(state_dict) for target_q_net, state_dict in zip(self.target_q_nets, checkpoint["target_q_nets"]): target_q_net.load_state_dict(state_dict) optimizer_state = checkpoint.get("optimizer") if optimizer_state is not None: self.optimizer.load_state_dict(optimizer_state) self.env_steps = int(checkpoint.get("env_steps", 0)) self.update_steps = int(checkpoint.get("update_steps", 0))