ctm-dqn/agents/dcmappo_agent.py

555 lines
22 KiB
Python

"""
Directional Corridor MAPPO for ordered motorway VSL control.
This revision borrows three ideas that are common in newer sequence/control
architectures from other domains:
- Variable selection style gating to suppress noisy neighbor features
- Conditional modulation (FiLM-like) so global state can reshape local features
- Dual-path corridor mixing that keeps both local smoothing and directional flow
"""
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from typing import Dict, Tuple
class SelectiveFeatureFusion(nn.Module):
"""Fuse self/neighbor/delta components with content-aware gates."""
def __init__(self, component_dim: int, condition_dim: int, hidden_dim: int, num_components: int):
super().__init__()
self.component_proj = nn.Sequential(
nn.Linear(component_dim, hidden_dim),
nn.LayerNorm(hidden_dim),
nn.GELU(),
)
self.component_score = nn.Linear(hidden_dim, 1)
self.condition_score = nn.Sequential(
nn.Linear(condition_dim, hidden_dim),
nn.GELU(),
nn.Linear(hidden_dim, num_components),
)
self.film = nn.Sequential(
nn.Linear(condition_dim, hidden_dim * 2),
nn.GELU(),
nn.Linear(hidden_dim * 2, hidden_dim * 2),
)
self._init_weights()
def _init_weights(self):
for module in self.modules():
if isinstance(module, nn.Linear):
nn.init.orthogonal_(module.weight, gain=np.sqrt(2))
nn.init.constant_(module.bias, 0)
def forward(self, components: torch.Tensor, condition: torch.Tensor) -> torch.Tensor:
encoded = self.component_proj(components)
score_from_component = self.component_score(encoded).squeeze(-1)
score_from_condition = self.condition_score(condition)
weights = torch.softmax(score_from_component + score_from_condition, dim=1)
fused = (weights.unsqueeze(-1) * encoded).sum(dim=1)
gamma, beta = self.film(condition).chunk(2, dim=-1)
gamma = 0.1 * torch.tanh(gamma)
beta = 0.1 * beta
return fused * (1.0 + gamma) + beta
class DualPathCorridorBlock(nn.Module):
"""Combine local convolutional mixing and directional recurrent propagation."""
def __init__(self, hidden_dim: int, kernel_size: int = 5, dropout: float = 0.05):
super().__init__()
padding = kernel_size // 2
self.local_norm = nn.LayerNorm(hidden_dim)
self.depthwise = nn.Conv1d(
hidden_dim,
hidden_dim,
kernel_size=kernel_size,
padding=padding,
groups=hidden_dim,
)
self.pointwise = nn.Conv1d(hidden_dim, hidden_dim, kernel_size=1)
self.local_gate = nn.Linear(hidden_dim * 2, hidden_dim)
self.seq_norm = nn.LayerNorm(hidden_dim)
self.seq_mixer = nn.GRU(
input_size=hidden_dim,
hidden_size=hidden_dim // 2,
num_layers=1,
batch_first=True,
bidirectional=True,
)
self.seq_gate = nn.Linear(hidden_dim * 2, hidden_dim)
self.ffn_norm = nn.LayerNorm(hidden_dim)
self.ffn = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim * 2),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim * 2, hidden_dim),
)
self.dropout = nn.Dropout(dropout)
self.activation = nn.GELU()
self._init_weights()
def _init_weights(self):
for module in self.modules():
if isinstance(module, nn.Linear):
nn.init.orthogonal_(module.weight, gain=np.sqrt(2))
nn.init.constant_(module.bias, 0)
elif isinstance(module, nn.Conv1d):
nn.init.kaiming_uniform_(module.weight, a=np.sqrt(5))
if module.bias is not None:
nn.init.constant_(module.bias, 0)
for name, param in self.seq_mixer.named_parameters():
if "weight" in name:
nn.init.orthogonal_(param)
elif "bias" in name:
nn.init.constant_(param, 0)
def forward(self, x: torch.Tensor) -> torch.Tensor:
residual = x
local = self.local_norm(x).transpose(1, 2)
local = self.activation(self.pointwise(self.depthwise(local))).transpose(1, 2)
local_gate = torch.sigmoid(self.local_gate(torch.cat([residual, local], dim=-1)))
x = residual + self.dropout(local_gate * local)
residual = x
seq_input = self.seq_norm(x)
seq_out, _ = self.seq_mixer(seq_input)
seq_gate = torch.sigmoid(self.seq_gate(torch.cat([residual, seq_out], dim=-1)))
x = residual + self.dropout(seq_gate * seq_out)
residual = x
x = residual + self.dropout(self.ffn(self.ffn_norm(x)))
return x
class CorridorActor(nn.Module):
def __init__(
self,
edge_token_dim: int,
condition_dim: int,
num_agents: int,
num_actions: int,
hidden_dim: int = 256,
num_blocks: int = 2,
kernel_size: int = 5,
dropout: float = 0.05,
):
super().__init__()
self.fusion = SelectiveFeatureFusion(
component_dim=edge_token_dim,
condition_dim=condition_dim,
hidden_dim=hidden_dim,
num_components=5,
)
self.position_embedding = nn.Parameter(torch.zeros(1, num_agents, hidden_dim))
self.blocks = nn.ModuleList(
[DualPathCorridorBlock(hidden_dim, kernel_size=kernel_size, dropout=dropout) for _ in range(num_blocks)]
)
self.head_norm = nn.LayerNorm(hidden_dim)
self.head = nn.Linear(hidden_dim, num_actions)
self._init_weights()
def _init_weights(self):
nn.init.normal_(self.position_embedding, mean=0.0, std=0.02)
nn.init.orthogonal_(self.head.weight, gain=0.01)
nn.init.constant_(self.head.bias, 0)
def forward(self, components: torch.Tensor, condition: torch.Tensor) -> torch.Tensor:
batch_size, num_agents, _, _ = components.shape
fused = self.fusion(
components.view(batch_size * num_agents, components.size(2), components.size(3)),
condition.view(batch_size * num_agents, condition.size(-1)),
)
x = fused.view(batch_size, num_agents, -1) + self.position_embedding
for block in self.blocks:
x = block(x)
return self.head(self.head_norm(x))
class StructuredCorridorCritic(nn.Module):
"""Structured critic to reduce value underfitting on ordered corridor states."""
def __init__(
self,
num_agents: int,
edge_feature_dim: int,
time_feature_dim: int,
hidden_dim: int = 256,
num_blocks: int = 2,
kernel_size: int = 5,
dropout: float = 0.05,
):
super().__init__()
self.num_agents = num_agents
self.edge_feature_dim = edge_feature_dim
self.time_feature_dim = time_feature_dim
self.speed_feature_dim = 1
self.last_reward_dim = 1
self.global_feature_dim = self.time_feature_dim + self.last_reward_dim
self.edge_token_dim = self.edge_feature_dim + self.speed_feature_dim
self.edge_proj = nn.Sequential(
nn.Linear(self.edge_token_dim, hidden_dim),
nn.LayerNorm(hidden_dim),
nn.GELU(),
)
self.global_film = nn.Sequential(
nn.Linear(self.global_feature_dim, hidden_dim * 2),
nn.GELU(),
nn.Linear(hidden_dim * 2, hidden_dim * 2),
)
self.position_embedding = nn.Parameter(torch.zeros(1, num_agents, hidden_dim))
self.blocks = nn.ModuleList(
[DualPathCorridorBlock(hidden_dim, kernel_size=kernel_size, dropout=dropout) for _ in range(num_blocks)]
)
self.head = nn.Sequential(
nn.Linear(hidden_dim * 2 + self.global_feature_dim, hidden_dim),
nn.LayerNorm(hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim // 2),
nn.ReLU(),
nn.Linear(hidden_dim // 2, 1),
)
self._init_weights()
def _init_weights(self):
for module in self.head.modules():
if isinstance(module, nn.Linear):
nn.init.orthogonal_(module.weight, gain=np.sqrt(2))
nn.init.constant_(module.bias, 0)
nn.init.orthogonal_(self.head[-1].weight, gain=1.0)
nn.init.normal_(self.position_embedding, mean=0.0, std=0.02)
def forward(self, edge_tokens: torch.Tensor, global_features: torch.Tensor) -> torch.Tensor:
x = self.edge_proj(edge_tokens)
gamma, beta = self.global_film(global_features).chunk(2, dim=-1)
gamma = 0.1 * torch.tanh(gamma).unsqueeze(1)
beta = 0.1 * beta.unsqueeze(1)
x = x * (1.0 + gamma) + beta
x = x + self.position_embedding
for block in self.blocks:
x = block(x)
pooled_mean = x.mean(dim=1)
pooled_max = x.max(dim=1).values
summary = torch.cat([pooled_mean, pooled_max, global_features], dim=-1)
return self.head(summary)
class DCMAPPOAgent:
"""Directional Corridor MAPPO with gated actor and structured critic."""
def __init__(
self,
state_dim: int,
num_agents: int,
num_actions: int,
edge_feature_dim: int = 3,
time_feature_dim: int = 3,
total_edge_count: int | None = None,
controlled_start_index: int = 0,
hidden_dim: int = 256,
critic_hidden_dim: int = 256,
num_corridor_blocks: int = 2,
corridor_kernel_size: int = 5,
corridor_dropout: float = 0.05,
learning_rate: float = 3e-4,
gamma: float = 0.99,
gae_lambda: float = 0.95,
clip_epsilon: float = 0.2,
value_coef: float = 0.5,
entropy_coef: float = 0.01,
max_grad_norm: float = 0.5,
ppo_epochs: int = 4,
minibatch_size: int = 15,
device: str = "cuda",
lr_schedule: str = "cosine",
total_episodes: int = 300,
):
self.device = torch.device(device if torch.cuda.is_available() else "cpu")
self.state_dim = state_dim
self.num_agents = num_agents
self.num_actions = num_actions
self.edge_feature_dim = edge_feature_dim
self.time_feature_dim = time_feature_dim
self.total_edge_count = int(total_edge_count if total_edge_count is not None else num_agents)
self.controlled_start_index = int(controlled_start_index)
self.controlled_end_index = self.controlled_start_index + self.num_agents
if self.controlled_end_index > self.total_edge_count:
raise ValueError("controlled action slice exceeds total edge count")
self.gamma = gamma
self.gae_lambda = gae_lambda
self.clip_epsilon = clip_epsilon
self.value_coef = value_coef
self.entropy_coef = entropy_coef
self.max_grad_norm = max_grad_norm
self.ppo_epochs = ppo_epochs
self.minibatch_size = minibatch_size
self.speed_feature_dim = 1
self.last_reward_dim = 1
self.global_feature_dim = self.time_feature_dim + self.last_reward_dim
self.agent_id_dim = 1
self.edge_token_dim = self.edge_feature_dim + self.speed_feature_dim
self.condition_dim = self.global_feature_dim + self.agent_id_dim
self.actor = CorridorActor(
edge_token_dim=self.edge_token_dim,
condition_dim=self.condition_dim,
num_agents=num_agents,
num_actions=num_actions,
hidden_dim=hidden_dim,
num_blocks=num_corridor_blocks,
kernel_size=corridor_kernel_size,
dropout=corridor_dropout,
).to(self.device)
self.critic = StructuredCorridorCritic(
num_agents=num_agents,
edge_feature_dim=edge_feature_dim,
time_feature_dim=time_feature_dim,
hidden_dim=critic_hidden_dim,
num_blocks=max(1, num_corridor_blocks),
kernel_size=corridor_kernel_size,
dropout=corridor_dropout,
).to(self.device)
self.actor_optimizer = optim.Adam(self.actor.parameters(), lr=learning_rate, eps=1e-5)
self.critic_optimizer = optim.Adam(self.critic.parameters(), lr=learning_rate, eps=1e-5)
if lr_schedule == "cosine":
self.actor_scheduler = optim.lr_scheduler.CosineAnnealingLR(
self.actor_optimizer,
T_max=total_episodes,
eta_min=learning_rate * 0.1,
)
self.critic_scheduler = optim.lr_scheduler.CosineAnnealingLR(
self.critic_optimizer,
T_max=total_episodes,
eta_min=learning_rate * 0.1,
)
else:
self.actor_scheduler = None
self.critic_scheduler = None
agent_ids = np.linspace(0.0, 1.0, num_agents, dtype=np.float32)
self.agent_id_features = torch.tensor(agent_ids, device=self.device).view(1, num_agents, 1)
self.reset_buffers()
def reset_buffers(self):
self.states = []
self.actions = []
self.rewards = []
self.values = []
self.log_probs = []
self.dones = []
@staticmethod
def _shift_left(x: torch.Tensor) -> torch.Tensor:
return torch.cat([x[:, :1, :], x[:, :-1, :]], dim=1)
@staticmethod
def _shift_right(x: torch.Tensor) -> torch.Tensor:
return torch.cat([x[:, 1:, :], x[:, -1:, :]], dim=1)
def _parse_state(self, state_tensor: torch.Tensor):
if state_tensor.dim() == 1:
state_tensor = state_tensor.unsqueeze(0)
batch_size = state_tensor.size(0)
edge_block = self.total_edge_count * self.edge_feature_dim
speed_block_start = edge_block
speed_block_end = speed_block_start + self.total_edge_count
global_block_start = speed_block_end
global_block_end = global_block_start + self.global_feature_dim
edge_features = state_tensor[:, :edge_block].view(batch_size, self.total_edge_count, self.edge_feature_dim)
edge_features = edge_features[:, self.controlled_start_index:self.controlled_end_index, :]
local_speed_limits = state_tensor[:, speed_block_start:speed_block_end].view(batch_size, self.total_edge_count, 1)
local_speed_limits = local_speed_limits[:, self.controlled_start_index:self.controlled_end_index, :]
edge_tokens = torch.cat([edge_features, local_speed_limits], dim=-1)
global_features = state_tensor[:, global_block_start:global_block_end]
agent_ids = self.agent_id_features.expand(batch_size, -1, -1)
condition = torch.cat(
[global_features.unsqueeze(1).expand(-1, self.num_agents, -1), agent_ids],
dim=-1,
)
return edge_tokens, global_features, condition
def _build_actor_inputs(self, state_tensor: torch.Tensor):
edge_tokens, global_features, condition = self._parse_state(state_tensor)
left_tokens = self._shift_left(edge_tokens)
right_tokens = self._shift_right(edge_tokens)
delta_left = edge_tokens - left_tokens
delta_right = right_tokens - edge_tokens
components = torch.stack(
[edge_tokens, left_tokens, right_tokens, delta_left, delta_right],
dim=2,
)
return components, condition, edge_tokens, global_features
def select_action(
self,
state: np.ndarray,
deterministic: bool = False,
) -> Tuple[np.ndarray, np.ndarray, float]:
state_tensor = torch.FloatTensor(state).unsqueeze(0).to(self.device)
components, condition, edge_tokens, global_features = self._build_actor_inputs(state_tensor)
with torch.no_grad():
logits = self.actor(components, condition)
value = self.critic(edge_tokens, global_features)
actions = []
log_probs = []
for agent_idx in range(self.num_agents):
dist = torch.distributions.Categorical(logits=logits[0, agent_idx])
if deterministic:
action = torch.argmax(logits[0, agent_idx], dim=-1).item()
else:
action = dist.sample().item()
actions.append(action)
log_probs.append(dist.log_prob(torch.tensor(action, device=self.device)).item())
return np.array(actions, dtype=np.int64), np.array(log_probs, dtype=np.float32), value.item()
def get_value(self, state: np.ndarray) -> float:
state_tensor = torch.FloatTensor(state).unsqueeze(0).to(self.device)
_, _, edge_tokens, global_features = self._build_actor_inputs(state_tensor)
with torch.no_grad():
value = self.critic(edge_tokens, global_features)
return value.item()
def store_transition(self, state, action, reward, value, log_prob, done):
self.states.append(state)
self.actions.append(action)
self.rewards.append(reward)
self.values.append(value)
self.log_probs.append(log_prob)
self.dones.append(done)
def compute_gae(self, next_value: float):
advantages = []
gae = 0.0
for t in reversed(range(len(self.rewards))):
next_val = next_value if t == len(self.rewards) - 1 else self.values[t + 1]
delta = self.rewards[t] + self.gamma * next_val * (1 - self.dones[t]) - self.values[t]
gae = delta + self.gamma * self.gae_lambda * (1 - self.dones[t]) * gae
advantages.insert(0, gae)
advantages = np.array(advantages, dtype=np.float32)
returns = advantages + np.array(self.values, dtype=np.float32)
return advantages, returns
def update(self, next_value: float) -> Dict[str, float]:
if len(self.states) == 0:
return {}
advantages, returns = self.compute_gae(next_value)
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
states = torch.FloatTensor(np.array(self.states)).to(self.device)
actions = torch.LongTensor(np.array(self.actions)).to(self.device)
old_log_probs = torch.FloatTensor(np.array(self.log_probs)).to(self.device)
advantages_t = torch.FloatTensor(advantages).to(self.device)
returns_t = torch.FloatTensor(returns).to(self.device)
total_policy_loss = 0.0
total_value_loss = 0.0
total_entropy = 0.0
update_count = 0
dataset_size = len(self.states)
for _ in range(self.ppo_epochs):
indices = np.random.permutation(dataset_size)
for start_idx in range(0, dataset_size, self.minibatch_size):
end_idx = min(start_idx + self.minibatch_size, dataset_size)
batch_idx = indices[start_idx:end_idx]
batch_states = states[batch_idx]
batch_actions = actions[batch_idx]
batch_old_lp = old_log_probs[batch_idx]
batch_adv = advantages_t[batch_idx]
batch_ret = returns_t[batch_idx]
components, condition, edge_tokens, global_features = self._build_actor_inputs(batch_states)
logits = self.actor(components, condition)
dist = torch.distributions.Categorical(logits=logits)
new_log_probs = dist.log_prob(batch_actions)
entropy = dist.entropy().mean()
expanded_adv = batch_adv.unsqueeze(1).expand(-1, self.num_agents)
ratio = torch.exp(new_log_probs - batch_old_lp)
surr1 = ratio * expanded_adv
surr2 = torch.clamp(ratio, 1 - self.clip_epsilon, 1 + self.clip_epsilon) * expanded_adv
policy_loss = -torch.min(surr1, surr2).mean()
values = self.critic(edge_tokens, global_features).squeeze(-1)
value_loss = nn.functional.mse_loss(values, batch_ret)
actor_loss = policy_loss - self.entropy_coef * entropy
critic_loss = self.value_coef * value_loss
self.actor_optimizer.zero_grad()
actor_loss.backward()
nn.utils.clip_grad_norm_(self.actor.parameters(), self.max_grad_norm)
self.actor_optimizer.step()
self.critic_optimizer.zero_grad()
critic_loss.backward()
nn.utils.clip_grad_norm_(self.critic.parameters(), self.max_grad_norm)
self.critic_optimizer.step()
total_policy_loss += policy_loss.item()
total_value_loss += value_loss.item()
total_entropy += entropy.item()
update_count += 1
if self.actor_scheduler is not None:
self.actor_scheduler.step()
if self.critic_scheduler is not None:
self.critic_scheduler.step()
self.reset_buffers()
return {
"policy_loss": total_policy_loss / max(update_count, 1),
"value_loss": total_value_loss / max(update_count, 1),
"entropy": total_entropy / max(update_count, 1),
}
def save(self, path: str):
torch.save(
{
"actor_state_dict": self.actor.state_dict(),
"critic_state_dict": self.critic.state_dict(),
"actor_optimizer_state_dict": self.actor_optimizer.state_dict(),
"critic_optimizer_state_dict": self.critic_optimizer.state_dict(),
},
path,
)
def load(self, path: str):
checkpoint = torch.load(path, map_location=self.device, weights_only=False)
self.actor.load_state_dict(checkpoint["actor_state_dict"])
self.critic.load_state_dict(checkpoint["critic_state_dict"])
actor_optim_state = checkpoint.get("actor_optimizer_state_dict")
critic_optim_state = checkpoint.get("critic_optimizer_state_dict")
legacy_optim_state = checkpoint.get("optimizer_state_dict")
if actor_optim_state is not None:
self.actor_optimizer.load_state_dict(actor_optim_state)
elif legacy_optim_state is not None:
self.actor_optimizer.load_state_dict(legacy_optim_state)
if critic_optim_state is not None:
self.critic_optimizer.load_state_dict(critic_optim_state)
elif legacy_optim_state is not None:
self.critic_optimizer.load_state_dict(legacy_optim_state)