diff --git a/agents/gpro_agent.py b/agents/gpro_agent.py index 6d2f0d9..94ee518 100644 --- a/agents/gpro_agent.py +++ b/agents/gpro_agent.py @@ -1,4 +1,4 @@ -"""GRPO-inspired PPO agent for grouped corridor control rollouts.""" +"""PPO-style GPRO agent with group-relative advantages for corridor control.""" from __future__ import annotations from typing import Dict, List, Tuple @@ -10,8 +10,8 @@ import torch.nn.functional as F import torch.optim as optim -class MultiDiscreteActor(nn.Module): - """Shared trunk plus one categorical head per controlled edge.""" +class MultiDiscreteActorCritic(nn.Module): + """Shared actor-critic backbone with one categorical head per control segment.""" def __init__( self, @@ -20,9 +20,7 @@ class MultiDiscreteActor(nn.Module): hidden_layers: List[int] = [256, 256], ): super().__init__() - self.state_dim = state_dim self.action_dims = action_dims - self.num_heads = len(action_dims) layers = [] prev_dim = state_dim @@ -32,9 +30,15 @@ class MultiDiscreteActor(nn.Module): layers.append(nn.ReLU()) prev_dim = hidden_dim self.feature_extractor = nn.Sequential(*layers) + self.actor_heads = nn.ModuleList( [nn.Linear(prev_dim, action_dim) for action_dim in action_dims] ) + self.critic = nn.Sequential( + nn.Linear(prev_dim, 128), + nn.ReLU(), + nn.Linear(128, 1), + ) self._init_weights() def _init_weights(self): @@ -44,17 +48,26 @@ class MultiDiscreteActor(nn.Module): nn.init.constant_(module.bias, 0) for head in self.actor_heads: nn.init.orthogonal_(head.weight, gain=0.01) + nn.init.orthogonal_(self.critic[-1].weight, gain=1.0) - def forward(self, state: torch.Tensor) -> List[torch.Tensor]: + def forward(self, state: torch.Tensor) -> Tuple[List[torch.Tensor], torch.Tensor]: features = self.feature_extractor(state) - return [head(features) for head in self.actor_heads] + logits_list = [head(features) for head in self.actor_heads] + value = self.critic(features) + return logits_list, value - def get_action_probs(self, state: torch.Tensor) -> List[torch.Tensor]: - return [F.softmax(logits, dim=-1) for logits in self.forward(state)] + def get_action_probs(self, state: torch.Tensor) -> Tuple[List[torch.Tensor], torch.Tensor]: + logits_list, value = self.forward(state) + probs_list = [F.softmax(logits, dim=-1) for logits in logits_list] + return probs_list, value + + def get_value(self, state: torch.Tensor) -> torch.Tensor: + features = self.feature_extractor(state) + return self.critic(features) class GPROAgent: - """Grouped relative PPO without value critic.""" + """PPO actor-critic with group-relative trajectory ranking.""" def __init__( self, @@ -62,29 +75,39 @@ class GPROAgent: action_dims: List[int], hidden_layers: List[int] = [256, 256], learning_rate: float = 3e-4, + gamma: float = 0.99, + gae_lambda: float = 0.95, clip_epsilon: float = 0.2, + value_coef: float = 0.5, entropy_coef: float = 0.01, max_grad_norm: float = 0.5, ppo_epochs: int = 4, minibatch_size: int = 64, group_size: int = 4, + group_advantage_coef: float = 0.35, advantage_epsilon: float = 1e-8, device: str = "cuda", lr_schedule: str = "cosine", total_episodes: int = 300, ): self.device = torch.device(device if torch.cuda.is_available() else "cpu") + self.gamma = gamma + self.gae_lambda = gae_lambda self.clip_epsilon = clip_epsilon + self.value_coef = value_coef self.entropy_coef = entropy_coef self.max_grad_norm = max_grad_norm self.ppo_epochs = ppo_epochs self.minibatch_size = minibatch_size self.group_size = max(int(group_size), 2) + self.group_advantage_coef = float(np.clip(group_advantage_coef, 0.0, 1.0)) self.advantage_epsilon = advantage_epsilon self.action_dims = action_dims self.num_heads = len(action_dims) - self.policy = MultiDiscreteActor(state_dim, action_dims, hidden_layers).to(self.device) + self.policy = MultiDiscreteActorCritic(state_dim, action_dims, hidden_layers).to( + self.device + ) self.optimizer = optim.Adam(self.policy.parameters(), lr=learning_rate, eps=1e-5) self.lr_schedule = lr_schedule @@ -101,10 +124,13 @@ class GPROAgent: def _reset_episode_buffer(self): self.current_states: List[np.ndarray] = [] self.current_actions: List[np.ndarray] = [] + self.current_rewards: List[float] = [] + self.current_values: List[float] = [] self.current_log_probs: List[float] = [] + self.current_dones: List[float] = [] def reset_group_buffers(self): - self.trajectories: List[Dict] = [] + self.trajectories: List[Dict[str, np.ndarray | float]] = [] self._reset_episode_buffer() def select_action( @@ -112,7 +138,7 @@ class GPROAgent: ) -> Tuple[np.ndarray, float, float]: state_tensor = torch.FloatTensor(state).unsqueeze(0).to(self.device) with torch.no_grad(): - probs_list = self.policy.get_action_probs(state_tensor) + probs_list, value = self.policy.get_action_probs(state_tensor) actions = [] total_log_prob = 0.0 @@ -126,13 +152,15 @@ class GPROAgent: actions.append(action) total_log_prob += log_prob - return np.array(actions, dtype=np.int64), total_log_prob, 0.0 + return np.array(actions, dtype=np.int64), total_log_prob, value.item() def store_transition(self, state, action, reward, value, log_prob, done): - del reward, value, done self.current_states.append(np.asarray(state, dtype=np.float32)) self.current_actions.append(np.asarray(action, dtype=np.int64)) + self.current_rewards.append(float(reward)) + self.current_values.append(float(value)) self.current_log_probs.append(float(log_prob)) + self.current_dones.append(float(done)) def finish_episode(self, score: float): if not self.current_states: @@ -141,47 +169,102 @@ class GPROAgent: { "states": np.asarray(self.current_states, dtype=np.float32), "actions": np.asarray(self.current_actions, dtype=np.int64), + "rewards": np.asarray(self.current_rewards, dtype=np.float32), + "values": np.asarray(self.current_values, dtype=np.float32), "log_probs": np.asarray(self.current_log_probs, dtype=np.float32), + "dones": np.asarray(self.current_dones, dtype=np.float32), "score": float(score), } ) self._reset_episode_buffer() - def _build_group_advantages(self) -> List[np.ndarray]: - scores = np.asarray([traj["score"] for traj in self.trajectories], dtype=np.float32) - if scores.size == 0: - return [] - score_mean = float(scores.mean()) - score_std = float(scores.std()) - if score_std < self.advantage_epsilon: - normalized_scores = np.zeros_like(scores, dtype=np.float32) - else: - normalized_scores = (scores - score_mean) / (score_std + self.advantage_epsilon) - return [ - np.full(len(traj["states"]), normalized_scores[idx], dtype=np.float32) - for idx, traj in enumerate(self.trajectories) - ] + def _compute_gae( + self, + rewards: np.ndarray, + values: np.ndarray, + dones: np.ndarray, + next_value: float = 0.0, + ) -> Tuple[np.ndarray, np.ndarray]: + advantages = np.zeros_like(rewards, dtype=np.float32) + gae = 0.0 + + for t in reversed(range(len(rewards))): + if t == len(rewards) - 1: + next_val = next_value + else: + next_val = values[t + 1] + delta = rewards[t] + self.gamma * next_val * (1.0 - dones[t]) - values[t] + gae = delta + self.gamma * self.gae_lambda * (1.0 - dones[t]) * gae + advantages[t] = gae + + returns = advantages + values + return advantages, returns + + def _normalize(self, values: np.ndarray) -> np.ndarray: + if values.size == 0: + return values + mean = float(values.mean()) + std = float(values.std()) + if std < self.advantage_epsilon: + return np.zeros_like(values, dtype=np.float32) + return ((values - mean) / (std + self.advantage_epsilon)).astype(np.float32) + + def _build_training_targets(self) -> Tuple[np.ndarray, np.ndarray]: + if not self.trajectories: + return np.array([], dtype=np.float32), np.array([], dtype=np.float32) + + gae_advantages = [] + returns = [] + for trajectory in self.trajectories: + trajectory_adv, trajectory_ret = self._compute_gae( + rewards=trajectory["rewards"], + values=trajectory["values"], + dones=trajectory["dones"], + next_value=0.0, + ) + gae_advantages.append(trajectory_adv) + returns.append(trajectory_ret) + + normalized_gae = self._normalize(np.concatenate(gae_advantages, axis=0)) + group_scores = np.asarray( + [trajectory["score"] for trajectory in self.trajectories], dtype=np.float32 + ) + normalized_group_scores = self._normalize(group_scores) + repeated_group_advantages = np.concatenate( + [ + np.full(len(trajectory["states"]), normalized_group_scores[idx], dtype=np.float32) + for idx, trajectory in enumerate(self.trajectories) + ], + axis=0, + ) + + combined_advantages = ( + (1.0 - self.group_advantage_coef) * normalized_gae + + self.group_advantage_coef * repeated_group_advantages + ) + combined_advantages = self._normalize(combined_advantages) + return combined_advantages, np.concatenate(returns, axis=0).astype(np.float32) def update(self) -> Dict[str, float]: if not self.trajectories: return {} - trajectory_advantages = self._build_group_advantages() + advantages, returns = self._build_training_targets() states = torch.FloatTensor( - np.concatenate([traj["states"] for traj in self.trajectories], axis=0) + np.concatenate([trajectory["states"] for trajectory in self.trajectories], axis=0) ).to(self.device) actions = torch.LongTensor( - np.concatenate([traj["actions"] for traj in self.trajectories], axis=0) + np.concatenate([trajectory["actions"] for trajectory in self.trajectories], axis=0) ).to(self.device) old_log_probs = torch.FloatTensor( - np.concatenate([traj["log_probs"] for traj in self.trajectories], axis=0) - ).to(self.device) - advantages = torch.FloatTensor( - np.concatenate(trajectory_advantages, axis=0) + np.concatenate([trajectory["log_probs"] for trajectory in self.trajectories], axis=0) ).to(self.device) + advantages_t = torch.FloatTensor(advantages).to(self.device) + returns_t = torch.FloatTensor(returns).to(self.device) total_loss = 0.0 total_policy_loss = 0.0 + total_value_loss = 0.0 total_entropy_value = 0.0 update_count = 0 dataset_size = states.shape[0] @@ -195,9 +278,10 @@ class GPROAgent: batch_states = states[batch_idx] batch_actions = actions[batch_idx] batch_old_log_probs = old_log_probs[batch_idx] - batch_advantages = advantages[batch_idx] + batch_advantages = advantages_t[batch_idx] + batch_returns = returns_t[batch_idx] - logits_list = self.policy(batch_states) + logits_list, values = self.policy(batch_states) total_new_log_probs = torch.zeros(len(batch_idx), device=self.device) entropy_terms = torch.zeros(len(batch_idx), device=self.device) @@ -215,7 +299,12 @@ class GPROAgent: * batch_advantages ) policy_loss = -torch.min(surr1, surr2).mean() - loss = policy_loss - self.entropy_coef * entropy + value_loss = F.mse_loss(values.squeeze(-1), batch_returns) + loss = ( + policy_loss + + self.value_coef * value_loss + - self.entropy_coef * entropy + ) self.optimizer.zero_grad() loss.backward() @@ -224,25 +313,26 @@ class GPROAgent: total_loss += float(loss.item()) total_policy_loss += float(policy_loss.item()) + total_value_loss += float(value_loss.item()) total_entropy_value += float(entropy.item()) update_count += 1 if self.scheduler is not None: self.scheduler.step() + scores = np.asarray( + [trajectory["score"] for trajectory in self.trajectories], dtype=np.float32 + ) stats = { "loss": total_loss / max(update_count, 1), "policy_loss": total_policy_loss / max(update_count, 1), - "value_loss": 0.0, + "value_loss": total_value_loss / max(update_count, 1), "entropy": total_entropy_value / max(update_count, 1), "lr": self.optimizer.param_groups[0]["lr"], - "group_score_mean": float( - np.mean([traj["score"] for traj in self.trajectories], dtype=np.float32) - ), - "group_score_std": float( - np.std([traj["score"] for traj in self.trajectories], dtype=np.float32) - ), + "group_score_mean": float(scores.mean()) if scores.size else 0.0, + "group_score_std": float(scores.std()) if scores.size else 0.0, "group_size": float(len(self.trajectories)), + "group_advantage_coef": float(self.group_advantage_coef), } self.reset_group_buffers() return stats diff --git a/config_sumo_vsl.yaml b/config_sumo_vsl.yaml index 926d0ae..74e45b5 100644 --- a/config_sumo_vsl.yaml +++ b/config_sumo_vsl.yaml @@ -107,13 +107,17 @@ agents: gpro: hidden_layers: [256, 256] learning_rate: 0.0003 + gamma: 0.99 + gae_lambda: 0.95 clip_epsilon: 0.2 + value_coef: 0.5 entropy_coef: 0.01 max_grad_norm: 0.5 ppo_epochs: 4 batch_size: 15 lr_schedule: "cosine" group_size: 4 + group_advantage_coef: 0.35 advantage_epsilon: 1.0e-8 appo: diff --git a/training/train_gpro.py b/training/train_gpro.py index 298124b..6491260 100644 --- a/training/train_gpro.py +++ b/training/train_gpro.py @@ -60,6 +60,7 @@ def train_sumo_gpro(log_dir=None, checkpoint_dir=None, run_timestamp=None): print(f" Hidden layers: {agent_config.get('hidden_layers', [256, 256])}") print(f" Learning rate: {agent_config.get('learning_rate', 3e-4)}") print(f" Group size: {group_size}") + print(f" Group advantage coef: {agent_config.get('group_advantage_coef', 0.35)}") print(f" Device: {agent_config.get('device', 'cuda')}") print() @@ -68,12 +69,16 @@ def train_sumo_gpro(log_dir=None, checkpoint_dir=None, run_timestamp=None): action_dims=action_dims, hidden_layers=agent_config.get("hidden_layers", [256, 256]), learning_rate=agent_config.get("learning_rate", 3e-4), + gamma=agent_config.get("gamma", 0.99), + gae_lambda=agent_config.get("gae_lambda", 0.95), clip_epsilon=agent_config.get("clip_epsilon", 0.2), + value_coef=agent_config.get("value_coef", 0.5), entropy_coef=agent_config.get("entropy_coef", 0.01), max_grad_norm=agent_config.get("max_grad_norm", 0.5), ppo_epochs=agent_config.get("ppo_epochs", 4), minibatch_size=agent_config.get("batch_size", 64), group_size=group_size, + group_advantage_coef=agent_config.get("group_advantage_coef", 0.35), advantage_epsilon=agent_config.get("advantage_epsilon", 1e-8), device=agent_config.get("device", "cuda"), lr_schedule=agent_config.get("lr_schedule", "cosine"), @@ -125,10 +130,10 @@ def train_sumo_gpro(log_dir=None, checkpoint_dir=None, run_timestamp=None): ) while not done: - action, log_prob, _ = agent.select_action(state, deterministic=False) + action, log_prob, value = agent.select_action(state, deterministic=False) next_state, reward, done, info = env.step(action) - agent.store_transition(state, action, reward, 0.0, log_prob, done) + agent.store_transition(state, action, reward, value, log_prob, done) episode_reward += reward episode_throughput += info["throughput"]