ctm-dqn/appo_v2_agent.py

364 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
APPO v2 - 改进的空间感知注意力机制
将状态分解为zone级别表示注意力在zone之间交互
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
from typing import Tuple, Dict
class SpatialAttention(nn.Module):
"""空间感知的多头注意力"""
def __init__(self, d_model: int, num_heads: int = 4):
super().__init__()
assert d_model % num_heads == 0
self.d_model = d_model
self.num_heads = num_heads
self.d_k = d_model // num_heads
self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, d_model)
self.W_v = nn.Linear(d_model, d_model)
self.W_o = nn.Linear(d_model, d_model)
self.dropout = nn.Dropout(0.1)
self.layer_norm = nn.LayerNorm(d_model)
def forward(self, x):
batch_size, seq_len, _ = x.size()
Q = self.W_q(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
K = self.W_k(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
V = self.W_v(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
scores = torch.matmul(Q, K.transpose(-2, -1)) / np.sqrt(self.d_k)
attn = F.softmax(scores, dim=-1)
attn = self.dropout(attn)
context = torch.matmul(attn, V)
context = context.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)
output = self.W_o(context)
return self.layer_norm(x + self.dropout(output))
class SpatialActorCritic(nn.Module):
"""空间感知的Actor-Critic网络"""
def __init__(
self,
state_dim: int,
num_actions: int,
num_zones: int = 3,
hidden_dim: int = 128,
num_heads: int = 4,
num_layers: int = 2,
):
super().__init__()
self.state_dim = state_dim
self.num_actions = num_actions
self.num_zones = num_zones
self.zone_dim = hidden_dim
# 状态编码器将全局状态投影到zone表示
self.state_encoder = nn.Sequential(
nn.Linear(state_dim, hidden_dim * num_zones),
nn.ReLU(),
nn.LayerNorm(hidden_dim * num_zones)
)
# 位置编码
# self.pos_encoding = nn.Parameter(torch.randn(1, num_zones, hidden_dim))
self.pos_encoding = nn.Parameter(torch.randn(1, num_zones, hidden_dim) * 0.02)
# 空间注意力层
self.attention_layers = nn.ModuleList([
SpatialAttention(hidden_dim, num_heads) for _ in range(num_layers)
])
self.ffn_layers = nn.ModuleList([
nn.Sequential(
nn.Linear(hidden_dim, hidden_dim * 2),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(hidden_dim * 2, hidden_dim),
nn.LayerNorm(hidden_dim)
) for _ in range(num_layers)
])
# Actor: 全局池化后输出动作
self.actor = nn.Sequential(
nn.Linear(hidden_dim * num_zones, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, num_actions)
)
# Critic: 全局池化后输出价值
self.critic = nn.Sequential(
nn.Linear(hidden_dim * num_zones, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, 1)
)
self._init_weights()
def _init_weights(self):
for m in self.modules():
if isinstance(m, nn.Linear):
nn.init.orthogonal_(m.weight, gain=np.sqrt(2))
if m.bias is not None:
nn.init.constant_(m.bias, 0)
nn.init.orthogonal_(self.actor[-1].weight, gain=0.01)
def forward(self, state):
batch_size = state.size(0)
x = self.state_encoder(state)
x = x.view(batch_size, self.num_zones, self.zone_dim)
x = x + self.pos_encoding
for attn, ffn in zip(self.attention_layers, self.ffn_layers):
x = attn(x)
x = x + ffn(x)
# 修复:将 x.view(batch_size, -1) 替换为平均池化,保留排列不变性
global_feat = x.mean(dim=1)
logits = self.actor(global_feat)
value = self.critic(global_feat)
return logits, value
# def forward(self, state):
# batch_size = state.size(0)
# # 编码为zone表示
# x = self.state_encoder(state)
# x = x.view(batch_size, self.num_zones, self.zone_dim)
# # 添加位置编码
# x = x + self.pos_encoding
# # 空间注意力
# for attn, ffn in zip(self.attention_layers, self.ffn_layers):
# x = attn(x)
# x = x + ffn(x)
# # 全局池化
# global_feat = x.view(batch_size, -1)
# logits = self.actor(global_feat)
# value = self.critic(global_feat)
# return logits, value
def get_value(self, state):
_, value = self.forward(state)
return value
class APPOv2Agent:
"""改进的APPO Agent"""
def __init__(
self,
state_dim: int,
num_actions: int,
num_zones: int = 3,
hidden_dim: int = 128,
num_heads: int = 4,
num_layers: int = 2,
learning_rate: float = 3e-4,
gamma: float = 0.99,
gae_lambda: float = 0.95,
clip_epsilon: float = 0.2,
value_coef: float = 0.5,
entropy_coef: float = 0.02,
max_grad_norm: float = 0.5,
ppo_epochs: int = 10,
minibatch_size: int = 64,
device: str = "cuda",
lr_schedule: str = "cosine",
total_episodes: int = 300,
):
self.device = torch.device(device if torch.cuda.is_available() else "cpu")
self.gamma = gamma
self.gae_lambda = gae_lambda
self.clip_epsilon = clip_epsilon
self.value_coef = value_coef
self.entropy_coef = entropy_coef
self.max_grad_norm = max_grad_norm
self.ppo_epochs = ppo_epochs
self.minibatch_size = minibatch_size
self.policy = SpatialActorCritic(
state_dim, num_actions, num_zones, hidden_dim, num_heads, num_layers
).to(self.device)
self.optimizer = optim.Adam(self.policy.parameters(), lr=learning_rate, eps=1e-5)
if lr_schedule == "cosine":
self.scheduler = optim.lr_scheduler.CosineAnnealingLR(
self.optimizer, T_max=total_episodes, eta_min=learning_rate * 0.1
)
else:
self.scheduler = None
self.reset_buffers()
def reset_buffers(self):
self.states = []
self.actions = []
self.rewards = []
self.values = []
self.log_probs = []
self.dones = []
def select_action(self, state: np.ndarray, deterministic: bool = False) -> Tuple[int, float, float]:
state_tensor = torch.FloatTensor(state).unsqueeze(0).to(self.device)
with torch.no_grad():
logits, value = self.policy(state_tensor)
probs = F.softmax(logits, dim=-1)
if deterministic:
action = torch.argmax(probs, dim=-1).item()
else:
dist = torch.distributions.Categorical(probs)
action = dist.sample().item()
log_prob = torch.log(probs[0, action] + 1e-10).item()
return action, log_prob, value.item()
def store_transition(self, state, action, reward, value, log_prob, done):
self.states.append(state)
self.actions.append(action)
self.rewards.append(reward)
self.values.append(value)
self.log_probs.append(log_prob)
self.dones.append(done)
def compute_gae(self, next_value: float) -> Tuple[np.ndarray, np.ndarray]:
# 预分配数组,将 O(N^2) 降维到 O(N)
length = len(self.rewards)
advantages = np.zeros(length, dtype=np.float32)
gae = 0
for t in reversed(range(length)):
if t == length - 1:
next_val = next_value
else:
next_val = self.values[t + 1]
delta = self.rewards[t] + self.gamma * next_val * (1 - self.dones[t]) - self.values[t]
gae = delta + self.gamma * self.gae_lambda * (1 - self.dones[t]) * gae
advantages[t] = gae # 直接索引赋值,不要头插
returns = advantages + np.array(self.values, dtype=np.float32)
return advantages, returns
# def compute_gae(self, next_value: float) -> Tuple[np.ndarray, np.ndarray]:
# advantages = []
# gae = 0
# for t in reversed(range(len(self.rewards))):
# if t == len(self.rewards) - 1:
# next_val = next_value
# else:
# next_val = self.values[t + 1]
# delta = self.rewards[t] + self.gamma * next_val * (1 - self.dones[t]) - self.values[t]
# gae = delta + self.gamma * self.gae_lambda * (1 - self.dones[t]) * gae
# advantages.insert(0, gae)
# advantages = np.array(advantages, dtype=np.float32)
# returns = advantages + np.array(self.values, dtype=np.float32)
# return advantages, returns
def update(self, next_value: float) -> Dict[str, float]:
if len(self.states) == 0:
return {}
advantages, returns = self.compute_gae(next_value)
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
states = torch.FloatTensor(np.array(self.states)).to(self.device)
actions = torch.LongTensor(self.actions).to(self.device)
old_log_probs = torch.FloatTensor(self.log_probs).to(self.device)
advantages_t = torch.FloatTensor(advantages).to(self.device)
returns_t = torch.FloatTensor(returns).to(self.device)
total_loss = 0
total_policy_loss = 0
total_value_loss = 0
total_entropy = 0
update_count = 0
dataset_size = len(self.states)
for _ in range(self.ppo_epochs):
indices = np.random.permutation(dataset_size)
for start_idx in range(0, dataset_size, self.minibatch_size):
end_idx = min(start_idx + self.minibatch_size, dataset_size)
batch_idx = indices[start_idx:end_idx]
batch_states = states[batch_idx]
batch_actions = actions[batch_idx]
batch_old_lp = old_log_probs[batch_idx]
batch_adv = advantages_t[batch_idx]
batch_ret = returns_t[batch_idx]
logits, values = self.policy(batch_states)
probs = F.softmax(logits, dim=-1)
dist = torch.distributions.Categorical(probs)
new_log_probs = dist.log_prob(batch_actions)
entropy = dist.entropy().mean()
ratio = torch.exp(new_log_probs - batch_old_lp)
surr1 = ratio * batch_adv
surr2 = torch.clamp(ratio, 1 - self.clip_epsilon, 1 + self.clip_epsilon) * batch_adv
policy_loss = -torch.min(surr1, surr2).mean()
value_loss = F.mse_loss(values.squeeze(), batch_ret)
loss = policy_loss + self.value_coef * value_loss - self.entropy_coef * entropy
self.optimizer.zero_grad()
loss.backward()
nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm)
self.optimizer.step()
total_loss += loss.item()
total_policy_loss += policy_loss.item()
total_value_loss += value_loss.item()
total_entropy += entropy.item()
update_count += 1
if self.scheduler is not None:
self.scheduler.step()
self.reset_buffers()
return {
'loss': total_loss / max(update_count, 1),
'policy_loss': total_policy_loss / max(update_count, 1),
'value_loss': total_value_loss / max(update_count, 1),
'entropy': total_entropy / max(update_count, 1),
'lr': self.optimizer.param_groups[0]['lr'],
}
def save(self, path: str):
torch.save({
'policy_state_dict': self.policy.state_dict(),
'optimizer_state_dict': self.optimizer.state_dict(),
}, path)
def load(self, path: str):
checkpoint = torch.load(path, map_location=self.device, weights_only=False)
self.policy.load_state_dict(checkpoint['policy_state_dict'])
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])