diff --git a/appo_agent.py b/appo_agent.py new file mode 100644 index 0000000..34801e0 --- /dev/null +++ b/appo_agent.py @@ -0,0 +1,321 @@ +""" +APPO for SUMO - MultiDiscrete动作空间版本 +""" +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +import numpy as np +from typing import Tuple, List, Dict + + +class SpatialAttentionMultiDiscrete(nn.Module): + """空间注意力 + MultiDiscrete输出""" + + def __init__(self, d_model: int, num_heads: int = 4): + super().__init__() + assert d_model % num_heads == 0 + self.d_model = d_model + self.num_heads = num_heads + self.d_k = d_model // num_heads + + self.W_q = nn.Linear(d_model, d_model) + self.W_k = nn.Linear(d_model, d_model) + self.W_v = nn.Linear(d_model, d_model) + self.W_o = nn.Linear(d_model, d_model) + self.dropout = nn.Dropout(0.1) + self.layer_norm = nn.LayerNorm(d_model) + + def forward(self, x): + batch_size, seq_len, _ = x.size() + + Q = self.W_q(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2) + K = self.W_k(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2) + V = self.W_v(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2) + + scores = torch.matmul(Q, K.transpose(-2, -1)) / np.sqrt(self.d_k) + attn = F.softmax(scores, dim=-1) + attn = self.dropout(attn) + + context = torch.matmul(attn, V) + context = context.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model) + output = self.W_o(context) + + return self.layer_norm(x + self.dropout(output)) + + +class MultiDiscreteActorCritic(nn.Module): + """MultiDiscrete Actor-Critic with Spatial Attention""" + + def __init__( + self, + state_dim: int, + action_dims: List[int], + hidden_dim: int = 128, + num_heads: int = 4, + num_layers: int = 2, + ): + super().__init__() + self.state_dim = state_dim + self.action_dims = action_dims + self.num_zones = len(action_dims) + + # 状态编码 + self.state_encoder = nn.Sequential( + nn.Linear(state_dim, hidden_dim * self.num_zones), + nn.ReLU(), + nn.LayerNorm(hidden_dim * self.num_zones) + ) + + # 位置编码 + self.pos_encoding = nn.Parameter(torch.randn(1, self.num_zones, hidden_dim)) + + # 注意力层 + self.attention_layers = nn.ModuleList([ + SpatialAttentionMultiDiscrete(hidden_dim, num_heads) for _ in range(num_layers) + ]) + + self.ffn_layers = nn.ModuleList([ + nn.Sequential( + nn.Linear(hidden_dim, hidden_dim * 2), + nn.ReLU(), + nn.Dropout(0.1), + nn.Linear(hidden_dim * 2, hidden_dim), + nn.LayerNorm(hidden_dim) + ) for _ in range(num_layers) + ]) + + # Actor heads (每个zone独立) + self.actor_heads = nn.ModuleList([ + nn.Linear(hidden_dim, adim) for adim in action_dims + ]) + + # Critic + self.critic = nn.Sequential( + nn.Linear(hidden_dim * self.num_zones, hidden_dim), + nn.ReLU(), + nn.Linear(hidden_dim, 1) + ) + + self._init_weights() + + def _init_weights(self): + for m in self.modules(): + if isinstance(m, nn.Linear): + nn.init.orthogonal_(m.weight, gain=np.sqrt(2)) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + for head in self.actor_heads: + nn.init.orthogonal_(head.weight, gain=0.01) + + def forward(self, state): + batch_size = state.size(0) + + # 编码为zone表示 + x = self.state_encoder(state) + x = x.view(batch_size, self.num_zones, -1) + x = x + self.pos_encoding + + # 注意力 + for attn, ffn in zip(self.attention_layers, self.ffn_layers): + x = attn(x) + x = x + ffn(x) + + # Actor: 每个zone独立输出 + logits_list = [head(x[:, i, :]) for i, head in enumerate(self.actor_heads)] + + # Critic + global_feat = x.view(batch_size, -1) + value = self.critic(global_feat) + + return logits_list, value + + def get_value(self, state): + _, value = self.forward(state) + return value + + +class APPOAgent: + """APPO Agent for SUMO MultiDiscrete Action Space""" + + def __init__( + self, + state_dim: int, + action_dims: List[int], + hidden_dim: int = 128, + num_heads: int = 4, + num_layers: int = 2, + learning_rate: float = 3e-4, + gamma: float = 0.99, + gae_lambda: float = 0.95, + clip_epsilon: float = 0.2, + value_coef: float = 0.5, + entropy_coef: float = 0.02, + max_grad_norm: float = 0.5, + ppo_epochs: int = 10, + minibatch_size: int = 64, + device: str = "cuda", + total_episodes: int = 300, + ): + self.device = torch.device(device if torch.cuda.is_available() else "cpu") + self.gamma = gamma + self.gae_lambda = gae_lambda + self.clip_epsilon = clip_epsilon + self.value_coef = value_coef + self.entropy_coef = entropy_coef + self.max_grad_norm = max_grad_norm + self.ppo_epochs = ppo_epochs + self.minibatch_size = minibatch_size + self.action_dims = action_dims + + self.policy = MultiDiscreteActorCritic( + state_dim, action_dims, hidden_dim, num_heads, num_layers + ).to(self.device) + + self.optimizer = optim.Adam(self.policy.parameters(), lr=learning_rate, eps=1e-5) + self.scheduler = optim.lr_scheduler.CosineAnnealingLR( + self.optimizer, T_max=total_episodes, eta_min=learning_rate * 0.1 + ) + + self.reset_buffers() + + def reset_buffers(self): + self.states = [] + self.actions = [] + self.rewards = [] + self.values = [] + self.log_probs = [] + self.dones = [] + + def select_action(self, state: np.ndarray, deterministic: bool = False) -> Tuple[np.ndarray, float, float]: + state_tensor = torch.FloatTensor(state).unsqueeze(0).to(self.device) + + with torch.no_grad(): + logits_list, value = self.policy(state_tensor) + + actions = [] + log_prob_total = 0.0 + + for logits in logits_list: + probs = F.softmax(logits, dim=-1) + if deterministic: + action = torch.argmax(probs, dim=-1).item() + else: + dist = torch.distributions.Categorical(probs) + action = dist.sample().item() + actions.append(action) + log_prob_total += torch.log(probs[0, action] + 1e-10).item() + + return np.array(actions), log_prob_total, value.item() + + def store_transition(self, state, action, reward, value, log_prob, done): + self.states.append(state) + self.actions.append(action) + self.rewards.append(reward) + self.values.append(value) + self.log_probs.append(log_prob) + self.dones.append(done) + + def compute_gae(self, next_value: float) -> Tuple[np.ndarray, np.ndarray]: + advantages = [] + gae = 0 + + for t in reversed(range(len(self.rewards))): + if t == len(self.rewards) - 1: + next_val = next_value + else: + next_val = self.values[t + 1] + + delta = self.rewards[t] + self.gamma * next_val * (1 - self.dones[t]) - self.values[t] + gae = delta + self.gamma * self.gae_lambda * (1 - self.dones[t]) * gae + advantages.insert(0, gae) + + advantages = np.array(advantages, dtype=np.float32) + returns = advantages + np.array(self.values, dtype=np.float32) + return advantages, returns + + def update(self, next_value: float) -> Dict[str, float]: + if len(self.states) == 0: + return {} + + advantages, returns = self.compute_gae(next_value) + advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8) + + states = torch.FloatTensor(np.array(self.states)).to(self.device) + actions = torch.LongTensor(np.array(self.actions)).to(self.device) + old_log_probs = torch.FloatTensor(self.log_probs).to(self.device) + advantages_t = torch.FloatTensor(advantages).to(self.device) + returns_t = torch.FloatTensor(returns).to(self.device) + + total_policy_loss = 0 + total_value_loss = 0 + total_entropy = 0 + update_count = 0 + + dataset_size = len(self.states) + for _ in range(self.ppo_epochs): + indices = np.random.permutation(dataset_size) + + for start_idx in range(0, dataset_size, self.minibatch_size): + end_idx = min(start_idx + self.minibatch_size, dataset_size) + batch_idx = indices[start_idx:end_idx] + + batch_states = states[batch_idx] + batch_actions = actions[batch_idx] + batch_old_lp = old_log_probs[batch_idx] + batch_adv = advantages_t[batch_idx] + batch_ret = returns_t[batch_idx] + + logits_list, values = self.policy(batch_states) + + new_log_probs = torch.zeros(len(batch_idx), device=self.device) + entropy = torch.zeros(len(batch_idx), device=self.device) + + for i, logits in enumerate(logits_list): + probs = F.softmax(logits, dim=-1) + dist = torch.distributions.Categorical(probs) + new_log_probs += dist.log_prob(batch_actions[:, i]) + entropy += dist.entropy() + + new_log_probs_mean = new_log_probs + entropy_mean = entropy.mean() + + ratio = torch.exp(new_log_probs_mean - batch_old_lp) + surr1 = ratio * batch_adv + surr2 = torch.clamp(ratio, 1 - self.clip_epsilon, 1 + self.clip_epsilon) * batch_adv + policy_loss = -torch.min(surr1, surr2).mean() + + value_loss = F.mse_loss(values.squeeze(), batch_ret) + + loss = policy_loss + self.value_coef * value_loss - self.entropy_coef * entropy_mean + + self.optimizer.zero_grad() + loss.backward() + nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm) + self.optimizer.step() + + total_policy_loss += policy_loss.item() + total_value_loss += value_loss.item() + total_entropy += entropy_mean.item() + update_count += 1 + + self.scheduler.step() + self.reset_buffers() + + return { + 'policy_loss': total_policy_loss / max(update_count, 1), + 'value_loss': total_value_loss / max(update_count, 1), + 'entropy': total_entropy / max(update_count, 1), + } + + def save(self, path: str): + torch.save({ + 'policy_state_dict': self.policy.state_dict(), + 'optimizer_state_dict': self.optimizer.state_dict(), + }, path) + + def load(self, path: str): + checkpoint = torch.load(path, map_location=self.device, weights_only=False) + self.policy.load_state_dict(checkpoint['policy_state_dict']) + self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])