""" Attention-based PPO Agent (APPO) 使用多头自注意力机制捕捉zone之间的空间依赖关系 """ import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim import numpy as np from typing import Tuple, List, Dict class MultiHeadAttention(nn.Module): """多头自注意力层""" def __init__(self, d_model: int, num_heads: int = 4, dropout: float = 0.1): super().__init__() assert d_model % num_heads == 0 self.d_model = d_model self.num_heads = num_heads self.d_k = d_model // num_heads self.W_q = nn.Linear(d_model, d_model) self.W_k = nn.Linear(d_model, d_model) self.W_v = nn.Linear(d_model, d_model) self.W_o = nn.Linear(d_model, d_model) self.dropout = nn.Dropout(dropout) self.layer_norm = nn.LayerNorm(d_model) def forward(self, x): batch_size = x.size(0) # Linear projections Q = self.W_q(x).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2) K = self.W_k(x).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2) V = self.W_v(x).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2) # Attention scores scores = torch.matmul(Q, K.transpose(-2, -1)) / np.sqrt(self.d_k) attn = F.softmax(scores, dim=-1) attn = self.dropout(attn) # Apply attention to values context = torch.matmul(attn, V) context = context.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model) # Output projection output = self.W_o(context) # Residual connection and layer norm return self.layer_norm(x + self.dropout(output)) class AttentionActorCritic(nn.Module): """基于注意力机制的Actor-Critic网络""" def __init__( self, state_dim: int, num_actions: int, hidden_dim: int = 256, num_heads: int = 4, num_attention_layers: int = 2, ): super().__init__() self.state_dim = state_dim self.num_actions = num_actions # 输入投影 self.input_proj = nn.Sequential( nn.Linear(state_dim, hidden_dim), nn.LayerNorm(hidden_dim), nn.ReLU() ) # 多层注意力 self.attention_layers = nn.ModuleList([ MultiHeadAttention(hidden_dim, num_heads) for _ in range(num_attention_layers) ]) # FFN layers self.ffn_layers = nn.ModuleList([ nn.Sequential( nn.Linear(hidden_dim, hidden_dim * 2), nn.ReLU(), nn.Dropout(0.1), nn.Linear(hidden_dim * 2, hidden_dim), nn.Dropout(0.1), nn.LayerNorm(hidden_dim) ) for _ in range(num_attention_layers) ]) # Actor head self.actor = nn.Linear(hidden_dim, num_actions) # Critic head self.critic = nn.Sequential( nn.Linear(hidden_dim, 128), nn.ReLU(), nn.Linear(128, 1) ) self._init_weights() def _init_weights(self): for m in self.modules(): if isinstance(m, nn.Linear): nn.init.orthogonal_(m.weight, gain=np.sqrt(2)) if m.bias is not None: nn.init.constant_(m.bias, 0) nn.init.orthogonal_(self.actor.weight, gain=0.01) def forward(self, state): batch_size = state.size(0) x = self.input_proj(state) x = x.unsqueeze(1) for attn_layer, ffn_layer in zip(self.attention_layers, self.ffn_layers): x = attn_layer(x) x = x + ffn_layer(x) x = x.squeeze(1) logits = self.actor(x) value = self.critic(x) return logits, value def get_value(self, state): batch_size = state.size(0) x = self.input_proj(state) x = x.unsqueeze(1) for attn_layer, ffn_layer in zip(self.attention_layers, self.ffn_layers): x = attn_layer(x) x = x + ffn_layer(x) x = x.squeeze(1) return self.critic(x) class APPOAgent: """Attention-based PPO Agent""" def __init__( self, state_dim: int, num_actions: int, hidden_dim: int = 256, num_heads: int = 4, num_attention_layers: int = 2, learning_rate: float = 3e-4, gamma: float = 0.99, gae_lambda: float = 0.95, clip_epsilon: float = 0.2, value_coef: float = 0.5, entropy_coef: float = 0.02, max_grad_norm: float = 0.5, ppo_epochs: int = 10, minibatch_size: int = 64, device: str = "cuda", lr_schedule: str = "cosine", total_episodes: int = 300, ): self.device = torch.device(device if torch.cuda.is_available() else "cpu") self.gamma = gamma self.gae_lambda = gae_lambda self.clip_epsilon = clip_epsilon self.value_coef = value_coef self.entropy_coef = entropy_coef self.max_grad_norm = max_grad_norm self.ppo_epochs = ppo_epochs self.minibatch_size = minibatch_size self.policy = AttentionActorCritic( state_dim, num_actions, hidden_dim, num_heads, num_attention_layers ).to(self.device) self.optimizer = optim.Adam(self.policy.parameters(), lr=learning_rate, eps=1e-5) self.lr_schedule = lr_schedule if lr_schedule == "cosine": self.scheduler = optim.lr_scheduler.CosineAnnealingLR( self.optimizer, T_max=total_episodes, eta_min=learning_rate * 0.1 ) else: self.scheduler = None self.reset_buffers() def reset_buffers(self): self.states = [] self.actions = [] self.rewards = [] self.values = [] self.log_probs = [] self.dones = [] def select_action( self, state: np.ndarray, deterministic: bool = False ) -> Tuple[int, float, float]: state_tensor = torch.FloatTensor(state).unsqueeze(0).to(self.device) with torch.no_grad(): logits, value = self.policy(state_tensor) probs = F.softmax(logits, dim=-1) if deterministic: action = torch.argmax(probs, dim=-1).item() else: dist = torch.distributions.Categorical(probs) action = dist.sample().item() log_prob = torch.log(probs[0, action] + 1e-10).item() return action, log_prob, value.item() def store_transition(self, state, action, reward, value, log_prob, done): self.states.append(state) self.actions.append(action) self.rewards.append(reward) self.values.append(value) self.log_probs.append(log_prob) self.dones.append(done) def compute_gae(self, next_value: float) -> Tuple[np.ndarray, np.ndarray]: advantages = [] gae = 0 for t in reversed(range(len(self.rewards))): if t == len(self.rewards) - 1: next_val = next_value else: next_val = self.values[t + 1] delta = self.rewards[t] + self.gamma * next_val * (1 - self.dones[t]) - self.values[t] gae = delta + self.gamma * self.gae_lambda * (1 - self.dones[t]) * gae advantages.insert(0, gae) advantages = np.array(advantages, dtype=np.float32) returns = advantages + np.array(self.values, dtype=np.float32) return advantages, returns def update(self, next_value: float) -> Dict[str, float]: if len(self.states) == 0: return {} advantages, returns = self.compute_gae(next_value) advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8) states = torch.FloatTensor(np.array(self.states)).to(self.device) actions = torch.LongTensor(self.actions).to(self.device) old_log_probs = torch.FloatTensor(self.log_probs).to(self.device) advantages_t = torch.FloatTensor(advantages).to(self.device) returns_t = torch.FloatTensor(returns).to(self.device) total_loss = 0 total_policy_loss = 0 total_value_loss = 0 total_entropy = 0 update_count = 0 dataset_size = len(self.states) for _ in range(self.ppo_epochs): indices = np.random.permutation(dataset_size) for start_idx in range(0, dataset_size, self.minibatch_size): end_idx = min(start_idx + self.minibatch_size, dataset_size) batch_idx = indices[start_idx:end_idx] batch_states = states[batch_idx] batch_actions = actions[batch_idx] batch_old_lp = old_log_probs[batch_idx] batch_adv = advantages_t[batch_idx] batch_ret = returns_t[batch_idx] logits, values = self.policy(batch_states) probs = F.softmax(logits, dim=-1) dist = torch.distributions.Categorical(probs) new_log_probs = dist.log_prob(batch_actions) entropy = dist.entropy().mean() ratio = torch.exp(new_log_probs - batch_old_lp) surr1 = ratio * batch_adv surr2 = torch.clamp(ratio, 1 - self.clip_epsilon, 1 + self.clip_epsilon) * batch_adv policy_loss = -torch.min(surr1, surr2).mean() value_loss = F.mse_loss(values.squeeze(), batch_ret) loss = policy_loss + self.value_coef * value_loss - self.entropy_coef * entropy self.optimizer.zero_grad() loss.backward() nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm) self.optimizer.step() total_loss += loss.item() total_policy_loss += policy_loss.item() total_value_loss += value_loss.item() total_entropy += entropy.item() update_count += 1 if self.scheduler is not None: self.scheduler.step() self.reset_buffers() return { 'loss': total_loss / max(update_count, 1), 'policy_loss': total_policy_loss / max(update_count, 1), 'value_loss': total_value_loss / max(update_count, 1), 'entropy': total_entropy / max(update_count, 1), 'lr': self.optimizer.param_groups[0]['lr'], } def save(self, path: str): torch.save({ 'policy_state_dict': self.policy.state_dict(), 'optimizer_state_dict': self.optimizer.state_dict(), }, path) def load(self, path: str): checkpoint = torch.load(path, map_location=self.device, weights_only=False) self.policy.load_state_dict(checkpoint['policy_state_dict']) self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])