"""PPO-style GPRO agent with group-relative advantages for corridor control.""" from __future__ import annotations from typing import Dict, List, Tuple import numpy as np import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim class MultiDiscreteActorCritic(nn.Module): """Shared actor-critic backbone with one categorical head per control segment.""" def __init__( self, state_dim: int, action_dims: List[int], hidden_layers: List[int] = [256, 256], ): super().__init__() self.action_dims = action_dims layers = [] prev_dim = state_dim for hidden_dim in hidden_layers: layers.append(nn.Linear(prev_dim, hidden_dim)) layers.append(nn.LayerNorm(hidden_dim)) layers.append(nn.ReLU()) prev_dim = hidden_dim self.feature_extractor = nn.Sequential(*layers) self.actor_heads = nn.ModuleList( [nn.Linear(prev_dim, action_dim) for action_dim in action_dims] ) self.critic = nn.Sequential( nn.Linear(prev_dim, 128), nn.ReLU(), nn.Linear(128, 1), ) self._init_weights() def _init_weights(self): for module in self.modules(): if isinstance(module, nn.Linear): nn.init.orthogonal_(module.weight, gain=np.sqrt(2)) nn.init.constant_(module.bias, 0) for head in self.actor_heads: nn.init.orthogonal_(head.weight, gain=0.01) nn.init.orthogonal_(self.critic[-1].weight, gain=1.0) def forward(self, state: torch.Tensor) -> Tuple[List[torch.Tensor], torch.Tensor]: features = self.feature_extractor(state) logits_list = [head(features) for head in self.actor_heads] value = self.critic(features) return logits_list, value def get_action_probs(self, state: torch.Tensor) -> Tuple[List[torch.Tensor], torch.Tensor]: logits_list, value = self.forward(state) probs_list = [F.softmax(logits, dim=-1) for logits in logits_list] return probs_list, value def get_value(self, state: torch.Tensor) -> torch.Tensor: features = self.feature_extractor(state) return self.critic(features) class GPROAgent: """PPO actor-critic with group-relative trajectory ranking.""" def __init__( self, state_dim: int, action_dims: List[int], hidden_layers: List[int] = [256, 256], learning_rate: float = 3e-4, gamma: float = 0.99, gae_lambda: float = 0.95, clip_epsilon: float = 0.2, value_coef: float = 0.5, entropy_coef: float = 0.01, max_grad_norm: float = 0.5, ppo_epochs: int = 4, minibatch_size: int = 64, group_size: int = 4, group_advantage_coef: float = 0.35, advantage_epsilon: float = 1e-8, device: str = "cuda", lr_schedule: str = "cosine", total_episodes: int = 300, ): self.device = torch.device(device if torch.cuda.is_available() else "cpu") self.gamma = gamma self.gae_lambda = gae_lambda self.clip_epsilon = clip_epsilon self.value_coef = value_coef self.entropy_coef = entropy_coef self.max_grad_norm = max_grad_norm self.ppo_epochs = ppo_epochs self.minibatch_size = minibatch_size self.group_size = max(int(group_size), 2) self.group_advantage_coef = float(np.clip(group_advantage_coef, 0.0, 1.0)) self.advantage_epsilon = advantage_epsilon self.action_dims = action_dims self.num_heads = len(action_dims) self.policy = MultiDiscreteActorCritic(state_dim, action_dims, hidden_layers).to( self.device ) self.optimizer = optim.Adam(self.policy.parameters(), lr=learning_rate, eps=1e-5) self.lr_schedule = lr_schedule if lr_schedule == "cosine": self.scheduler = optim.lr_scheduler.CosineAnnealingLR( self.optimizer, T_max=total_episodes, eta_min=learning_rate * 0.1 ) else: self.scheduler = None self.reset_group_buffers() self._reset_episode_buffer() def _reset_episode_buffer(self): self.current_states: List[np.ndarray] = [] self.current_actions: List[np.ndarray] = [] self.current_rewards: List[float] = [] self.current_values: List[float] = [] self.current_log_probs: List[float] = [] self.current_dones: List[float] = [] def reset_group_buffers(self): self.trajectories: List[Dict[str, np.ndarray | float]] = [] self._reset_episode_buffer() def select_action( self, state: np.ndarray, deterministic: bool = False ) -> Tuple[np.ndarray, float, float]: state_tensor = torch.FloatTensor(state).unsqueeze(0).to(self.device) with torch.no_grad(): probs_list, value = self.policy.get_action_probs(state_tensor) actions = [] total_log_prob = 0.0 for probs in probs_list: if deterministic: action = torch.argmax(probs, dim=-1).item() else: dist = torch.distributions.Categorical(probs) action = dist.sample().item() log_prob = torch.log(probs[0, action] + 1e-10).item() actions.append(action) total_log_prob += log_prob return np.array(actions, dtype=np.int64), total_log_prob, value.item() def store_transition(self, state, action, reward, value, log_prob, done): self.current_states.append(np.asarray(state, dtype=np.float32)) self.current_actions.append(np.asarray(action, dtype=np.int64)) self.current_rewards.append(float(reward)) self.current_values.append(float(value)) self.current_log_probs.append(float(log_prob)) self.current_dones.append(float(done)) def finish_episode(self, score: float): if not self.current_states: return self.trajectories.append( { "states": np.asarray(self.current_states, dtype=np.float32), "actions": np.asarray(self.current_actions, dtype=np.int64), "rewards": np.asarray(self.current_rewards, dtype=np.float32), "values": np.asarray(self.current_values, dtype=np.float32), "log_probs": np.asarray(self.current_log_probs, dtype=np.float32), "dones": np.asarray(self.current_dones, dtype=np.float32), "score": float(score), } ) self._reset_episode_buffer() def _compute_gae( self, rewards: np.ndarray, values: np.ndarray, dones: np.ndarray, next_value: float = 0.0, ) -> Tuple[np.ndarray, np.ndarray]: advantages = np.zeros_like(rewards, dtype=np.float32) gae = 0.0 for t in reversed(range(len(rewards))): if t == len(rewards) - 1: next_val = next_value else: next_val = values[t + 1] delta = rewards[t] + self.gamma * next_val * (1.0 - dones[t]) - values[t] gae = delta + self.gamma * self.gae_lambda * (1.0 - dones[t]) * gae advantages[t] = gae returns = advantages + values return advantages, returns def _normalize(self, values: np.ndarray) -> np.ndarray: if values.size == 0: return values mean = float(values.mean()) std = float(values.std()) if std < self.advantage_epsilon: return np.zeros_like(values, dtype=np.float32) return ((values - mean) / (std + self.advantage_epsilon)).astype(np.float32) def _build_training_targets(self) -> Tuple[np.ndarray, np.ndarray]: if not self.trajectories: return np.array([], dtype=np.float32), np.array([], dtype=np.float32) gae_advantages = [] returns = [] for trajectory in self.trajectories: trajectory_adv, trajectory_ret = self._compute_gae( rewards=trajectory["rewards"], values=trajectory["values"], dones=trajectory["dones"], next_value=0.0, ) gae_advantages.append(trajectory_adv) returns.append(trajectory_ret) normalized_gae = self._normalize(np.concatenate(gae_advantages, axis=0)) group_scores = np.asarray( [trajectory["score"] for trajectory in self.trajectories], dtype=np.float32 ) normalized_group_scores = self._normalize(group_scores) repeated_group_advantages = np.concatenate( [ np.full(len(trajectory["states"]), normalized_group_scores[idx], dtype=np.float32) for idx, trajectory in enumerate(self.trajectories) ], axis=0, ) combined_advantages = ( (1.0 - self.group_advantage_coef) * normalized_gae + self.group_advantage_coef * repeated_group_advantages ) combined_advantages = self._normalize(combined_advantages) return combined_advantages, np.concatenate(returns, axis=0).astype(np.float32) def update(self) -> Dict[str, float]: if not self.trajectories: return {} advantages, returns = self._build_training_targets() states = torch.FloatTensor( np.concatenate([trajectory["states"] for trajectory in self.trajectories], axis=0) ).to(self.device) actions = torch.LongTensor( np.concatenate([trajectory["actions"] for trajectory in self.trajectories], axis=0) ).to(self.device) old_log_probs = torch.FloatTensor( np.concatenate([trajectory["log_probs"] for trajectory in self.trajectories], axis=0) ).to(self.device) advantages_t = torch.FloatTensor(advantages).to(self.device) returns_t = torch.FloatTensor(returns).to(self.device) total_loss = 0.0 total_policy_loss = 0.0 total_value_loss = 0.0 total_entropy_value = 0.0 update_count = 0 dataset_size = states.shape[0] for _ in range(self.ppo_epochs): indices = np.random.permutation(dataset_size) for start_idx in range(0, dataset_size, self.minibatch_size): end_idx = min(start_idx + self.minibatch_size, dataset_size) batch_idx = indices[start_idx:end_idx] batch_states = states[batch_idx] batch_actions = actions[batch_idx] batch_old_log_probs = old_log_probs[batch_idx] batch_advantages = advantages_t[batch_idx] batch_returns = returns_t[batch_idx] logits_list, values = self.policy(batch_states) total_new_log_probs = torch.zeros(len(batch_idx), device=self.device) entropy_terms = torch.zeros(len(batch_idx), device=self.device) for head_idx in range(self.num_heads): probs = F.softmax(logits_list[head_idx], dim=-1) dist = torch.distributions.Categorical(probs) total_new_log_probs += dist.log_prob(batch_actions[:, head_idx]) entropy_terms += dist.entropy() entropy = entropy_terms.mean() ratio = torch.exp(total_new_log_probs - batch_old_log_probs) surr1 = ratio * batch_advantages surr2 = ( torch.clamp(ratio, 1 - self.clip_epsilon, 1 + self.clip_epsilon) * batch_advantages ) policy_loss = -torch.min(surr1, surr2).mean() value_loss = F.mse_loss(values.squeeze(-1), batch_returns) loss = ( policy_loss + self.value_coef * value_loss - self.entropy_coef * entropy ) self.optimizer.zero_grad() loss.backward() nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm) self.optimizer.step() total_loss += float(loss.item()) total_policy_loss += float(policy_loss.item()) total_value_loss += float(value_loss.item()) total_entropy_value += float(entropy.item()) update_count += 1 if self.scheduler is not None: self.scheduler.step() scores = np.asarray( [trajectory["score"] for trajectory in self.trajectories], dtype=np.float32 ) stats = { "loss": total_loss / max(update_count, 1), "policy_loss": total_policy_loss / max(update_count, 1), "value_loss": total_value_loss / max(update_count, 1), "entropy": total_entropy_value / max(update_count, 1), "lr": self.optimizer.param_groups[0]["lr"], "group_score_mean": float(scores.mean()) if scores.size else 0.0, "group_score_std": float(scores.std()) if scores.size else 0.0, "group_size": float(len(self.trajectories)), "group_advantage_coef": float(self.group_advantage_coef), } self.reset_group_buffers() return stats def save(self, path: str): torch.save( { "policy_state_dict": self.policy.state_dict(), "optimizer_state_dict": self.optimizer.state_dict(), }, path, ) def load(self, path: str): checkpoint = torch.load(path, map_location=self.device, weights_only=False) self.policy.load_state_dict(checkpoint["policy_state_dict"]) self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"])