"""QMIX baseline agent for cooperative corridor VSL control.""" from __future__ import annotations import random from collections import deque from typing import Deque, Dict, Tuple import numpy as np import torch import torch.nn as nn import torch.optim as optim class SharedUtilityNetwork(nn.Module): """Shared per-agent utility network for homogeneous corridor controllers.""" def __init__(self, local_obs_dim: int, num_actions: int, hidden_dim: int = 256): super().__init__() self.network = nn.Sequential( nn.Linear(local_obs_dim, hidden_dim), nn.LayerNorm(hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, hidden_dim), nn.LayerNorm(hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, num_actions), ) self._init_weights() def _init_weights(self): for module in self.modules(): if isinstance(module, nn.Linear): nn.init.orthogonal_(module.weight, gain=np.sqrt(2)) nn.init.constant_(module.bias, 0.0) nn.init.orthogonal_(self.network[-1].weight, gain=0.01) def forward(self, local_obs: torch.Tensor) -> torch.Tensor: return self.network(local_obs) class QMixer(nn.Module): """Monotonic mixing network conditioned on the global state.""" def __init__( self, num_agents: int, state_dim: int, mixing_hidden_dim: int = 256, ): super().__init__() self.num_agents = num_agents self.state_dim = state_dim self.mixing_hidden_dim = mixing_hidden_dim self.hyper_w1 = nn.Sequential( nn.Linear(state_dim, mixing_hidden_dim), nn.ReLU(), nn.Linear(mixing_hidden_dim, num_agents * mixing_hidden_dim), ) self.hyper_b1 = nn.Linear(state_dim, mixing_hidden_dim) self.hyper_w2 = nn.Sequential( nn.Linear(state_dim, mixing_hidden_dim), nn.ReLU(), nn.Linear(mixing_hidden_dim, mixing_hidden_dim), ) self.hyper_b2 = nn.Sequential( nn.Linear(state_dim, mixing_hidden_dim), nn.ReLU(), nn.Linear(mixing_hidden_dim, 1), ) self.activation = nn.ELU() self._init_weights() def _init_weights(self): for module in self.modules(): if isinstance(module, nn.Linear): nn.init.orthogonal_(module.weight, gain=np.sqrt(2)) nn.init.constant_(module.bias, 0.0) nn.init.orthogonal_(self.hyper_b2[-1].weight, gain=1.0) def forward(self, agent_qs: torch.Tensor, state: torch.Tensor) -> torch.Tensor: batch_size = agent_qs.size(0) flat_state = state.view(batch_size, -1) w1 = torch.abs(self.hyper_w1(flat_state)).view( batch_size, self.num_agents, self.mixing_hidden_dim, ) b1 = self.hyper_b1(flat_state).view(batch_size, 1, self.mixing_hidden_dim) hidden = self.activation(torch.bmm(agent_qs.unsqueeze(1), w1) + b1) w2 = torch.abs(self.hyper_w2(flat_state)).view( batch_size, self.mixing_hidden_dim, 1, ) b2 = self.hyper_b2(flat_state).view(batch_size, 1, 1) return torch.bmm(hidden, w2).squeeze(1) + b2.squeeze(1) class QMIXAgent: """Standard QMIX with parameter-shared utilities for cooperative discrete VSL.""" def __init__( self, state_dim: int, num_edges: int, num_actions_per_edge: int, hidden_dim: int = 256, mixing_hidden_dim: int = 256, learning_rate: float = 1e-3, gamma: float = 0.99, epsilon_start: float = 1.0, epsilon_end: float = 0.01, epsilon_decay: int = 200, buffer_size: int = 10000, batch_size: int = 64, target_update: int = 10, device: str = "cuda", edge_feature_dim: int = 3, time_feature_dim: int = 3, total_edge_count: int | None = None, controlled_start_index: int = 0, ): self.device = torch.device(device if torch.cuda.is_available() else "cpu") self.state_dim = state_dim self.num_agents = int(num_edges) self.num_actions_per_agent = int(num_actions_per_edge) self.gamma = float(gamma) self.epsilon_start = float(epsilon_start) self.epsilon_end = float(epsilon_end) self.epsilon_decay = int(epsilon_decay) self.batch_size = int(batch_size) self.target_update = max(int(target_update), 1) self.edge_feature_dim = int(edge_feature_dim) self.time_feature_dim = int(time_feature_dim) self.speed_feature_dim = 1 self.last_reward_dim = 1 self.agent_id_dim = 1 self.total_edge_count = int(total_edge_count if total_edge_count is not None else num_edges) self.controlled_start_index = int(controlled_start_index) self.controlled_end_index = self.controlled_start_index + self.num_agents if self.controlled_end_index > self.total_edge_count: raise ValueError("controlled action slice exceeds total edge count") self.local_obs_dim = ( self.edge_feature_dim + self.speed_feature_dim + self.time_feature_dim + self.last_reward_dim + self.agent_id_dim ) self.utility_net = self._build_utility_network(hidden_dim).to(self.device) self.target_utility_net = self._build_utility_network(hidden_dim).to(self.device) self.mixer = self._build_mixer(mixing_hidden_dim).to(self.device) self.target_mixer = self._build_mixer(mixing_hidden_dim).to(self.device) self._sync_target_networks() self.target_utility_net.eval() self.target_mixer.eval() self.optimizer = optim.Adam( list(self.utility_net.parameters()) + list(self.mixer.parameters()), lr=learning_rate, ) self.replay_buffer: Deque[Tuple[np.ndarray, np.ndarray, float, np.ndarray, bool]] = deque( maxlen=buffer_size ) self.env_steps = 0 self.update_steps = 0 agent_ids = np.linspace(0.0, 1.0, self.num_agents, dtype=np.float32) self.agent_id_features = torch.tensor(agent_ids, device=self.device).view(1, self.num_agents, 1) def _build_utility_network(self, hidden_dim: int) -> nn.Module: return SharedUtilityNetwork( local_obs_dim=self.local_obs_dim, num_actions=self.num_actions_per_agent, hidden_dim=hidden_dim, ) def _build_mixer(self, mixing_hidden_dim: int) -> nn.Module: return QMixer( num_agents=self.num_agents, state_dim=self.state_dim, mixing_hidden_dim=mixing_hidden_dim, ) def _current_epsilon(self) -> float: return self.epsilon_end + (self.epsilon_start - self.epsilon_end) * np.exp( -float(self.env_steps) / max(float(self.epsilon_decay), 1.0) ) def _build_local_obs(self, state_tensor: torch.Tensor) -> torch.Tensor: if state_tensor.dim() == 1: state_tensor = state_tensor.unsqueeze(0) batch_size = state_tensor.size(0) edge_block = self.total_edge_count * self.edge_feature_dim speed_block_start = edge_block speed_block_end = speed_block_start + self.total_edge_count global_block_start = speed_block_end global_block_end = global_block_start + self.time_feature_dim + self.last_reward_dim edge_features = state_tensor[:, :edge_block].view(batch_size, self.total_edge_count, self.edge_feature_dim) edge_features = edge_features[:, self.controlled_start_index:self.controlled_end_index, :] local_speed_limits = state_tensor[:, speed_block_start:speed_block_end].view(batch_size, self.total_edge_count, 1) local_speed_limits = local_speed_limits[:, self.controlled_start_index:self.controlled_end_index, :] global_features = state_tensor[:, global_block_start:global_block_end].unsqueeze(1) global_features = global_features.expand(-1, self.num_agents, -1) agent_ids = self.agent_id_features.expand(batch_size, -1, -1) return torch.cat([edge_features, local_speed_limits, global_features, agent_ids], dim=-1) def _compute_agent_q_values_with_net( self, state_tensor: torch.Tensor, utility_net: nn.Module, ) -> torch.Tensor: local_obs = self._build_local_obs(state_tensor) batch_size = local_obs.size(0) q_values = utility_net(local_obs.view(batch_size * self.num_agents, self.local_obs_dim)) return q_values.view(batch_size, self.num_agents, self.num_actions_per_agent) def _compute_agent_q_values(self, state_tensor: torch.Tensor) -> torch.Tensor: return self._compute_agent_q_values_with_net(state_tensor, self.utility_net) def _compute_target_next_agent_q(self, next_states_t: torch.Tensor) -> torch.Tensor: target_next_q_all = self._compute_agent_q_values_with_net(next_states_t, self.target_utility_net) return target_next_q_all.max(dim=2)[0] def _greedy_action(self, state: np.ndarray) -> np.ndarray: state_tensor = torch.as_tensor(state, dtype=torch.float32, device=self.device).unsqueeze(0) with torch.no_grad(): q_values = self._compute_agent_q_values(state_tensor) return q_values.argmax(dim=2).squeeze(0).cpu().numpy().astype(np.int64) def select_action( self, state: np.ndarray, deterministic: bool = False, ) -> Tuple[np.ndarray, float, float]: if deterministic: return self._greedy_action(state), 0.0, 0.0 epsilon = self._current_epsilon() self.env_steps += 1 if random.random() < epsilon: action = np.array( [random.randrange(self.num_actions_per_agent) for _ in range(self.num_agents)], dtype=np.int64, ) else: action = self._greedy_action(state) return action, 0.0, 0.0 def store_transition(self, state, action, reward, next_state, done): transition = ( np.asarray(state, dtype=np.float32), np.asarray(action, dtype=np.int64), float(reward), np.asarray(next_state, dtype=np.float32), bool(done), ) self.replay_buffer.append(transition) def _sync_target_networks(self): self.target_utility_net.load_state_dict(self.utility_net.state_dict()) self.target_mixer.load_state_dict(self.mixer.state_dict()) def _sample_batch(self): batch = random.sample(self.replay_buffer, self.batch_size) states, actions, rewards, next_states, dones = zip(*batch) states_t = torch.as_tensor(np.array(states), dtype=torch.float32, device=self.device) actions_t = torch.as_tensor(np.array(actions), dtype=torch.long, device=self.device) rewards_t = torch.as_tensor(rewards, dtype=torch.float32, device=self.device).unsqueeze(1) next_states_t = torch.as_tensor(np.array(next_states), dtype=torch.float32, device=self.device) dones_t = torch.as_tensor(dones, dtype=torch.float32, device=self.device).unsqueeze(1) return states_t, actions_t, rewards_t, next_states_t, dones_t def update(self) -> Dict[str, float]: if len(self.replay_buffer) < self.batch_size: return {} states_t, actions_t, rewards_t, next_states_t, dones_t = self._sample_batch() current_agent_q_all = self._compute_agent_q_values(states_t) current_agent_q = current_agent_q_all.gather(2, actions_t.unsqueeze(2)).squeeze(2) current_q_tot = self.mixer(current_agent_q, states_t) with torch.no_grad(): target_next_agent_q = self._compute_target_next_agent_q(next_states_t) target_q_tot = rewards_t + (1.0 - dones_t) * self.gamma * self.target_mixer( target_next_agent_q, next_states_t, ) loss = nn.MSELoss()(current_q_tot, target_q_tot) self.optimizer.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm_( list(self.utility_net.parameters()) + list(self.mixer.parameters()), 1.0, ) self.optimizer.step() self.update_steps += 1 if self.update_steps % self.target_update == 0: self._sync_target_networks() value_loss = float(loss.item()) return {"loss": value_loss, "value_loss": value_loss} def save(self, path: str): torch.save( { "utility_net": self.utility_net.state_dict(), "target_utility_net": self.target_utility_net.state_dict(), "mixer": self.mixer.state_dict(), "target_mixer": self.target_mixer.state_dict(), "optimizer": self.optimizer.state_dict(), "env_steps": int(self.env_steps), "update_steps": int(self.update_steps), }, path, ) def load(self, path: str): checkpoint = torch.load(path, map_location=self.device, weights_only=False) self.utility_net.load_state_dict(checkpoint["utility_net"]) self.target_utility_net.load_state_dict(checkpoint["target_utility_net"]) self.mixer.load_state_dict(checkpoint["mixer"]) self.target_mixer.load_state_dict(checkpoint["target_mixer"]) optimizer_state = checkpoint.get("optimizer") if optimizer_state is not None: self.optimizer.load_state_dict(optimizer_state) self.env_steps = int(checkpoint.get("env_steps", 0)) self.update_steps = int(checkpoint.get("update_steps", 0))