"""Directional Corridor QMIX for ordered motorway VSL control.""" from __future__ import annotations from typing import Tuple import numpy as np import torch import torch.nn as nn from agents.dcmappo_agent import DualPathCorridorBlock, SelectiveFeatureFusion from agents.qmix_agent import QMIXAgent class DirectionalUtilityNetwork(nn.Module): """Corridor-aware shared utility network with directional context mixing.""" def __init__( self, edge_token_dim: int, condition_dim: int, num_agents: int, num_actions: int, hidden_dim: int = 256, num_blocks: int = 2, kernel_size: int = 5, dropout: float = 0.05, ): super().__init__() self.fusion = SelectiveFeatureFusion( component_dim=edge_token_dim, condition_dim=condition_dim, hidden_dim=hidden_dim, num_components=5, ) self.position_embedding = nn.Parameter(torch.zeros(1, num_agents, hidden_dim)) self.blocks = nn.ModuleList( [DualPathCorridorBlock(hidden_dim, kernel_size=kernel_size, dropout=dropout) for _ in range(num_blocks)] ) self.head_norm = nn.LayerNorm(hidden_dim) self.head = nn.Linear(hidden_dim, num_actions) self._init_weights() def _init_weights(self): nn.init.normal_(self.position_embedding, mean=0.0, std=0.02) nn.init.orthogonal_(self.head.weight, gain=0.01) nn.init.constant_(self.head.bias, 0.0) def forward(self, components: torch.Tensor, condition: torch.Tensor) -> torch.Tensor: batch_size, num_agents, _, _ = components.shape fused = self.fusion( components.view(batch_size * num_agents, components.size(2), components.size(3)), condition.view(batch_size * num_agents, condition.size(-1)), ) x = fused.view(batch_size, num_agents, -1) + self.position_embedding for block in self.blocks: x = block(x) return self.head(self.head_norm(x)) class DirectionalMixerStateEncoder(nn.Module): """Encode ordered corridor state before hyper-network generation.""" def __init__( self, num_agents: int, edge_token_dim: int, global_feature_dim: int, hidden_dim: int = 256, num_blocks: int = 2, kernel_size: int = 5, dropout: float = 0.05, ): super().__init__() self.edge_proj = nn.Sequential( nn.Linear(edge_token_dim, hidden_dim), nn.LayerNorm(hidden_dim), nn.GELU(), ) self.global_film = nn.Sequential( nn.Linear(global_feature_dim, hidden_dim * 2), nn.GELU(), nn.Linear(hidden_dim * 2, hidden_dim * 2), ) self.position_embedding = nn.Parameter(torch.zeros(1, num_agents, hidden_dim)) self.blocks = nn.ModuleList( [DualPathCorridorBlock(hidden_dim, kernel_size=kernel_size, dropout=dropout) for _ in range(num_blocks)] ) self.output_proj = nn.Sequential( nn.Linear(hidden_dim * 2 + global_feature_dim, hidden_dim), nn.LayerNorm(hidden_dim), nn.GELU(), ) self._init_weights() def _init_weights(self): for module in self.modules(): if isinstance(module, nn.Linear): nn.init.orthogonal_(module.weight, gain=np.sqrt(2)) nn.init.constant_(module.bias, 0.0) nn.init.normal_(self.position_embedding, mean=0.0, std=0.02) def forward(self, edge_tokens: torch.Tensor, global_features: torch.Tensor) -> torch.Tensor: x = self.edge_proj(edge_tokens) gamma, beta = self.global_film(global_features).chunk(2, dim=-1) gamma = 0.1 * torch.tanh(gamma).unsqueeze(1) beta = 0.1 * beta.unsqueeze(1) x = x * (1.0 + gamma) + beta x = x + self.position_embedding for block in self.blocks: x = block(x) pooled_mean = x.mean(dim=1) pooled_max = x.max(dim=1).values return self.output_proj(torch.cat([pooled_mean, pooled_max, global_features], dim=-1)) class DirectionalQMixer(nn.Module): """QMIX mixer conditioned on a structured corridor state embedding.""" def __init__( self, num_agents: int, edge_feature_dim: int, time_feature_dim: int, total_edge_count: int, controlled_start_index: int, mixing_hidden_dim: int = 256, state_hidden_dim: int = 256, num_corridor_blocks: int = 2, corridor_kernel_size: int = 5, corridor_dropout: float = 0.05, ): super().__init__() self.num_agents = num_agents self.edge_feature_dim = edge_feature_dim self.time_feature_dim = time_feature_dim self.speed_feature_dim = 1 self.last_reward_dim = 1 self.total_edge_count = total_edge_count self.controlled_start_index = controlled_start_index self.controlled_end_index = self.controlled_start_index + self.num_agents self.global_feature_dim = self.time_feature_dim + self.last_reward_dim self.edge_token_dim = self.edge_feature_dim + self.speed_feature_dim self.state_encoder = DirectionalMixerStateEncoder( num_agents=self.num_agents, edge_token_dim=self.edge_token_dim, global_feature_dim=self.global_feature_dim, hidden_dim=state_hidden_dim, num_blocks=max(1, num_corridor_blocks), kernel_size=corridor_kernel_size, dropout=corridor_dropout, ) self.hyper_w1 = nn.Sequential( nn.Linear(state_hidden_dim, mixing_hidden_dim), nn.ReLU(), nn.Linear(mixing_hidden_dim, num_agents * mixing_hidden_dim), ) self.hyper_b1 = nn.Linear(state_hidden_dim, mixing_hidden_dim) self.hyper_w2 = nn.Sequential( nn.Linear(state_hidden_dim, mixing_hidden_dim), nn.ReLU(), nn.Linear(mixing_hidden_dim, mixing_hidden_dim), ) self.hyper_b2 = nn.Sequential( nn.Linear(state_hidden_dim, mixing_hidden_dim), nn.ReLU(), nn.Linear(mixing_hidden_dim, 1), ) self.activation = nn.ELU() self._init_weights() def _init_weights(self): for module in self.modules(): if isinstance(module, nn.Linear): nn.init.orthogonal_(module.weight, gain=np.sqrt(2)) nn.init.constant_(module.bias, 0.0) nn.init.orthogonal_(self.hyper_b2[-1].weight, gain=1.0) def _extract_state_tokens(self, state_tensor: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: if state_tensor.dim() == 1: state_tensor = state_tensor.unsqueeze(0) batch_size = state_tensor.size(0) edge_block = self.total_edge_count * self.edge_feature_dim speed_block_start = edge_block speed_block_end = speed_block_start + self.total_edge_count global_block_start = speed_block_end global_block_end = global_block_start + self.global_feature_dim edge_features = state_tensor[:, :edge_block].view(batch_size, self.total_edge_count, self.edge_feature_dim) edge_features = edge_features[:, self.controlled_start_index:self.controlled_end_index, :] edge_speed_limits = state_tensor[:, speed_block_start:speed_block_end].view(batch_size, self.total_edge_count, 1) edge_speed_limits = edge_speed_limits[:, self.controlled_start_index:self.controlled_end_index, :] global_features = state_tensor[:, global_block_start:global_block_end] edge_tokens = torch.cat([edge_features, edge_speed_limits], dim=-1) return edge_tokens, global_features def forward(self, agent_qs: torch.Tensor, state: torch.Tensor) -> torch.Tensor: edge_tokens, global_features = self._extract_state_tokens(state) state_embed = self.state_encoder(edge_tokens, global_features) batch_size = agent_qs.size(0) w1 = torch.abs(self.hyper_w1(state_embed)).view(batch_size, self.num_agents, -1) b1 = self.hyper_b1(state_embed).view(batch_size, 1, -1) hidden = self.activation(torch.bmm(agent_qs.unsqueeze(1), w1) + b1) w2 = torch.abs(self.hyper_w2(state_embed)).view(batch_size, -1, 1) b2 = self.hyper_b2(state_embed).view(batch_size, 1, 1) return torch.bmm(hidden, w2).squeeze(1) + b2.squeeze(1) class DCQMIXAgent(QMIXAgent): """Directional corridor variant of QMIX with Double-Q target action selection.""" def __init__( self, state_dim: int, num_edges: int, num_actions_per_edge: int, hidden_dim: int = 256, mixing_hidden_dim: int = 256, learning_rate: float = 1e-3, gamma: float = 0.99, epsilon_start: float = 1.0, epsilon_end: float = 0.01, epsilon_decay: int = 200, buffer_size: int = 10000, batch_size: int = 64, target_update: int = 10, device: str = "cuda", edge_feature_dim: int = 3, time_feature_dim: int = 3, total_edge_count: int | None = None, controlled_start_index: int = 0, num_corridor_blocks: int = 2, corridor_kernel_size: int = 5, corridor_dropout: float = 0.05, ): self.num_corridor_blocks = int(num_corridor_blocks) self.corridor_kernel_size = int(corridor_kernel_size) self.corridor_dropout = float(corridor_dropout) super().__init__( state_dim=state_dim, num_edges=num_edges, num_actions_per_edge=num_actions_per_edge, hidden_dim=hidden_dim, mixing_hidden_dim=mixing_hidden_dim, learning_rate=learning_rate, gamma=gamma, epsilon_start=epsilon_start, epsilon_end=epsilon_end, epsilon_decay=epsilon_decay, buffer_size=buffer_size, batch_size=batch_size, target_update=target_update, device=device, edge_feature_dim=edge_feature_dim, time_feature_dim=time_feature_dim, total_edge_count=total_edge_count, controlled_start_index=controlled_start_index, ) def _build_utility_network(self, hidden_dim: int) -> nn.Module: return DirectionalUtilityNetwork( edge_token_dim=self.edge_feature_dim + self.speed_feature_dim, condition_dim=self.time_feature_dim + self.last_reward_dim + self.agent_id_dim, num_agents=self.num_agents, num_actions=self.num_actions_per_agent, hidden_dim=hidden_dim, num_blocks=self.num_corridor_blocks, kernel_size=self.corridor_kernel_size, dropout=self.corridor_dropout, ) def _build_mixer(self, mixing_hidden_dim: int) -> nn.Module: return DirectionalQMixer( num_agents=self.num_agents, edge_feature_dim=self.edge_feature_dim, time_feature_dim=self.time_feature_dim, total_edge_count=self.total_edge_count, controlled_start_index=self.controlled_start_index, mixing_hidden_dim=mixing_hidden_dim, state_hidden_dim=mixing_hidden_dim, num_corridor_blocks=self.num_corridor_blocks, corridor_kernel_size=self.corridor_kernel_size, corridor_dropout=self.corridor_dropout, ) @staticmethod def _shift_left(x: torch.Tensor) -> torch.Tensor: return torch.cat([x[:, :1, :], x[:, :-1, :]], dim=1) @staticmethod def _shift_right(x: torch.Tensor) -> torch.Tensor: return torch.cat([x[:, 1:, :], x[:, -1:, :]], dim=1) def _parse_state(self, state_tensor: torch.Tensor): if state_tensor.dim() == 1: state_tensor = state_tensor.unsqueeze(0) batch_size = state_tensor.size(0) edge_block = self.total_edge_count * self.edge_feature_dim speed_block_start = edge_block speed_block_end = speed_block_start + self.total_edge_count global_block_start = speed_block_end global_block_end = global_block_start + self.time_feature_dim + self.last_reward_dim edge_features = state_tensor[:, :edge_block].view(batch_size, self.total_edge_count, self.edge_feature_dim) edge_features = edge_features[:, self.controlled_start_index:self.controlled_end_index, :] local_speed_limits = state_tensor[:, speed_block_start:speed_block_end].view(batch_size, self.total_edge_count, 1) local_speed_limits = local_speed_limits[:, self.controlled_start_index:self.controlled_end_index, :] edge_tokens = torch.cat([edge_features, local_speed_limits], dim=-1) global_features = state_tensor[:, global_block_start:global_block_end] agent_ids = self.agent_id_features.expand(batch_size, -1, -1) condition = torch.cat( [global_features.unsqueeze(1).expand(-1, self.num_agents, -1), agent_ids], dim=-1, ) return edge_tokens, condition def _compute_agent_q_values_with_net( self, state_tensor: torch.Tensor, utility_net: nn.Module, ) -> torch.Tensor: edge_tokens, condition = self._parse_state(state_tensor) left_tokens = self._shift_left(edge_tokens) right_tokens = self._shift_right(edge_tokens) delta_left = edge_tokens - left_tokens delta_right = right_tokens - edge_tokens components = torch.stack( [edge_tokens, left_tokens, right_tokens, delta_left, delta_right], dim=2, ) return utility_net(components, condition) def _compute_target_next_agent_q(self, next_states_t: torch.Tensor) -> torch.Tensor: online_next_q_all = self._compute_agent_q_values(next_states_t) next_actions = online_next_q_all.argmax(dim=2, keepdim=True) target_next_q_all = self._compute_agent_q_values_with_net(next_states_t, self.target_utility_net) return target_next_q_all.gather(2, next_actions).squeeze(2)