""" PPO Agent Implementation (优化版) 支持 MultiDiscrete 动作空间,适配 VSL 多zone独立控制 优化点: 1. MultiDiscrete: 每zone独立actor head, 避免指数爆炸的动作空间 2. LayerNorm + 更深网络: 改善训练稳定性 3. 学习率余弦退火调度 4. 正交初始化 """ import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim import numpy as np from typing import Tuple, List, Dict class MultiDiscreteActorCritic(nn.Module): """ MultiDiscrete Actor-Critic 网络 Actor: 共享特征提取 -> 每zone独立head输出该zone的速度档位概率 Critic: 共享特征提取 -> 单头输出状态价值 """ def __init__( self, state_dim: int, action_dims: List[int], # e.g. [5, 5, 5, 5, 5] 每zone的动作数 hidden_layers: List[int] = [256, 256], ): super().__init__() self.state_dim = state_dim self.action_dims = action_dims self.num_heads = len(action_dims) # Shared feature extractor with LayerNorm layers = [] prev_dim = state_dim for hidden_dim in hidden_layers: layers.append(nn.Linear(prev_dim, hidden_dim)) layers.append(nn.LayerNorm(hidden_dim)) layers.append(nn.ReLU()) prev_dim = hidden_dim self.feature_extractor = nn.Sequential(*layers) # 每zone独立的 actor head self.actor_heads = nn.ModuleList([ nn.Linear(prev_dim, adim) for adim in action_dims ]) # Critic head self.critic = nn.Sequential( nn.Linear(prev_dim, 128), nn.ReLU(), nn.Linear(128, 1), ) # 正交初始化 self._init_weights() def _init_weights(self): for m in self.modules(): if isinstance(m, nn.Linear): nn.init.orthogonal_(m.weight, gain=np.sqrt(2)) nn.init.constant_(m.bias, 0) # Actor heads 用更小的gain for head in self.actor_heads: nn.init.orthogonal_(head.weight, gain=0.01) # Critic 最后一层也用小gain nn.init.orthogonal_(self.critic[-1].weight, gain=1.0) def forward(self, state): features = self.feature_extractor(state) logits_list = [head(features) for head in self.actor_heads] value = self.critic(features) return logits_list, value def get_action_probs(self, state): """获取每个zone的动作概率分布""" logits_list, value = self.forward(state) probs_list = [F.softmax(logits, dim=-1) for logits in logits_list] return probs_list, value def get_value(self, state): features = self.feature_extractor(state) return self.critic(features) class PPOAgent: """PPO智能体 (MultiDiscrete 版)""" def __init__( self, state_dim: int, action_dims: List[int], # MultiDiscrete: 每zone的动作数 hidden_layers: List[int] = [256, 256], learning_rate: float = 3e-4, gamma: float = 0.99, gae_lambda: float = 0.95, clip_epsilon: float = 0.2, value_coef: float = 0.5, entropy_coef: float = 0.02, max_grad_norm: float = 0.5, ppo_epochs: int = 10, minibatch_size: int = 64, device: str = "cuda", lr_schedule: str = "cosine", # "cosine" or "none" total_episodes: int = 300, ): self.device = torch.device(device if torch.cuda.is_available() else "cpu") self.gamma = gamma self.gae_lambda = gae_lambda self.clip_epsilon = clip_epsilon self.value_coef = value_coef self.entropy_coef = entropy_coef self.max_grad_norm = max_grad_norm self.ppo_epochs = ppo_epochs self.minibatch_size = minibatch_size self.action_dims = action_dims self.num_heads = len(action_dims) # Create network self.policy = MultiDiscreteActorCritic( state_dim, action_dims, hidden_layers ).to(self.device) self.optimizer = optim.Adam(self.policy.parameters(), lr=learning_rate, eps=1e-5) # 学习率调度 self.lr_schedule = lr_schedule if lr_schedule == "cosine": self.scheduler = optim.lr_scheduler.CosineAnnealingLR( self.optimizer, T_max=total_episodes, eta_min=learning_rate * 0.1 ) else: self.scheduler = None # Buffers self.reset_buffers() def reset_buffers(self): """重置经验缓冲""" self.states = [] self.actions = [] # List of np.ndarray, each shape=(num_heads,) self.rewards = [] self.values = [] self.log_probs = [] # List of float (sum of all heads' log_probs) self.dones = [] def select_action( self, state: np.ndarray, deterministic: bool = False ) -> Tuple[np.ndarray, float, float]: """ 选择动作 Returns: action: np.ndarray shape=(num_heads,), 每个zone的动作index log_prob: float, 所有zone log_prob之和 value: float """ state_tensor = torch.FloatTensor(state).unsqueeze(0).to(self.device) with torch.no_grad(): probs_list, value = self.policy.get_action_probs(state_tensor) actions = [] total_log_prob = 0.0 for i, probs in enumerate(probs_list): if deterministic: a = torch.argmax(probs, dim=-1).item() else: dist = torch.distributions.Categorical(probs) a = dist.sample().item() log_p = torch.log(probs[0, a] + 1e-10).item() actions.append(a) total_log_prob += log_p return np.array(actions, dtype=np.int64), total_log_prob, value.item() def store_transition(self, state, action, reward, value, log_prob, done): """存储转换""" self.states.append(state) self.actions.append(action) # np.ndarray shape=(num_heads,) self.rewards.append(reward) self.values.append(value) self.log_probs.append(log_prob) self.dones.append(done) def compute_gae(self, next_value: float) -> Tuple[np.ndarray, np.ndarray]: """计算GAE优势函数""" advantages = [] gae = 0 for t in reversed(range(len(self.rewards))): if t == len(self.rewards) - 1: next_val = next_value else: next_val = self.values[t + 1] delta = self.rewards[t] + self.gamma * next_val * (1 - self.dones[t]) - self.values[t] gae = delta + self.gamma * self.gae_lambda * (1 - self.dones[t]) * gae advantages.insert(0, gae) advantages = np.array(advantages, dtype=np.float32) returns = advantages + np.array(self.values, dtype=np.float32) return advantages, returns def update(self, next_value: float) -> Dict[str, float]: """PPO更新""" if len(self.states) == 0: return {} advantages, returns = self.compute_gae(next_value) advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8) # Convert to tensors states = torch.FloatTensor(np.array(self.states)).to(self.device) # actions: shape (N, num_heads) actions = torch.LongTensor(np.array(self.actions)).to(self.device) old_log_probs = torch.FloatTensor(self.log_probs).to(self.device) advantages_t = torch.FloatTensor(advantages).to(self.device) returns_t = torch.FloatTensor(returns).to(self.device) total_loss = 0 total_policy_loss = 0 total_value_loss = 0 total_entropy = 0 update_count = 0 dataset_size = len(self.states) for _ in range(self.ppo_epochs): indices = np.random.permutation(dataset_size) for start_idx in range(0, dataset_size, self.minibatch_size): end_idx = min(start_idx + self.minibatch_size, dataset_size) batch_idx = indices[start_idx:end_idx] batch_states = states[batch_idx] batch_actions = actions[batch_idx] # (B, num_heads) batch_old_lp = old_log_probs[batch_idx] batch_adv = advantages_t[batch_idx] batch_ret = returns_t[batch_idx] # Forward logits_list, values = self.policy(batch_states) # 计算每个head的log_prob和entropy, 然后求和 total_new_lp = torch.zeros(len(batch_idx), device=self.device) total_ent = torch.zeros(len(batch_idx), device=self.device) for h in range(self.num_heads): probs_h = F.softmax(logits_list[h], dim=-1) dist_h = torch.distributions.Categorical(probs_h) total_new_lp += dist_h.log_prob(batch_actions[:, h]) total_ent += dist_h.entropy() entropy = total_ent.mean() # PPO clipped loss ratio = torch.exp(total_new_lp - batch_old_lp) surr1 = ratio * batch_adv surr2 = torch.clamp(ratio, 1 - self.clip_epsilon, 1 + self.clip_epsilon) * batch_adv policy_loss = -torch.min(surr1, surr2).mean() # Value loss value_loss = F.mse_loss(values.squeeze(), batch_ret) # Total loss = policy_loss + self.value_coef * value_loss - self.entropy_coef * entropy self.optimizer.zero_grad() loss.backward() nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm) self.optimizer.step() total_loss += loss.item() total_policy_loss += policy_loss.item() total_value_loss += value_loss.item() total_entropy += entropy.item() update_count += 1 # 学习率调度 step if self.scheduler is not None: self.scheduler.step() self.reset_buffers() return { 'loss': total_loss / max(update_count, 1), 'policy_loss': total_policy_loss / max(update_count, 1), 'value_loss': total_value_loss / max(update_count, 1), 'entropy': total_entropy / max(update_count, 1), 'lr': self.optimizer.param_groups[0]['lr'], } def save(self, path: str): """保存模型""" torch.save({ 'policy_state_dict': self.policy.state_dict(), 'optimizer_state_dict': self.optimizer.state_dict(), }, path) def load(self, path: str): """加载模型""" checkpoint = torch.load(path, map_location=self.device, weights_only=False) self.policy.load_state_dict(checkpoint['policy_state_dict']) self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])