ctm-dqn/appo_agent.py

329 lines
11 KiB
Python

"""
Attention-based PPO Agent (APPO)
使用多头自注意力机制捕捉zone之间的空间依赖关系
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
from typing import Tuple, List, Dict
class MultiHeadAttention(nn.Module):
"""多头自注意力层"""
def __init__(self, d_model: int, num_heads: int = 4, dropout: float = 0.1):
super().__init__()
assert d_model % num_heads == 0
self.d_model = d_model
self.num_heads = num_heads
self.d_k = d_model // num_heads
self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, d_model)
self.W_v = nn.Linear(d_model, d_model)
self.W_o = nn.Linear(d_model, d_model)
self.dropout = nn.Dropout(dropout)
self.layer_norm = nn.LayerNorm(d_model)
def forward(self, x):
batch_size = x.size(0)
# Linear projections
Q = self.W_q(x).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
K = self.W_k(x).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
V = self.W_v(x).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
# Attention scores
scores = torch.matmul(Q, K.transpose(-2, -1)) / np.sqrt(self.d_k)
attn = F.softmax(scores, dim=-1)
attn = self.dropout(attn)
# Apply attention to values
context = torch.matmul(attn, V)
context = context.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)
# Output projection
output = self.W_o(context)
# Residual connection and layer norm
return self.layer_norm(x + self.dropout(output))
class AttentionActorCritic(nn.Module):
"""基于注意力机制的Actor-Critic网络"""
def __init__(
self,
state_dim: int,
num_actions: int,
hidden_dim: int = 256,
num_heads: int = 4,
num_attention_layers: int = 2,
):
super().__init__()
self.state_dim = state_dim
self.num_actions = num_actions
# 输入投影
self.input_proj = nn.Sequential(
nn.Linear(state_dim, hidden_dim),
nn.LayerNorm(hidden_dim),
nn.ReLU()
)
# 多层注意力
self.attention_layers = nn.ModuleList([
MultiHeadAttention(hidden_dim, num_heads) for _ in range(num_attention_layers)
])
# FFN layers
self.ffn_layers = nn.ModuleList([
nn.Sequential(
nn.Linear(hidden_dim, hidden_dim * 2),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(hidden_dim * 2, hidden_dim),
nn.Dropout(0.1),
nn.LayerNorm(hidden_dim)
) for _ in range(num_attention_layers)
])
# Actor head
self.actor = nn.Linear(hidden_dim, num_actions)
# Critic head
self.critic = nn.Sequential(
nn.Linear(hidden_dim, 128),
nn.ReLU(),
nn.Linear(128, 1)
)
self._init_weights()
def _init_weights(self):
for m in self.modules():
if isinstance(m, nn.Linear):
nn.init.orthogonal_(m.weight, gain=np.sqrt(2))
if m.bias is not None:
nn.init.constant_(m.bias, 0)
nn.init.orthogonal_(self.actor.weight, gain=0.01)
def forward(self, state):
batch_size = state.size(0)
x = self.input_proj(state)
x = x.unsqueeze(1)
for attn_layer, ffn_layer in zip(self.attention_layers, self.ffn_layers):
x = attn_layer(x)
x = x + ffn_layer(x)
x = x.squeeze(1)
logits = self.actor(x)
value = self.critic(x)
return logits, value
def get_value(self, state):
batch_size = state.size(0)
x = self.input_proj(state)
x = x.unsqueeze(1)
for attn_layer, ffn_layer in zip(self.attention_layers, self.ffn_layers):
x = attn_layer(x)
x = x + ffn_layer(x)
x = x.squeeze(1)
return self.critic(x)
class APPOAgent:
"""Attention-based PPO Agent"""
def __init__(
self,
state_dim: int,
num_actions: int,
hidden_dim: int = 256,
num_heads: int = 4,
num_attention_layers: int = 2,
learning_rate: float = 3e-4,
gamma: float = 0.99,
gae_lambda: float = 0.95,
clip_epsilon: float = 0.2,
value_coef: float = 0.5,
entropy_coef: float = 0.02,
max_grad_norm: float = 0.5,
ppo_epochs: int = 10,
minibatch_size: int = 64,
device: str = "cuda",
lr_schedule: str = "cosine",
total_episodes: int = 300,
):
self.device = torch.device(device if torch.cuda.is_available() else "cpu")
self.gamma = gamma
self.gae_lambda = gae_lambda
self.clip_epsilon = clip_epsilon
self.value_coef = value_coef
self.entropy_coef = entropy_coef
self.max_grad_norm = max_grad_norm
self.ppo_epochs = ppo_epochs
self.minibatch_size = minibatch_size
self.policy = AttentionActorCritic(
state_dim, num_actions, hidden_dim, num_heads, num_attention_layers
).to(self.device)
self.optimizer = optim.Adam(self.policy.parameters(), lr=learning_rate, eps=1e-5)
self.lr_schedule = lr_schedule
if lr_schedule == "cosine":
self.scheduler = optim.lr_scheduler.CosineAnnealingLR(
self.optimizer, T_max=total_episodes, eta_min=learning_rate * 0.1
)
else:
self.scheduler = None
self.reset_buffers()
def reset_buffers(self):
self.states = []
self.actions = []
self.rewards = []
self.values = []
self.log_probs = []
self.dones = []
def select_action(
self, state: np.ndarray, deterministic: bool = False
) -> Tuple[int, float, float]:
state_tensor = torch.FloatTensor(state).unsqueeze(0).to(self.device)
with torch.no_grad():
logits, value = self.policy(state_tensor)
probs = F.softmax(logits, dim=-1)
if deterministic:
action = torch.argmax(probs, dim=-1).item()
else:
dist = torch.distributions.Categorical(probs)
action = dist.sample().item()
log_prob = torch.log(probs[0, action] + 1e-10).item()
return action, log_prob, value.item()
def store_transition(self, state, action, reward, value, log_prob, done):
self.states.append(state)
self.actions.append(action)
self.rewards.append(reward)
self.values.append(value)
self.log_probs.append(log_prob)
self.dones.append(done)
def compute_gae(self, next_value: float) -> Tuple[np.ndarray, np.ndarray]:
advantages = []
gae = 0
for t in reversed(range(len(self.rewards))):
if t == len(self.rewards) - 1:
next_val = next_value
else:
next_val = self.values[t + 1]
delta = self.rewards[t] + self.gamma * next_val * (1 - self.dones[t]) - self.values[t]
gae = delta + self.gamma * self.gae_lambda * (1 - self.dones[t]) * gae
advantages.insert(0, gae)
advantages = np.array(advantages, dtype=np.float32)
returns = advantages + np.array(self.values, dtype=np.float32)
return advantages, returns
def update(self, next_value: float) -> Dict[str, float]:
if len(self.states) == 0:
return {}
advantages, returns = self.compute_gae(next_value)
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
states = torch.FloatTensor(np.array(self.states)).to(self.device)
actions = torch.LongTensor(self.actions).to(self.device)
old_log_probs = torch.FloatTensor(self.log_probs).to(self.device)
advantages_t = torch.FloatTensor(advantages).to(self.device)
returns_t = torch.FloatTensor(returns).to(self.device)
total_loss = 0
total_policy_loss = 0
total_value_loss = 0
total_entropy = 0
update_count = 0
dataset_size = len(self.states)
for _ in range(self.ppo_epochs):
indices = np.random.permutation(dataset_size)
for start_idx in range(0, dataset_size, self.minibatch_size):
end_idx = min(start_idx + self.minibatch_size, dataset_size)
batch_idx = indices[start_idx:end_idx]
batch_states = states[batch_idx]
batch_actions = actions[batch_idx]
batch_old_lp = old_log_probs[batch_idx]
batch_adv = advantages_t[batch_idx]
batch_ret = returns_t[batch_idx]
logits, values = self.policy(batch_states)
probs = F.softmax(logits, dim=-1)
dist = torch.distributions.Categorical(probs)
new_log_probs = dist.log_prob(batch_actions)
entropy = dist.entropy().mean()
ratio = torch.exp(new_log_probs - batch_old_lp)
surr1 = ratio * batch_adv
surr2 = torch.clamp(ratio, 1 - self.clip_epsilon, 1 + self.clip_epsilon) * batch_adv
policy_loss = -torch.min(surr1, surr2).mean()
value_loss = F.mse_loss(values.squeeze(), batch_ret)
loss = policy_loss + self.value_coef * value_loss - self.entropy_coef * entropy
self.optimizer.zero_grad()
loss.backward()
nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm)
self.optimizer.step()
total_loss += loss.item()
total_policy_loss += policy_loss.item()
total_value_loss += value_loss.item()
total_entropy += entropy.item()
update_count += 1
if self.scheduler is not None:
self.scheduler.step()
self.reset_buffers()
return {
'loss': total_loss / max(update_count, 1),
'policy_loss': total_policy_loss / max(update_count, 1),
'value_loss': total_value_loss / max(update_count, 1),
'entropy': total_entropy / max(update_count, 1),
'lr': self.optimizer.param_groups[0]['lr'],
}
def save(self, path: str):
torch.save({
'policy_state_dict': self.policy.state_dict(),
'optimizer_state_dict': self.optimizer.state_dict(),
}, path)
def load(self, path: str):
checkpoint = torch.load(path, map_location=self.device, weights_only=False)
self.policy.load_state_dict(checkpoint['policy_state_dict'])
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])