341 lines
14 KiB
Python
341 lines
14 KiB
Python
"""Directional Corridor QMIX for ordered motorway VSL control."""
|
|
from __future__ import annotations
|
|
|
|
from typing import Tuple
|
|
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
from agents.dcmappo_agent import DualPathCorridorBlock, SelectiveFeatureFusion
|
|
from agents.qmix_agent import QMIXAgent
|
|
|
|
|
|
class DirectionalUtilityNetwork(nn.Module):
|
|
"""Corridor-aware shared utility network with directional context mixing."""
|
|
|
|
def __init__(
|
|
self,
|
|
edge_token_dim: int,
|
|
condition_dim: int,
|
|
num_agents: int,
|
|
num_actions: int,
|
|
hidden_dim: int = 256,
|
|
num_blocks: int = 2,
|
|
kernel_size: int = 5,
|
|
dropout: float = 0.05,
|
|
):
|
|
super().__init__()
|
|
self.fusion = SelectiveFeatureFusion(
|
|
component_dim=edge_token_dim,
|
|
condition_dim=condition_dim,
|
|
hidden_dim=hidden_dim,
|
|
num_components=5,
|
|
)
|
|
self.position_embedding = nn.Parameter(torch.zeros(1, num_agents, hidden_dim))
|
|
self.blocks = nn.ModuleList(
|
|
[DualPathCorridorBlock(hidden_dim, kernel_size=kernel_size, dropout=dropout) for _ in range(num_blocks)]
|
|
)
|
|
self.head_norm = nn.LayerNorm(hidden_dim)
|
|
self.head = nn.Linear(hidden_dim, num_actions)
|
|
self._init_weights()
|
|
|
|
def _init_weights(self):
|
|
nn.init.normal_(self.position_embedding, mean=0.0, std=0.02)
|
|
nn.init.orthogonal_(self.head.weight, gain=0.01)
|
|
nn.init.constant_(self.head.bias, 0.0)
|
|
|
|
def forward(self, components: torch.Tensor, condition: torch.Tensor) -> torch.Tensor:
|
|
batch_size, num_agents, _, _ = components.shape
|
|
fused = self.fusion(
|
|
components.view(batch_size * num_agents, components.size(2), components.size(3)),
|
|
condition.view(batch_size * num_agents, condition.size(-1)),
|
|
)
|
|
x = fused.view(batch_size, num_agents, -1) + self.position_embedding
|
|
for block in self.blocks:
|
|
x = block(x)
|
|
return self.head(self.head_norm(x))
|
|
|
|
|
|
class DirectionalMixerStateEncoder(nn.Module):
|
|
"""Encode ordered corridor state before hyper-network generation."""
|
|
|
|
def __init__(
|
|
self,
|
|
num_agents: int,
|
|
edge_token_dim: int,
|
|
global_feature_dim: int,
|
|
hidden_dim: int = 256,
|
|
num_blocks: int = 2,
|
|
kernel_size: int = 5,
|
|
dropout: float = 0.05,
|
|
):
|
|
super().__init__()
|
|
self.edge_proj = nn.Sequential(
|
|
nn.Linear(edge_token_dim, hidden_dim),
|
|
nn.LayerNorm(hidden_dim),
|
|
nn.GELU(),
|
|
)
|
|
self.global_film = nn.Sequential(
|
|
nn.Linear(global_feature_dim, hidden_dim * 2),
|
|
nn.GELU(),
|
|
nn.Linear(hidden_dim * 2, hidden_dim * 2),
|
|
)
|
|
self.position_embedding = nn.Parameter(torch.zeros(1, num_agents, hidden_dim))
|
|
self.blocks = nn.ModuleList(
|
|
[DualPathCorridorBlock(hidden_dim, kernel_size=kernel_size, dropout=dropout) for _ in range(num_blocks)]
|
|
)
|
|
self.output_proj = nn.Sequential(
|
|
nn.Linear(hidden_dim * 2 + global_feature_dim, hidden_dim),
|
|
nn.LayerNorm(hidden_dim),
|
|
nn.GELU(),
|
|
)
|
|
self._init_weights()
|
|
|
|
def _init_weights(self):
|
|
for module in self.modules():
|
|
if isinstance(module, nn.Linear):
|
|
nn.init.orthogonal_(module.weight, gain=np.sqrt(2))
|
|
nn.init.constant_(module.bias, 0.0)
|
|
nn.init.normal_(self.position_embedding, mean=0.0, std=0.02)
|
|
|
|
def forward(self, edge_tokens: torch.Tensor, global_features: torch.Tensor) -> torch.Tensor:
|
|
x = self.edge_proj(edge_tokens)
|
|
gamma, beta = self.global_film(global_features).chunk(2, dim=-1)
|
|
gamma = 0.1 * torch.tanh(gamma).unsqueeze(1)
|
|
beta = 0.1 * beta.unsqueeze(1)
|
|
x = x * (1.0 + gamma) + beta
|
|
x = x + self.position_embedding
|
|
for block in self.blocks:
|
|
x = block(x)
|
|
pooled_mean = x.mean(dim=1)
|
|
pooled_max = x.max(dim=1).values
|
|
return self.output_proj(torch.cat([pooled_mean, pooled_max, global_features], dim=-1))
|
|
|
|
|
|
class DirectionalQMixer(nn.Module):
|
|
"""QMIX mixer conditioned on a structured corridor state embedding."""
|
|
|
|
def __init__(
|
|
self,
|
|
num_agents: int,
|
|
edge_feature_dim: int,
|
|
time_feature_dim: int,
|
|
total_edge_count: int,
|
|
controlled_start_index: int,
|
|
mixing_hidden_dim: int = 256,
|
|
state_hidden_dim: int = 256,
|
|
num_corridor_blocks: int = 2,
|
|
corridor_kernel_size: int = 5,
|
|
corridor_dropout: float = 0.05,
|
|
):
|
|
super().__init__()
|
|
self.num_agents = num_agents
|
|
self.edge_feature_dim = edge_feature_dim
|
|
self.time_feature_dim = time_feature_dim
|
|
self.speed_feature_dim = 1
|
|
self.last_reward_dim = 1
|
|
self.total_edge_count = total_edge_count
|
|
self.controlled_start_index = controlled_start_index
|
|
self.controlled_end_index = self.controlled_start_index + self.num_agents
|
|
self.global_feature_dim = self.time_feature_dim + self.last_reward_dim
|
|
self.edge_token_dim = self.edge_feature_dim + self.speed_feature_dim
|
|
self.state_encoder = DirectionalMixerStateEncoder(
|
|
num_agents=self.num_agents,
|
|
edge_token_dim=self.edge_token_dim,
|
|
global_feature_dim=self.global_feature_dim,
|
|
hidden_dim=state_hidden_dim,
|
|
num_blocks=max(1, num_corridor_blocks),
|
|
kernel_size=corridor_kernel_size,
|
|
dropout=corridor_dropout,
|
|
)
|
|
self.hyper_w1 = nn.Sequential(
|
|
nn.Linear(state_hidden_dim, mixing_hidden_dim),
|
|
nn.ReLU(),
|
|
nn.Linear(mixing_hidden_dim, num_agents * mixing_hidden_dim),
|
|
)
|
|
self.hyper_b1 = nn.Linear(state_hidden_dim, mixing_hidden_dim)
|
|
self.hyper_w2 = nn.Sequential(
|
|
nn.Linear(state_hidden_dim, mixing_hidden_dim),
|
|
nn.ReLU(),
|
|
nn.Linear(mixing_hidden_dim, mixing_hidden_dim),
|
|
)
|
|
self.hyper_b2 = nn.Sequential(
|
|
nn.Linear(state_hidden_dim, mixing_hidden_dim),
|
|
nn.ReLU(),
|
|
nn.Linear(mixing_hidden_dim, 1),
|
|
)
|
|
self.activation = nn.ELU()
|
|
self._init_weights()
|
|
|
|
def _init_weights(self):
|
|
for module in self.modules():
|
|
if isinstance(module, nn.Linear):
|
|
nn.init.orthogonal_(module.weight, gain=np.sqrt(2))
|
|
nn.init.constant_(module.bias, 0.0)
|
|
nn.init.orthogonal_(self.hyper_b2[-1].weight, gain=1.0)
|
|
|
|
def _extract_state_tokens(self, state_tensor: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
if state_tensor.dim() == 1:
|
|
state_tensor = state_tensor.unsqueeze(0)
|
|
|
|
batch_size = state_tensor.size(0)
|
|
edge_block = self.total_edge_count * self.edge_feature_dim
|
|
speed_block_start = edge_block
|
|
speed_block_end = speed_block_start + self.total_edge_count
|
|
global_block_start = speed_block_end
|
|
global_block_end = global_block_start + self.global_feature_dim
|
|
|
|
edge_features = state_tensor[:, :edge_block].view(batch_size, self.total_edge_count, self.edge_feature_dim)
|
|
edge_features = edge_features[:, self.controlled_start_index:self.controlled_end_index, :]
|
|
edge_speed_limits = state_tensor[:, speed_block_start:speed_block_end].view(batch_size, self.total_edge_count, 1)
|
|
edge_speed_limits = edge_speed_limits[:, self.controlled_start_index:self.controlled_end_index, :]
|
|
global_features = state_tensor[:, global_block_start:global_block_end]
|
|
edge_tokens = torch.cat([edge_features, edge_speed_limits], dim=-1)
|
|
return edge_tokens, global_features
|
|
|
|
def forward(self, agent_qs: torch.Tensor, state: torch.Tensor) -> torch.Tensor:
|
|
edge_tokens, global_features = self._extract_state_tokens(state)
|
|
state_embed = self.state_encoder(edge_tokens, global_features)
|
|
batch_size = agent_qs.size(0)
|
|
|
|
w1 = torch.abs(self.hyper_w1(state_embed)).view(batch_size, self.num_agents, -1)
|
|
b1 = self.hyper_b1(state_embed).view(batch_size, 1, -1)
|
|
hidden = self.activation(torch.bmm(agent_qs.unsqueeze(1), w1) + b1)
|
|
|
|
w2 = torch.abs(self.hyper_w2(state_embed)).view(batch_size, -1, 1)
|
|
b2 = self.hyper_b2(state_embed).view(batch_size, 1, 1)
|
|
return torch.bmm(hidden, w2).squeeze(1) + b2.squeeze(1)
|
|
|
|
|
|
class DCQMIXAgent(QMIXAgent):
|
|
"""Directional corridor variant of QMIX with Double-Q target action selection."""
|
|
|
|
def __init__(
|
|
self,
|
|
state_dim: int,
|
|
num_edges: int,
|
|
num_actions_per_edge: int,
|
|
hidden_dim: int = 256,
|
|
mixing_hidden_dim: int = 256,
|
|
learning_rate: float = 1e-3,
|
|
gamma: float = 0.99,
|
|
epsilon_start: float = 1.0,
|
|
epsilon_end: float = 0.01,
|
|
epsilon_decay: int = 200,
|
|
buffer_size: int = 10000,
|
|
batch_size: int = 64,
|
|
target_update: int = 10,
|
|
device: str = "cuda",
|
|
edge_feature_dim: int = 3,
|
|
time_feature_dim: int = 3,
|
|
total_edge_count: int | None = None,
|
|
controlled_start_index: int = 0,
|
|
num_corridor_blocks: int = 2,
|
|
corridor_kernel_size: int = 5,
|
|
corridor_dropout: float = 0.05,
|
|
):
|
|
self.num_corridor_blocks = int(num_corridor_blocks)
|
|
self.corridor_kernel_size = int(corridor_kernel_size)
|
|
self.corridor_dropout = float(corridor_dropout)
|
|
super().__init__(
|
|
state_dim=state_dim,
|
|
num_edges=num_edges,
|
|
num_actions_per_edge=num_actions_per_edge,
|
|
hidden_dim=hidden_dim,
|
|
mixing_hidden_dim=mixing_hidden_dim,
|
|
learning_rate=learning_rate,
|
|
gamma=gamma,
|
|
epsilon_start=epsilon_start,
|
|
epsilon_end=epsilon_end,
|
|
epsilon_decay=epsilon_decay,
|
|
buffer_size=buffer_size,
|
|
batch_size=batch_size,
|
|
target_update=target_update,
|
|
device=device,
|
|
edge_feature_dim=edge_feature_dim,
|
|
time_feature_dim=time_feature_dim,
|
|
total_edge_count=total_edge_count,
|
|
controlled_start_index=controlled_start_index,
|
|
)
|
|
|
|
def _build_utility_network(self, hidden_dim: int) -> nn.Module:
|
|
return DirectionalUtilityNetwork(
|
|
edge_token_dim=self.edge_feature_dim + self.speed_feature_dim,
|
|
condition_dim=self.time_feature_dim + self.last_reward_dim + self.agent_id_dim,
|
|
num_agents=self.num_agents,
|
|
num_actions=self.num_actions_per_agent,
|
|
hidden_dim=hidden_dim,
|
|
num_blocks=self.num_corridor_blocks,
|
|
kernel_size=self.corridor_kernel_size,
|
|
dropout=self.corridor_dropout,
|
|
)
|
|
|
|
def _build_mixer(self, mixing_hidden_dim: int) -> nn.Module:
|
|
return DirectionalQMixer(
|
|
num_agents=self.num_agents,
|
|
edge_feature_dim=self.edge_feature_dim,
|
|
time_feature_dim=self.time_feature_dim,
|
|
total_edge_count=self.total_edge_count,
|
|
controlled_start_index=self.controlled_start_index,
|
|
mixing_hidden_dim=mixing_hidden_dim,
|
|
state_hidden_dim=mixing_hidden_dim,
|
|
num_corridor_blocks=self.num_corridor_blocks,
|
|
corridor_kernel_size=self.corridor_kernel_size,
|
|
corridor_dropout=self.corridor_dropout,
|
|
)
|
|
|
|
@staticmethod
|
|
def _shift_left(x: torch.Tensor) -> torch.Tensor:
|
|
return torch.cat([x[:, :1, :], x[:, :-1, :]], dim=1)
|
|
|
|
@staticmethod
|
|
def _shift_right(x: torch.Tensor) -> torch.Tensor:
|
|
return torch.cat([x[:, 1:, :], x[:, -1:, :]], dim=1)
|
|
|
|
def _parse_state(self, state_tensor: torch.Tensor):
|
|
if state_tensor.dim() == 1:
|
|
state_tensor = state_tensor.unsqueeze(0)
|
|
|
|
batch_size = state_tensor.size(0)
|
|
edge_block = self.total_edge_count * self.edge_feature_dim
|
|
speed_block_start = edge_block
|
|
speed_block_end = speed_block_start + self.total_edge_count
|
|
global_block_start = speed_block_end
|
|
global_block_end = global_block_start + self.time_feature_dim + self.last_reward_dim
|
|
|
|
edge_features = state_tensor[:, :edge_block].view(batch_size, self.total_edge_count, self.edge_feature_dim)
|
|
edge_features = edge_features[:, self.controlled_start_index:self.controlled_end_index, :]
|
|
local_speed_limits = state_tensor[:, speed_block_start:speed_block_end].view(batch_size, self.total_edge_count, 1)
|
|
local_speed_limits = local_speed_limits[:, self.controlled_start_index:self.controlled_end_index, :]
|
|
edge_tokens = torch.cat([edge_features, local_speed_limits], dim=-1)
|
|
global_features = state_tensor[:, global_block_start:global_block_end]
|
|
agent_ids = self.agent_id_features.expand(batch_size, -1, -1)
|
|
condition = torch.cat(
|
|
[global_features.unsqueeze(1).expand(-1, self.num_agents, -1), agent_ids],
|
|
dim=-1,
|
|
)
|
|
return edge_tokens, condition
|
|
|
|
def _compute_agent_q_values_with_net(
|
|
self,
|
|
state_tensor: torch.Tensor,
|
|
utility_net: nn.Module,
|
|
) -> torch.Tensor:
|
|
edge_tokens, condition = self._parse_state(state_tensor)
|
|
left_tokens = self._shift_left(edge_tokens)
|
|
right_tokens = self._shift_right(edge_tokens)
|
|
delta_left = edge_tokens - left_tokens
|
|
delta_right = right_tokens - edge_tokens
|
|
components = torch.stack(
|
|
[edge_tokens, left_tokens, right_tokens, delta_left, delta_right],
|
|
dim=2,
|
|
)
|
|
return utility_net(components, condition)
|
|
|
|
def _compute_target_next_agent_q(self, next_states_t: torch.Tensor) -> torch.Tensor:
|
|
online_next_q_all = self._compute_agent_q_values(next_states_t)
|
|
next_actions = online_next_q_all.argmax(dim=2, keepdim=True)
|
|
target_next_q_all = self._compute_agent_q_values_with_net(next_states_t, self.target_utility_net)
|
|
return target_next_q_all.gather(2, next_actions).squeeze(2)
|