ctm-dqn/agents/dcqmix_agent.py

341 lines
14 KiB
Python

"""Directional Corridor QMIX for ordered motorway VSL control."""
from __future__ import annotations
from typing import Tuple
import numpy as np
import torch
import torch.nn as nn
from agents.dcmappo_agent import DualPathCorridorBlock, SelectiveFeatureFusion
from agents.qmix_agent import QMIXAgent
class DirectionalUtilityNetwork(nn.Module):
"""Corridor-aware shared utility network with directional context mixing."""
def __init__(
self,
edge_token_dim: int,
condition_dim: int,
num_agents: int,
num_actions: int,
hidden_dim: int = 256,
num_blocks: int = 2,
kernel_size: int = 5,
dropout: float = 0.05,
):
super().__init__()
self.fusion = SelectiveFeatureFusion(
component_dim=edge_token_dim,
condition_dim=condition_dim,
hidden_dim=hidden_dim,
num_components=5,
)
self.position_embedding = nn.Parameter(torch.zeros(1, num_agents, hidden_dim))
self.blocks = nn.ModuleList(
[DualPathCorridorBlock(hidden_dim, kernel_size=kernel_size, dropout=dropout) for _ in range(num_blocks)]
)
self.head_norm = nn.LayerNorm(hidden_dim)
self.head = nn.Linear(hidden_dim, num_actions)
self._init_weights()
def _init_weights(self):
nn.init.normal_(self.position_embedding, mean=0.0, std=0.02)
nn.init.orthogonal_(self.head.weight, gain=0.01)
nn.init.constant_(self.head.bias, 0.0)
def forward(self, components: torch.Tensor, condition: torch.Tensor) -> torch.Tensor:
batch_size, num_agents, _, _ = components.shape
fused = self.fusion(
components.view(batch_size * num_agents, components.size(2), components.size(3)),
condition.view(batch_size * num_agents, condition.size(-1)),
)
x = fused.view(batch_size, num_agents, -1) + self.position_embedding
for block in self.blocks:
x = block(x)
return self.head(self.head_norm(x))
class DirectionalMixerStateEncoder(nn.Module):
"""Encode ordered corridor state before hyper-network generation."""
def __init__(
self,
num_agents: int,
edge_token_dim: int,
global_feature_dim: int,
hidden_dim: int = 256,
num_blocks: int = 2,
kernel_size: int = 5,
dropout: float = 0.05,
):
super().__init__()
self.edge_proj = nn.Sequential(
nn.Linear(edge_token_dim, hidden_dim),
nn.LayerNorm(hidden_dim),
nn.GELU(),
)
self.global_film = nn.Sequential(
nn.Linear(global_feature_dim, hidden_dim * 2),
nn.GELU(),
nn.Linear(hidden_dim * 2, hidden_dim * 2),
)
self.position_embedding = nn.Parameter(torch.zeros(1, num_agents, hidden_dim))
self.blocks = nn.ModuleList(
[DualPathCorridorBlock(hidden_dim, kernel_size=kernel_size, dropout=dropout) for _ in range(num_blocks)]
)
self.output_proj = nn.Sequential(
nn.Linear(hidden_dim * 2 + global_feature_dim, hidden_dim),
nn.LayerNorm(hidden_dim),
nn.GELU(),
)
self._init_weights()
def _init_weights(self):
for module in self.modules():
if isinstance(module, nn.Linear):
nn.init.orthogonal_(module.weight, gain=np.sqrt(2))
nn.init.constant_(module.bias, 0.0)
nn.init.normal_(self.position_embedding, mean=0.0, std=0.02)
def forward(self, edge_tokens: torch.Tensor, global_features: torch.Tensor) -> torch.Tensor:
x = self.edge_proj(edge_tokens)
gamma, beta = self.global_film(global_features).chunk(2, dim=-1)
gamma = 0.1 * torch.tanh(gamma).unsqueeze(1)
beta = 0.1 * beta.unsqueeze(1)
x = x * (1.0 + gamma) + beta
x = x + self.position_embedding
for block in self.blocks:
x = block(x)
pooled_mean = x.mean(dim=1)
pooled_max = x.max(dim=1).values
return self.output_proj(torch.cat([pooled_mean, pooled_max, global_features], dim=-1))
class DirectionalQMixer(nn.Module):
"""QMIX mixer conditioned on a structured corridor state embedding."""
def __init__(
self,
num_agents: int,
edge_feature_dim: int,
time_feature_dim: int,
total_edge_count: int,
controlled_start_index: int,
mixing_hidden_dim: int = 256,
state_hidden_dim: int = 256,
num_corridor_blocks: int = 2,
corridor_kernel_size: int = 5,
corridor_dropout: float = 0.05,
):
super().__init__()
self.num_agents = num_agents
self.edge_feature_dim = edge_feature_dim
self.time_feature_dim = time_feature_dim
self.speed_feature_dim = 1
self.last_reward_dim = 1
self.total_edge_count = total_edge_count
self.controlled_start_index = controlled_start_index
self.controlled_end_index = self.controlled_start_index + self.num_agents
self.global_feature_dim = self.time_feature_dim + self.last_reward_dim
self.edge_token_dim = self.edge_feature_dim + self.speed_feature_dim
self.state_encoder = DirectionalMixerStateEncoder(
num_agents=self.num_agents,
edge_token_dim=self.edge_token_dim,
global_feature_dim=self.global_feature_dim,
hidden_dim=state_hidden_dim,
num_blocks=max(1, num_corridor_blocks),
kernel_size=corridor_kernel_size,
dropout=corridor_dropout,
)
self.hyper_w1 = nn.Sequential(
nn.Linear(state_hidden_dim, mixing_hidden_dim),
nn.ReLU(),
nn.Linear(mixing_hidden_dim, num_agents * mixing_hidden_dim),
)
self.hyper_b1 = nn.Linear(state_hidden_dim, mixing_hidden_dim)
self.hyper_w2 = nn.Sequential(
nn.Linear(state_hidden_dim, mixing_hidden_dim),
nn.ReLU(),
nn.Linear(mixing_hidden_dim, mixing_hidden_dim),
)
self.hyper_b2 = nn.Sequential(
nn.Linear(state_hidden_dim, mixing_hidden_dim),
nn.ReLU(),
nn.Linear(mixing_hidden_dim, 1),
)
self.activation = nn.ELU()
self._init_weights()
def _init_weights(self):
for module in self.modules():
if isinstance(module, nn.Linear):
nn.init.orthogonal_(module.weight, gain=np.sqrt(2))
nn.init.constant_(module.bias, 0.0)
nn.init.orthogonal_(self.hyper_b2[-1].weight, gain=1.0)
def _extract_state_tokens(self, state_tensor: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
if state_tensor.dim() == 1:
state_tensor = state_tensor.unsqueeze(0)
batch_size = state_tensor.size(0)
edge_block = self.total_edge_count * self.edge_feature_dim
speed_block_start = edge_block
speed_block_end = speed_block_start + self.total_edge_count
global_block_start = speed_block_end
global_block_end = global_block_start + self.global_feature_dim
edge_features = state_tensor[:, :edge_block].view(batch_size, self.total_edge_count, self.edge_feature_dim)
edge_features = edge_features[:, self.controlled_start_index:self.controlled_end_index, :]
edge_speed_limits = state_tensor[:, speed_block_start:speed_block_end].view(batch_size, self.total_edge_count, 1)
edge_speed_limits = edge_speed_limits[:, self.controlled_start_index:self.controlled_end_index, :]
global_features = state_tensor[:, global_block_start:global_block_end]
edge_tokens = torch.cat([edge_features, edge_speed_limits], dim=-1)
return edge_tokens, global_features
def forward(self, agent_qs: torch.Tensor, state: torch.Tensor) -> torch.Tensor:
edge_tokens, global_features = self._extract_state_tokens(state)
state_embed = self.state_encoder(edge_tokens, global_features)
batch_size = agent_qs.size(0)
w1 = torch.abs(self.hyper_w1(state_embed)).view(batch_size, self.num_agents, -1)
b1 = self.hyper_b1(state_embed).view(batch_size, 1, -1)
hidden = self.activation(torch.bmm(agent_qs.unsqueeze(1), w1) + b1)
w2 = torch.abs(self.hyper_w2(state_embed)).view(batch_size, -1, 1)
b2 = self.hyper_b2(state_embed).view(batch_size, 1, 1)
return torch.bmm(hidden, w2).squeeze(1) + b2.squeeze(1)
class DCQMIXAgent(QMIXAgent):
"""Directional corridor variant of QMIX with Double-Q target action selection."""
def __init__(
self,
state_dim: int,
num_edges: int,
num_actions_per_edge: int,
hidden_dim: int = 256,
mixing_hidden_dim: int = 256,
learning_rate: float = 1e-3,
gamma: float = 0.99,
epsilon_start: float = 1.0,
epsilon_end: float = 0.01,
epsilon_decay: int = 200,
buffer_size: int = 10000,
batch_size: int = 64,
target_update: int = 10,
device: str = "cuda",
edge_feature_dim: int = 3,
time_feature_dim: int = 3,
total_edge_count: int | None = None,
controlled_start_index: int = 0,
num_corridor_blocks: int = 2,
corridor_kernel_size: int = 5,
corridor_dropout: float = 0.05,
):
self.num_corridor_blocks = int(num_corridor_blocks)
self.corridor_kernel_size = int(corridor_kernel_size)
self.corridor_dropout = float(corridor_dropout)
super().__init__(
state_dim=state_dim,
num_edges=num_edges,
num_actions_per_edge=num_actions_per_edge,
hidden_dim=hidden_dim,
mixing_hidden_dim=mixing_hidden_dim,
learning_rate=learning_rate,
gamma=gamma,
epsilon_start=epsilon_start,
epsilon_end=epsilon_end,
epsilon_decay=epsilon_decay,
buffer_size=buffer_size,
batch_size=batch_size,
target_update=target_update,
device=device,
edge_feature_dim=edge_feature_dim,
time_feature_dim=time_feature_dim,
total_edge_count=total_edge_count,
controlled_start_index=controlled_start_index,
)
def _build_utility_network(self, hidden_dim: int) -> nn.Module:
return DirectionalUtilityNetwork(
edge_token_dim=self.edge_feature_dim + self.speed_feature_dim,
condition_dim=self.time_feature_dim + self.last_reward_dim + self.agent_id_dim,
num_agents=self.num_agents,
num_actions=self.num_actions_per_agent,
hidden_dim=hidden_dim,
num_blocks=self.num_corridor_blocks,
kernel_size=self.corridor_kernel_size,
dropout=self.corridor_dropout,
)
def _build_mixer(self, mixing_hidden_dim: int) -> nn.Module:
return DirectionalQMixer(
num_agents=self.num_agents,
edge_feature_dim=self.edge_feature_dim,
time_feature_dim=self.time_feature_dim,
total_edge_count=self.total_edge_count,
controlled_start_index=self.controlled_start_index,
mixing_hidden_dim=mixing_hidden_dim,
state_hidden_dim=mixing_hidden_dim,
num_corridor_blocks=self.num_corridor_blocks,
corridor_kernel_size=self.corridor_kernel_size,
corridor_dropout=self.corridor_dropout,
)
@staticmethod
def _shift_left(x: torch.Tensor) -> torch.Tensor:
return torch.cat([x[:, :1, :], x[:, :-1, :]], dim=1)
@staticmethod
def _shift_right(x: torch.Tensor) -> torch.Tensor:
return torch.cat([x[:, 1:, :], x[:, -1:, :]], dim=1)
def _parse_state(self, state_tensor: torch.Tensor):
if state_tensor.dim() == 1:
state_tensor = state_tensor.unsqueeze(0)
batch_size = state_tensor.size(0)
edge_block = self.total_edge_count * self.edge_feature_dim
speed_block_start = edge_block
speed_block_end = speed_block_start + self.total_edge_count
global_block_start = speed_block_end
global_block_end = global_block_start + self.time_feature_dim + self.last_reward_dim
edge_features = state_tensor[:, :edge_block].view(batch_size, self.total_edge_count, self.edge_feature_dim)
edge_features = edge_features[:, self.controlled_start_index:self.controlled_end_index, :]
local_speed_limits = state_tensor[:, speed_block_start:speed_block_end].view(batch_size, self.total_edge_count, 1)
local_speed_limits = local_speed_limits[:, self.controlled_start_index:self.controlled_end_index, :]
edge_tokens = torch.cat([edge_features, local_speed_limits], dim=-1)
global_features = state_tensor[:, global_block_start:global_block_end]
agent_ids = self.agent_id_features.expand(batch_size, -1, -1)
condition = torch.cat(
[global_features.unsqueeze(1).expand(-1, self.num_agents, -1), agent_ids],
dim=-1,
)
return edge_tokens, condition
def _compute_agent_q_values_with_net(
self,
state_tensor: torch.Tensor,
utility_net: nn.Module,
) -> torch.Tensor:
edge_tokens, condition = self._parse_state(state_tensor)
left_tokens = self._shift_left(edge_tokens)
right_tokens = self._shift_right(edge_tokens)
delta_left = edge_tokens - left_tokens
delta_right = right_tokens - edge_tokens
components = torch.stack(
[edge_tokens, left_tokens, right_tokens, delta_left, delta_right],
dim=2,
)
return utility_net(components, condition)
def _compute_target_next_agent_q(self, next_states_t: torch.Tensor) -> torch.Tensor:
online_next_q_all = self._compute_agent_q_values(next_states_t)
next_actions = online_next_q_all.argmax(dim=2, keepdim=True)
target_next_q_all = self._compute_agent_q_values_with_net(next_states_t, self.target_utility_net)
return target_next_q_all.gather(2, next_actions).squeeze(2)