ctm-dqn/agents/tcamappo_agent.py

412 lines
16 KiB
Python

"""
Temporal Credit Assignment MAPPO for SUMO VSL.
- Actor: same decentralized shared actor style as MAPPO
- Critic: current-state query attends over recent decision/outcome history
"""
from collections import deque
from typing import Deque, Dict, Tuple
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
class SharedActor(nn.Module):
def __init__(self, local_obs_dim: int, num_actions: int, hidden_dim: int = 256):
super().__init__()
self.net = nn.Sequential(
nn.Linear(local_obs_dim, hidden_dim),
nn.LayerNorm(hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.LayerNorm(hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, num_actions),
)
self._init_weights()
def _init_weights(self):
for module in self.modules():
if isinstance(module, nn.Linear):
nn.init.orthogonal_(module.weight, gain=np.sqrt(2))
nn.init.constant_(module.bias, 0)
nn.init.orthogonal_(self.net[-1].weight, gain=0.01)
def forward(self, local_obs: torch.Tensor) -> torch.Tensor:
return self.net(local_obs)
class HistoryAttentionBlock(nn.Module):
def __init__(self, hidden_dim: int, num_heads: int = 4, dropout: float = 0.05):
super().__init__()
self.cross_attn = nn.MultiheadAttention(
embed_dim=hidden_dim,
num_heads=num_heads,
dropout=dropout,
batch_first=True,
)
self.norm1 = nn.LayerNorm(hidden_dim)
self.norm2 = nn.LayerNorm(hidden_dim)
self.ffn = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim * 2),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim * 2, hidden_dim),
)
self.dropout = nn.Dropout(dropout)
def forward(self, query: torch.Tensor, history: torch.Tensor) -> torch.Tensor:
attn_out, _ = self.cross_attn(query=query, key=history, value=history, need_weights=False)
query = self.norm1(query + self.dropout(attn_out))
ffn_out = self.ffn(query)
return self.norm2(query + self.dropout(ffn_out))
class TemporalCreditCritic(nn.Module):
def __init__(
self,
state_dim: int,
history_token_dim: int,
history_len: int,
hidden_dim: int = 256,
num_heads: int = 4,
num_layers: int = 2,
dropout: float = 0.05,
):
super().__init__()
self.history_len = history_len
self.state_encoder = nn.Sequential(
nn.Linear(state_dim, hidden_dim),
nn.LayerNorm(hidden_dim),
nn.GELU(),
)
self.history_encoder = nn.Sequential(
nn.Linear(history_token_dim, hidden_dim),
nn.LayerNorm(hidden_dim),
nn.GELU(),
)
self.age_embedding = nn.Parameter(torch.zeros(1, history_len, hidden_dim))
self.blocks = nn.ModuleList(
[HistoryAttentionBlock(hidden_dim, num_heads=num_heads, dropout=dropout) for _ in range(num_layers)]
)
self.head = nn.Sequential(
nn.Linear(hidden_dim * 3, hidden_dim),
nn.LayerNorm(hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim // 2),
nn.ReLU(),
nn.Linear(hidden_dim // 2, 1),
)
self._init_weights()
def _init_weights(self):
for module in self.modules():
if isinstance(module, nn.Linear):
nn.init.orthogonal_(module.weight, gain=np.sqrt(2))
nn.init.constant_(module.bias, 0)
nn.init.orthogonal_(self.head[-1].weight, gain=1.0)
nn.init.normal_(self.age_embedding, mean=0.0, std=0.02)
def forward(self, current_state: torch.Tensor, history_tokens: torch.Tensor) -> torch.Tensor:
query = self.state_encoder(current_state).unsqueeze(1)
history = self.history_encoder(history_tokens) + self.age_embedding[:, : history_tokens.size(1), :]
for block in self.blocks:
query = block(query, history)
pooled_history = history.mean(dim=1)
max_history = history.max(dim=1).values
fused = torch.cat([query.squeeze(1), pooled_history, max_history], dim=-1)
return self.head(fused)
class TCAMAPPOAgent:
"""MAPPO with a temporal credit-assignment critic."""
def __init__(
self,
state_dim: int,
num_agents: int,
num_actions: int,
edge_feature_dim: int = 3,
time_feature_dim: int = 3,
hidden_dim: int = 256,
critic_hidden_dim: int = 256,
history_window: int = 6,
critic_num_heads: int = 4,
critic_num_layers: int = 2,
critic_dropout: float = 0.05,
learning_rate: float = 3e-4,
gamma: float = 0.99,
gae_lambda: float = 0.95,
clip_epsilon: float = 0.2,
value_coef: float = 0.5,
entropy_coef: float = 0.01,
max_grad_norm: float = 0.5,
ppo_epochs: int = 4,
minibatch_size: int = 15,
device: str = "cuda",
lr_schedule: str = "cosine",
total_episodes: int = 300,
):
self.device = torch.device(device if torch.cuda.is_available() else "cpu")
self.state_dim = state_dim
self.num_agents = num_agents
self.num_actions = num_actions
self.edge_feature_dim = edge_feature_dim
self.time_feature_dim = time_feature_dim
self.gamma = gamma
self.gae_lambda = gae_lambda
self.clip_epsilon = clip_epsilon
self.value_coef = value_coef
self.entropy_coef = entropy_coef
self.max_grad_norm = max_grad_norm
self.ppo_epochs = ppo_epochs
self.minibatch_size = minibatch_size
self.history_window = history_window
self.speed_feature_dim = 1
self.last_reward_dim = 1
self.agent_id_dim = 1
self.local_obs_dim = (
edge_feature_dim
+ self.speed_feature_dim
+ time_feature_dim
+ self.last_reward_dim
+ self.agent_id_dim
)
self.history_token_dim = state_dim + num_agents + 5
self.actor = SharedActor(self.local_obs_dim, num_actions, hidden_dim).to(self.device)
self.critic = TemporalCreditCritic(
state_dim=state_dim,
history_token_dim=self.history_token_dim,
history_len=history_window,
hidden_dim=critic_hidden_dim,
num_heads=critic_num_heads,
num_layers=critic_num_layers,
dropout=critic_dropout,
).to(self.device)
self.optimizer = optim.Adam(
list(self.actor.parameters()) + list(self.critic.parameters()),
lr=learning_rate,
eps=1e-5,
)
if lr_schedule == "cosine":
self.scheduler = optim.lr_scheduler.CosineAnnealingLR(
self.optimizer,
T_max=total_episodes,
eta_min=learning_rate * 0.1,
)
else:
self.scheduler = None
agent_ids = np.linspace(0.0, 1.0, num_agents, dtype=np.float32)
self.agent_id_features = torch.tensor(agent_ids, device=self.device).view(1, num_agents, 1)
self.reset_buffers()
self.reset_episode()
def reset_buffers(self):
self.states = []
self.history_tokens = []
self.actions = []
self.rewards = []
self.values = []
self.log_probs = []
self.dones = []
def reset_episode(self):
self._history: Deque[np.ndarray] = deque(maxlen=self.history_window)
self._zero_history_token = np.zeros(self.history_token_dim, dtype=np.float32)
def _build_local_obs(self, state_tensor: torch.Tensor) -> torch.Tensor:
if state_tensor.dim() == 1:
state_tensor = state_tensor.unsqueeze(0)
batch_size = state_tensor.size(0)
edge_block = self.num_agents * self.edge_feature_dim
speed_block_start = edge_block
speed_block_end = speed_block_start + self.num_agents
global_block_start = speed_block_end
global_block_end = global_block_start + self.time_feature_dim + self.last_reward_dim
edge_features = state_tensor[:, :edge_block].view(batch_size, self.num_agents, self.edge_feature_dim)
local_speed_limits = state_tensor[:, speed_block_start:speed_block_end].view(batch_size, self.num_agents, 1)
global_features = state_tensor[:, global_block_start:global_block_end].unsqueeze(1)
global_features = global_features.expand(-1, self.num_agents, -1)
agent_ids = self.agent_id_features.expand(batch_size, -1, -1)
return torch.cat([edge_features, local_speed_limits, global_features, agent_ids], dim=-1)
def _get_history_stack(self) -> np.ndarray:
tokens = list(self._history)
if len(tokens) < self.history_window:
pad_count = self.history_window - len(tokens)
tokens = [self._zero_history_token.copy() for _ in range(pad_count)] + tokens
return np.stack(tokens, axis=0).astype(np.float32)
def _build_history_token(self, state, action, reward, info) -> np.ndarray:
action_norm = np.asarray(action, dtype=np.float32) / max(self.num_actions - 1, 1)
reward_features = np.array(
[
float(reward) / 10.0,
float(info.get("r_flow", 0.0)),
float(info.get("r_var", 0.0)),
float(info.get("r_brake", 0.0)),
float(info.get("r_penalty", 0.0)),
],
dtype=np.float32,
)
return np.concatenate([np.asarray(state, dtype=np.float32), action_norm, reward_features], axis=0)
def update_temporal_context(self, state, action, reward, info):
self._history.append(self._build_history_token(state, action, reward, info))
def select_action(
self,
state: np.ndarray,
deterministic: bool = False,
) -> Tuple[np.ndarray, np.ndarray, float]:
state_tensor = torch.FloatTensor(state).unsqueeze(0).to(self.device)
local_obs = self._build_local_obs(state_tensor)
history_tensor = torch.FloatTensor(self._get_history_stack()).unsqueeze(0).to(self.device)
with torch.no_grad():
logits = self.actor(local_obs.view(self.num_agents, self.local_obs_dim))
logits = logits.view(1, self.num_agents, self.num_actions)
value = self.critic(state_tensor, history_tensor)
actions = []
log_probs = []
for agent_idx in range(self.num_agents):
dist = torch.distributions.Categorical(logits=logits[0, agent_idx])
if deterministic:
action = torch.argmax(logits[0, agent_idx], dim=-1).item()
else:
action = dist.sample().item()
actions.append(action)
log_probs.append(dist.log_prob(torch.tensor(action, device=self.device)).item())
return np.array(actions, dtype=np.int64), np.array(log_probs, dtype=np.float32), value.item()
def get_value(self, state: np.ndarray) -> float:
state_tensor = torch.FloatTensor(state).unsqueeze(0).to(self.device)
history_tensor = torch.FloatTensor(self._get_history_stack()).unsqueeze(0).to(self.device)
with torch.no_grad():
value = self.critic(state_tensor, history_tensor)
return value.item()
def store_transition(self, state, history_token_stack, action, reward, value, log_prob, done):
self.states.append(state)
self.history_tokens.append(history_token_stack)
self.actions.append(action)
self.rewards.append(reward)
self.values.append(value)
self.log_probs.append(log_prob)
self.dones.append(done)
def compute_gae(self, next_value: float):
advantages = []
gae = 0.0
for t in reversed(range(len(self.rewards))):
next_val = next_value if t == len(self.rewards) - 1 else self.values[t + 1]
delta = self.rewards[t] + self.gamma * next_val * (1 - self.dones[t]) - self.values[t]
gae = delta + self.gamma * self.gae_lambda * (1 - self.dones[t]) * gae
advantages.insert(0, gae)
advantages = np.array(advantages, dtype=np.float32)
returns = advantages + np.array(self.values, dtype=np.float32)
return advantages, returns
def update(self, next_value: float) -> Dict[str, float]:
if len(self.states) == 0:
return {}
advantages, returns = self.compute_gae(next_value)
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
states = torch.FloatTensor(np.array(self.states)).to(self.device)
history_tokens = torch.FloatTensor(np.array(self.history_tokens)).to(self.device)
actions = torch.LongTensor(np.array(self.actions)).to(self.device)
old_log_probs = torch.FloatTensor(np.array(self.log_probs)).to(self.device)
advantages_t = torch.FloatTensor(advantages).to(self.device)
returns_t = torch.FloatTensor(returns).to(self.device)
total_policy_loss = 0.0
total_value_loss = 0.0
total_entropy = 0.0
update_count = 0
dataset_size = len(self.states)
for _ in range(self.ppo_epochs):
indices = np.random.permutation(dataset_size)
for start_idx in range(0, dataset_size, self.minibatch_size):
end_idx = min(start_idx + self.minibatch_size, dataset_size)
batch_idx = indices[start_idx:end_idx]
batch_states = states[batch_idx]
batch_history = history_tokens[batch_idx]
batch_actions = actions[batch_idx]
batch_old_lp = old_log_probs[batch_idx]
batch_adv = advantages_t[batch_idx]
batch_ret = returns_t[batch_idx]
batch_local_obs = self._build_local_obs(batch_states)
logits = self.actor(
batch_local_obs.view(len(batch_idx) * self.num_agents, self.local_obs_dim)
).view(len(batch_idx), self.num_agents, self.num_actions)
dist = torch.distributions.Categorical(logits=logits)
new_log_probs = dist.log_prob(batch_actions)
entropy = dist.entropy().mean()
expanded_adv = batch_adv.unsqueeze(1).expand(-1, self.num_agents)
ratio = torch.exp(new_log_probs - batch_old_lp)
surr1 = ratio * expanded_adv
surr2 = torch.clamp(ratio, 1 - self.clip_epsilon, 1 + self.clip_epsilon) * expanded_adv
policy_loss = -torch.min(surr1, surr2).mean()
values = self.critic(batch_states, batch_history).squeeze(-1)
value_loss = nn.functional.mse_loss(values, batch_ret)
loss = policy_loss + self.value_coef * value_loss - self.entropy_coef * entropy
self.optimizer.zero_grad()
loss.backward()
nn.utils.clip_grad_norm_(
list(self.actor.parameters()) + list(self.critic.parameters()),
self.max_grad_norm,
)
self.optimizer.step()
total_policy_loss += policy_loss.item()
total_value_loss += value_loss.item()
total_entropy += entropy.item()
update_count += 1
if self.scheduler is not None:
self.scheduler.step()
self.reset_buffers()
return {
"policy_loss": total_policy_loss / max(update_count, 1),
"value_loss": total_value_loss / max(update_count, 1),
"entropy": total_entropy / max(update_count, 1),
}
def save(self, path: str):
torch.save(
{
"actor_state_dict": self.actor.state_dict(),
"critic_state_dict": self.critic.state_dict(),
"optimizer_state_dict": self.optimizer.state_dict(),
},
path,
)
def load(self, path: str):
checkpoint = torch.load(path, map_location=self.device, weights_only=False)
self.actor.load_state_dict(checkpoint["actor_state_dict"])
self.critic.load_state_dict(checkpoint["critic_state_dict"])
self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"])