311 lines
11 KiB
Python
311 lines
11 KiB
Python
"""
|
|
PPO Agent Implementation (优化版)
|
|
支持 MultiDiscrete 动作空间,适配 VSL 多zone独立控制
|
|
|
|
优化点:
|
|
1. MultiDiscrete: 每zone独立actor head, 避免指数爆炸的动作空间
|
|
2. LayerNorm + 更深网络: 改善训练稳定性
|
|
3. 学习率余弦退火调度
|
|
4. 正交初始化
|
|
"""
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
import torch.optim as optim
|
|
import numpy as np
|
|
from typing import Tuple, List, Dict
|
|
|
|
|
|
class MultiDiscreteActorCritic(nn.Module):
|
|
"""
|
|
MultiDiscrete Actor-Critic 网络
|
|
|
|
Actor: 共享特征提取 -> 每zone独立head输出该zone的速度档位概率
|
|
Critic: 共享特征提取 -> 单头输出状态价值
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
state_dim: int,
|
|
action_dims: List[int], # e.g. [5, 5, 5, 5, 5] 每zone的动作数
|
|
hidden_layers: List[int] = [256, 256],
|
|
):
|
|
super().__init__()
|
|
self.state_dim = state_dim
|
|
self.action_dims = action_dims
|
|
self.num_heads = len(action_dims)
|
|
|
|
# Shared feature extractor with LayerNorm
|
|
layers = []
|
|
prev_dim = state_dim
|
|
for hidden_dim in hidden_layers:
|
|
layers.append(nn.Linear(prev_dim, hidden_dim))
|
|
layers.append(nn.LayerNorm(hidden_dim))
|
|
layers.append(nn.ReLU())
|
|
prev_dim = hidden_dim
|
|
self.feature_extractor = nn.Sequential(*layers)
|
|
|
|
# 每zone独立的 actor head
|
|
self.actor_heads = nn.ModuleList([
|
|
nn.Linear(prev_dim, adim) for adim in action_dims
|
|
])
|
|
|
|
# Critic head
|
|
self.critic = nn.Sequential(
|
|
nn.Linear(prev_dim, 128),
|
|
nn.ReLU(),
|
|
nn.Linear(128, 1),
|
|
)
|
|
|
|
# 正交初始化
|
|
self._init_weights()
|
|
|
|
def _init_weights(self):
|
|
for m in self.modules():
|
|
if isinstance(m, nn.Linear):
|
|
nn.init.orthogonal_(m.weight, gain=np.sqrt(2))
|
|
nn.init.constant_(m.bias, 0)
|
|
# Actor heads 用更小的gain
|
|
for head in self.actor_heads:
|
|
nn.init.orthogonal_(head.weight, gain=0.01)
|
|
# Critic 最后一层也用小gain
|
|
nn.init.orthogonal_(self.critic[-1].weight, gain=1.0)
|
|
|
|
def forward(self, state):
|
|
features = self.feature_extractor(state)
|
|
logits_list = [head(features) for head in self.actor_heads]
|
|
value = self.critic(features)
|
|
return logits_list, value
|
|
|
|
def get_action_probs(self, state):
|
|
"""获取每个zone的动作概率分布"""
|
|
logits_list, value = self.forward(state)
|
|
probs_list = [F.softmax(logits, dim=-1) for logits in logits_list]
|
|
return probs_list, value
|
|
|
|
def get_value(self, state):
|
|
features = self.feature_extractor(state)
|
|
return self.critic(features)
|
|
|
|
|
|
class PPOAgent:
|
|
"""PPO智能体 (MultiDiscrete 版)"""
|
|
|
|
def __init__(
|
|
self,
|
|
state_dim: int,
|
|
action_dims: List[int], # MultiDiscrete: 每zone的动作数
|
|
hidden_layers: List[int] = [256, 256],
|
|
learning_rate: float = 3e-4,
|
|
gamma: float = 0.99,
|
|
gae_lambda: float = 0.95,
|
|
clip_epsilon: float = 0.2,
|
|
value_coef: float = 0.5,
|
|
entropy_coef: float = 0.02,
|
|
max_grad_norm: float = 0.5,
|
|
ppo_epochs: int = 10,
|
|
minibatch_size: int = 64,
|
|
device: str = "cuda",
|
|
lr_schedule: str = "cosine", # "cosine" or "none"
|
|
total_episodes: int = 300,
|
|
):
|
|
self.device = torch.device(device if torch.cuda.is_available() else "cpu")
|
|
self.gamma = gamma
|
|
self.gae_lambda = gae_lambda
|
|
self.clip_epsilon = clip_epsilon
|
|
self.value_coef = value_coef
|
|
self.entropy_coef = entropy_coef
|
|
self.max_grad_norm = max_grad_norm
|
|
self.ppo_epochs = ppo_epochs
|
|
self.minibatch_size = minibatch_size
|
|
self.action_dims = action_dims
|
|
self.num_heads = len(action_dims)
|
|
|
|
# Create network
|
|
self.policy = MultiDiscreteActorCritic(
|
|
state_dim, action_dims, hidden_layers
|
|
).to(self.device)
|
|
self.optimizer = optim.Adam(self.policy.parameters(), lr=learning_rate, eps=1e-5)
|
|
|
|
# 学习率调度
|
|
self.lr_schedule = lr_schedule
|
|
if lr_schedule == "cosine":
|
|
self.scheduler = optim.lr_scheduler.CosineAnnealingLR(
|
|
self.optimizer, T_max=total_episodes, eta_min=learning_rate * 0.1
|
|
)
|
|
else:
|
|
self.scheduler = None
|
|
|
|
# Buffers
|
|
self.reset_buffers()
|
|
|
|
def reset_buffers(self):
|
|
"""重置经验缓冲"""
|
|
self.states = []
|
|
self.actions = [] # List of np.ndarray, each shape=(num_heads,)
|
|
self.rewards = []
|
|
self.values = []
|
|
self.log_probs = [] # List of float (sum of all heads' log_probs)
|
|
self.dones = []
|
|
|
|
def select_action(
|
|
self, state: np.ndarray, deterministic: bool = False
|
|
) -> Tuple[np.ndarray, float, float]:
|
|
"""
|
|
选择动作
|
|
|
|
Returns:
|
|
action: np.ndarray shape=(num_heads,), 每个zone的动作index
|
|
log_prob: float, 所有zone log_prob之和
|
|
value: float
|
|
"""
|
|
state_tensor = torch.FloatTensor(state).unsqueeze(0).to(self.device)
|
|
|
|
with torch.no_grad():
|
|
probs_list, value = self.policy.get_action_probs(state_tensor)
|
|
|
|
actions = []
|
|
total_log_prob = 0.0
|
|
|
|
for i, probs in enumerate(probs_list):
|
|
if deterministic:
|
|
a = torch.argmax(probs, dim=-1).item()
|
|
else:
|
|
dist = torch.distributions.Categorical(probs)
|
|
a = dist.sample().item()
|
|
log_p = torch.log(probs[0, a] + 1e-10).item()
|
|
actions.append(a)
|
|
total_log_prob += log_p
|
|
|
|
return np.array(actions, dtype=np.int64), total_log_prob, value.item()
|
|
|
|
def store_transition(self, state, action, reward, value, log_prob, done):
|
|
"""存储转换"""
|
|
self.states.append(state)
|
|
self.actions.append(action) # np.ndarray shape=(num_heads,)
|
|
self.rewards.append(reward)
|
|
self.values.append(value)
|
|
self.log_probs.append(log_prob)
|
|
self.dones.append(done)
|
|
|
|
def compute_gae(self, next_value: float) -> Tuple[np.ndarray, np.ndarray]:
|
|
"""计算GAE优势函数"""
|
|
advantages = []
|
|
gae = 0
|
|
|
|
for t in reversed(range(len(self.rewards))):
|
|
if t == len(self.rewards) - 1:
|
|
next_val = next_value
|
|
else:
|
|
next_val = self.values[t + 1]
|
|
|
|
delta = self.rewards[t] + self.gamma * next_val * (1 - self.dones[t]) - self.values[t]
|
|
gae = delta + self.gamma * self.gae_lambda * (1 - self.dones[t]) * gae
|
|
advantages.insert(0, gae)
|
|
|
|
advantages = np.array(advantages, dtype=np.float32)
|
|
returns = advantages + np.array(self.values, dtype=np.float32)
|
|
|
|
return advantages, returns
|
|
|
|
def update(self, next_value: float) -> Dict[str, float]:
|
|
"""PPO更新"""
|
|
if len(self.states) == 0:
|
|
return {}
|
|
|
|
advantages, returns = self.compute_gae(next_value)
|
|
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
|
|
|
|
# Convert to tensors
|
|
states = torch.FloatTensor(np.array(self.states)).to(self.device)
|
|
# actions: shape (N, num_heads)
|
|
actions = torch.LongTensor(np.array(self.actions)).to(self.device)
|
|
old_log_probs = torch.FloatTensor(self.log_probs).to(self.device)
|
|
advantages_t = torch.FloatTensor(advantages).to(self.device)
|
|
returns_t = torch.FloatTensor(returns).to(self.device)
|
|
|
|
total_loss = 0
|
|
total_policy_loss = 0
|
|
total_value_loss = 0
|
|
total_entropy = 0
|
|
update_count = 0
|
|
|
|
dataset_size = len(self.states)
|
|
for _ in range(self.ppo_epochs):
|
|
indices = np.random.permutation(dataset_size)
|
|
|
|
for start_idx in range(0, dataset_size, self.minibatch_size):
|
|
end_idx = min(start_idx + self.minibatch_size, dataset_size)
|
|
batch_idx = indices[start_idx:end_idx]
|
|
|
|
batch_states = states[batch_idx]
|
|
batch_actions = actions[batch_idx] # (B, num_heads)
|
|
batch_old_lp = old_log_probs[batch_idx]
|
|
batch_adv = advantages_t[batch_idx]
|
|
batch_ret = returns_t[batch_idx]
|
|
|
|
# Forward
|
|
logits_list, values = self.policy(batch_states)
|
|
|
|
# 计算每个head的log_prob和entropy, 然后求和
|
|
total_new_lp = torch.zeros(len(batch_idx), device=self.device)
|
|
total_ent = torch.zeros(len(batch_idx), device=self.device)
|
|
|
|
for h in range(self.num_heads):
|
|
probs_h = F.softmax(logits_list[h], dim=-1)
|
|
dist_h = torch.distributions.Categorical(probs_h)
|
|
total_new_lp += dist_h.log_prob(batch_actions[:, h])
|
|
total_ent += dist_h.entropy()
|
|
|
|
entropy = total_ent.mean()
|
|
|
|
# PPO clipped loss
|
|
ratio = torch.exp(total_new_lp - batch_old_lp)
|
|
surr1 = ratio * batch_adv
|
|
surr2 = torch.clamp(ratio, 1 - self.clip_epsilon, 1 + self.clip_epsilon) * batch_adv
|
|
policy_loss = -torch.min(surr1, surr2).mean()
|
|
|
|
# Value loss
|
|
value_loss = F.mse_loss(values.squeeze(), batch_ret)
|
|
|
|
# Total
|
|
loss = policy_loss + self.value_coef * value_loss - self.entropy_coef * entropy
|
|
|
|
self.optimizer.zero_grad()
|
|
loss.backward()
|
|
nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm)
|
|
self.optimizer.step()
|
|
|
|
total_loss += loss.item()
|
|
total_policy_loss += policy_loss.item()
|
|
total_value_loss += value_loss.item()
|
|
total_entropy += entropy.item()
|
|
update_count += 1
|
|
|
|
# 学习率调度 step
|
|
if self.scheduler is not None:
|
|
self.scheduler.step()
|
|
|
|
self.reset_buffers()
|
|
|
|
return {
|
|
'loss': total_loss / max(update_count, 1),
|
|
'policy_loss': total_policy_loss / max(update_count, 1),
|
|
'value_loss': total_value_loss / max(update_count, 1),
|
|
'entropy': total_entropy / max(update_count, 1),
|
|
'lr': self.optimizer.param_groups[0]['lr'],
|
|
}
|
|
|
|
def save(self, path: str):
|
|
"""保存模型"""
|
|
torch.save({
|
|
'policy_state_dict': self.policy.state_dict(),
|
|
'optimizer_state_dict': self.optimizer.state_dict(),
|
|
}, path)
|
|
|
|
def load(self, path: str):
|
|
"""加载模型"""
|
|
checkpoint = torch.load(path, map_location=self.device, weights_only=False)
|
|
self.policy.load_state_dict(checkpoint['policy_state_dict'])
|
|
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|