""" Temporal Credit Assignment MAPPO for SUMO VSL. - Actor: same decentralized shared actor style as MAPPO - Critic: current-state query attends over recent decision/outcome history """ from collections import deque from typing import Deque, Dict, Tuple import numpy as np import torch import torch.nn as nn import torch.optim as optim from envs.reward_system import REWARD_COMPONENT_COLUMNS class SharedActor(nn.Module): def __init__(self, local_obs_dim: int, num_actions: int, hidden_dim: int = 256): super().__init__() self.net = nn.Sequential( nn.Linear(local_obs_dim, hidden_dim), nn.LayerNorm(hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, hidden_dim), nn.LayerNorm(hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, num_actions), ) self._init_weights() def _init_weights(self): for module in self.modules(): if isinstance(module, nn.Linear): nn.init.orthogonal_(module.weight, gain=np.sqrt(2)) nn.init.constant_(module.bias, 0) nn.init.orthogonal_(self.net[-1].weight, gain=0.01) def forward(self, local_obs: torch.Tensor) -> torch.Tensor: return self.net(local_obs) class HistoryAttentionBlock(nn.Module): def __init__(self, hidden_dim: int, num_heads: int = 4, dropout: float = 0.05): super().__init__() self.cross_attn = nn.MultiheadAttention( embed_dim=hidden_dim, num_heads=num_heads, dropout=dropout, batch_first=True, ) self.norm1 = nn.LayerNorm(hidden_dim) self.norm2 = nn.LayerNorm(hidden_dim) self.ffn = nn.Sequential( nn.Linear(hidden_dim, hidden_dim * 2), nn.GELU(), nn.Dropout(dropout), nn.Linear(hidden_dim * 2, hidden_dim), ) self.dropout = nn.Dropout(dropout) def forward(self, query: torch.Tensor, history: torch.Tensor) -> torch.Tensor: attn_out, _ = self.cross_attn(query=query, key=history, value=history, need_weights=False) query = self.norm1(query + self.dropout(attn_out)) ffn_out = self.ffn(query) return self.norm2(query + self.dropout(ffn_out)) class TemporalCreditCritic(nn.Module): def __init__( self, state_dim: int, history_token_dim: int, history_len: int, hidden_dim: int = 256, num_heads: int = 4, num_layers: int = 2, dropout: float = 0.05, ): super().__init__() self.history_len = history_len self.state_encoder = nn.Sequential( nn.Linear(state_dim, hidden_dim), nn.LayerNorm(hidden_dim), nn.GELU(), ) self.history_encoder = nn.Sequential( nn.Linear(history_token_dim, hidden_dim), nn.LayerNorm(hidden_dim), nn.GELU(), ) self.age_embedding = nn.Parameter(torch.zeros(1, history_len, hidden_dim)) self.blocks = nn.ModuleList( [HistoryAttentionBlock(hidden_dim, num_heads=num_heads, dropout=dropout) for _ in range(num_layers)] ) self.head = nn.Sequential( nn.Linear(hidden_dim * 3, hidden_dim), nn.LayerNorm(hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, hidden_dim // 2), nn.ReLU(), nn.Linear(hidden_dim // 2, 1), ) self._init_weights() def _init_weights(self): for module in self.modules(): if isinstance(module, nn.Linear): nn.init.orthogonal_(module.weight, gain=np.sqrt(2)) nn.init.constant_(module.bias, 0) nn.init.orthogonal_(self.head[-1].weight, gain=1.0) nn.init.normal_(self.age_embedding, mean=0.0, std=0.02) def forward(self, current_state: torch.Tensor, history_tokens: torch.Tensor) -> torch.Tensor: query = self.state_encoder(current_state).unsqueeze(1) history = self.history_encoder(history_tokens) + self.age_embedding[:, : history_tokens.size(1), :] for block in self.blocks: query = block(query, history) pooled_history = history.mean(dim=1) max_history = history.max(dim=1).values fused = torch.cat([query.squeeze(1), pooled_history, max_history], dim=-1) return self.head(fused) class TCAMAPPOAgent: """MAPPO with a temporal credit-assignment critic.""" def __init__( self, state_dim: int, num_agents: int, num_actions: int, edge_feature_dim: int = 3, time_feature_dim: int = 3, total_edge_count: int | None = None, controlled_start_index: int = 0, hidden_dim: int = 256, critic_hidden_dim: int = 256, history_window: int = 6, critic_num_heads: int = 4, critic_num_layers: int = 2, critic_dropout: float = 0.05, learning_rate: float = 3e-4, gamma: float = 0.99, gae_lambda: float = 0.95, clip_epsilon: float = 0.2, value_coef: float = 0.5, entropy_coef: float = 0.01, max_grad_norm: float = 0.5, ppo_epochs: int = 4, minibatch_size: int = 15, device: str = "cuda", lr_schedule: str = "cosine", total_episodes: int = 300, ): self.device = torch.device(device if torch.cuda.is_available() else "cpu") self.state_dim = state_dim self.num_agents = num_agents self.num_actions = num_actions self.edge_feature_dim = edge_feature_dim self.time_feature_dim = time_feature_dim self.total_edge_count = int(total_edge_count if total_edge_count is not None else num_agents) self.controlled_start_index = int(controlled_start_index) self.controlled_end_index = self.controlled_start_index + self.num_agents if self.controlled_end_index > self.total_edge_count: raise ValueError("controlled action slice exceeds total edge count") self.gamma = gamma self.gae_lambda = gae_lambda self.clip_epsilon = clip_epsilon self.value_coef = value_coef self.entropy_coef = entropy_coef self.max_grad_norm = max_grad_norm self.ppo_epochs = ppo_epochs self.minibatch_size = minibatch_size self.history_window = history_window self.speed_feature_dim = 1 self.last_reward_dim = 1 self.agent_id_dim = 1 self.local_obs_dim = ( edge_feature_dim + self.speed_feature_dim + time_feature_dim + self.last_reward_dim + self.agent_id_dim ) self.reward_feature_dim = 1 + len(REWARD_COMPONENT_COLUMNS) self.history_token_dim = state_dim + num_agents + self.reward_feature_dim self.actor = SharedActor(self.local_obs_dim, num_actions, hidden_dim).to(self.device) self.critic = TemporalCreditCritic( state_dim=state_dim, history_token_dim=self.history_token_dim, history_len=history_window, hidden_dim=critic_hidden_dim, num_heads=critic_num_heads, num_layers=critic_num_layers, dropout=critic_dropout, ).to(self.device) self.optimizer = optim.Adam( list(self.actor.parameters()) + list(self.critic.parameters()), lr=learning_rate, eps=1e-5, ) if lr_schedule == "cosine": self.scheduler = optim.lr_scheduler.CosineAnnealingLR( self.optimizer, T_max=total_episodes, eta_min=learning_rate * 0.1, ) else: self.scheduler = None agent_ids = np.linspace(0.0, 1.0, num_agents, dtype=np.float32) self.agent_id_features = torch.tensor(agent_ids, device=self.device).view(1, num_agents, 1) self.reset_buffers() self.reset_episode() def reset_buffers(self): self.states = [] self.history_tokens = [] self.actions = [] self.rewards = [] self.values = [] self.log_probs = [] self.dones = [] def reset_episode(self): self._history: Deque[np.ndarray] = deque(maxlen=self.history_window) self._zero_history_token = np.zeros(self.history_token_dim, dtype=np.float32) def _build_local_obs(self, state_tensor: torch.Tensor) -> torch.Tensor: if state_tensor.dim() == 1: state_tensor = state_tensor.unsqueeze(0) batch_size = state_tensor.size(0) edge_block = self.total_edge_count * self.edge_feature_dim speed_block_start = edge_block speed_block_end = speed_block_start + self.total_edge_count global_block_start = speed_block_end global_block_end = global_block_start + self.time_feature_dim + self.last_reward_dim edge_features = state_tensor[:, :edge_block].view(batch_size, self.total_edge_count, self.edge_feature_dim) edge_features = edge_features[:, self.controlled_start_index:self.controlled_end_index, :] local_speed_limits = state_tensor[:, speed_block_start:speed_block_end].view(batch_size, self.total_edge_count, 1) local_speed_limits = local_speed_limits[:, self.controlled_start_index:self.controlled_end_index, :] global_features = state_tensor[:, global_block_start:global_block_end].unsqueeze(1) global_features = global_features.expand(-1, self.num_agents, -1) agent_ids = self.agent_id_features.expand(batch_size, -1, -1) return torch.cat([edge_features, local_speed_limits, global_features, agent_ids], dim=-1) def _get_history_stack(self) -> np.ndarray: tokens = list(self._history) if len(tokens) < self.history_window: pad_count = self.history_window - len(tokens) tokens = [self._zero_history_token.copy() for _ in range(pad_count)] + tokens return np.stack(tokens, axis=0).astype(np.float32) def _build_history_token(self, state, action, reward, info) -> np.ndarray: action_norm = np.asarray(action, dtype=np.float32) / max(self.num_actions - 1, 1) reward_features = np.array( [ float(reward) / 10.0, *[float(info.get(column, 0.0)) for column in REWARD_COMPONENT_COLUMNS], ], dtype=np.float32, ) return np.concatenate([np.asarray(state, dtype=np.float32), action_norm, reward_features], axis=0) def update_temporal_context(self, state, action, reward, info): self._history.append(self._build_history_token(state, action, reward, info)) def select_action( self, state: np.ndarray, deterministic: bool = False, ) -> Tuple[np.ndarray, np.ndarray, float]: state_tensor = torch.FloatTensor(state).unsqueeze(0).to(self.device) local_obs = self._build_local_obs(state_tensor) history_tensor = torch.FloatTensor(self._get_history_stack()).unsqueeze(0).to(self.device) with torch.no_grad(): logits = self.actor(local_obs.view(self.num_agents, self.local_obs_dim)) logits = logits.view(1, self.num_agents, self.num_actions) value = self.critic(state_tensor, history_tensor) actions = [] log_probs = [] for agent_idx in range(self.num_agents): dist = torch.distributions.Categorical(logits=logits[0, agent_idx]) if deterministic: action = torch.argmax(logits[0, agent_idx], dim=-1).item() else: action = dist.sample().item() actions.append(action) log_probs.append(dist.log_prob(torch.tensor(action, device=self.device)).item()) return np.array(actions, dtype=np.int64), np.array(log_probs, dtype=np.float32), value.item() def get_value(self, state: np.ndarray) -> float: state_tensor = torch.FloatTensor(state).unsqueeze(0).to(self.device) history_tensor = torch.FloatTensor(self._get_history_stack()).unsqueeze(0).to(self.device) with torch.no_grad(): value = self.critic(state_tensor, history_tensor) return value.item() def store_transition(self, state, history_token_stack, action, reward, value, log_prob, done): self.states.append(state) self.history_tokens.append(history_token_stack) self.actions.append(action) self.rewards.append(reward) self.values.append(value) self.log_probs.append(log_prob) self.dones.append(done) def compute_gae(self, next_value: float): advantages = [] gae = 0.0 for t in reversed(range(len(self.rewards))): next_val = next_value if t == len(self.rewards) - 1 else self.values[t + 1] delta = self.rewards[t] + self.gamma * next_val * (1 - self.dones[t]) - self.values[t] gae = delta + self.gamma * self.gae_lambda * (1 - self.dones[t]) * gae advantages.insert(0, gae) advantages = np.array(advantages, dtype=np.float32) returns = advantages + np.array(self.values, dtype=np.float32) return advantages, returns def update(self, next_value: float) -> Dict[str, float]: if len(self.states) == 0: return {} advantages, returns = self.compute_gae(next_value) advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8) states = torch.FloatTensor(np.array(self.states)).to(self.device) history_tokens = torch.FloatTensor(np.array(self.history_tokens)).to(self.device) actions = torch.LongTensor(np.array(self.actions)).to(self.device) old_log_probs = torch.FloatTensor(np.array(self.log_probs)).to(self.device) advantages_t = torch.FloatTensor(advantages).to(self.device) returns_t = torch.FloatTensor(returns).to(self.device) total_policy_loss = 0.0 total_value_loss = 0.0 total_entropy = 0.0 update_count = 0 dataset_size = len(self.states) for _ in range(self.ppo_epochs): indices = np.random.permutation(dataset_size) for start_idx in range(0, dataset_size, self.minibatch_size): end_idx = min(start_idx + self.minibatch_size, dataset_size) batch_idx = indices[start_idx:end_idx] batch_states = states[batch_idx] batch_history = history_tokens[batch_idx] batch_actions = actions[batch_idx] batch_old_lp = old_log_probs[batch_idx] batch_adv = advantages_t[batch_idx] batch_ret = returns_t[batch_idx] batch_local_obs = self._build_local_obs(batch_states) logits = self.actor( batch_local_obs.view(len(batch_idx) * self.num_agents, self.local_obs_dim) ).view(len(batch_idx), self.num_agents, self.num_actions) dist = torch.distributions.Categorical(logits=logits) new_log_probs = dist.log_prob(batch_actions) entropy = dist.entropy().mean() expanded_adv = batch_adv.unsqueeze(1).expand(-1, self.num_agents) ratio = torch.exp(new_log_probs - batch_old_lp) surr1 = ratio * expanded_adv surr2 = torch.clamp(ratio, 1 - self.clip_epsilon, 1 + self.clip_epsilon) * expanded_adv policy_loss = -torch.min(surr1, surr2).mean() values = self.critic(batch_states, batch_history).squeeze(-1) value_loss = nn.functional.mse_loss(values, batch_ret) loss = policy_loss + self.value_coef * value_loss - self.entropy_coef * entropy self.optimizer.zero_grad() loss.backward() nn.utils.clip_grad_norm_( list(self.actor.parameters()) + list(self.critic.parameters()), self.max_grad_norm, ) self.optimizer.step() total_policy_loss += policy_loss.item() total_value_loss += value_loss.item() total_entropy += entropy.item() update_count += 1 if self.scheduler is not None: self.scheduler.step() self.reset_buffers() return { "policy_loss": total_policy_loss / max(update_count, 1), "value_loss": total_value_loss / max(update_count, 1), "entropy": total_entropy / max(update_count, 1), } def save(self, path: str): torch.save( { "actor_state_dict": self.actor.state_dict(), "critic_state_dict": self.critic.state_dict(), "optimizer_state_dict": self.optimizer.state_dict(), }, path, ) def load(self, path: str): checkpoint = torch.load(path, map_location=self.device, weights_only=False) self.actor.load_state_dict(checkpoint["actor_state_dict"]) self.critic.load_state_dict(checkpoint["critic_state_dict"]) self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"])