102 lines
3.3 KiB
Plaintext
102 lines
3.3 KiB
Plaintext
class APPOAgent:
|
|
"""Attention-based PPO Agent"""
|
|
|
|
def __init__(
|
|
self,
|
|
state_dim: int,
|
|
num_actions: int,
|
|
hidden_dim: int = 256,
|
|
num_heads: int = 4,
|
|
num_attention_layers: int = 2,
|
|
learning_rate: float = 3e-4,
|
|
gamma: float = 0.99,
|
|
gae_lambda: float = 0.95,
|
|
clip_epsilon: float = 0.2,
|
|
value_coef: float = 0.5,
|
|
entropy_coef: float = 0.02,
|
|
max_grad_norm: float = 0.5,
|
|
ppo_epochs: int = 10,
|
|
minibatch_size: int = 64,
|
|
device: str = "cuda",
|
|
lr_schedule: str = "cosine",
|
|
total_episodes: int = 300,
|
|
):
|
|
self.device = torch.device(device if torch.cuda.is_available() else "cpu")
|
|
self.gamma = gamma
|
|
self.gae_lambda = gae_lambda
|
|
self.clip_epsilon = clip_epsilon
|
|
self.value_coef = value_coef
|
|
self.entropy_coef = entropy_coef
|
|
self.max_grad_norm = max_grad_norm
|
|
self.ppo_epochs = ppo_epochs
|
|
self.minibatch_size = minibatch_size
|
|
|
|
self.policy = AttentionActorCritic(
|
|
state_dim, num_actions, hidden_dim, num_heads, num_attention_layers
|
|
).to(self.device)
|
|
|
|
self.optimizer = optim.Adam(self.policy.parameters(), lr=learning_rate, eps=1e-5)
|
|
|
|
self.lr_schedule = lr_schedule
|
|
if lr_schedule == "cosine":
|
|
self.scheduler = optim.lr_scheduler.CosineAnnealingLR(
|
|
self.optimizer, T_max=total_episodes, eta_min=learning_rate * 0.1
|
|
)
|
|
else:
|
|
self.scheduler = None
|
|
|
|
self.reset_buffers()
|
|
|
|
def reset_buffers(self):
|
|
self.states = []
|
|
self.actions = []
|
|
self.rewards = []
|
|
self.values = []
|
|
self.log_probs = []
|
|
self.dones = []
|
|
|
|
def select_action(
|
|
self, state: np.ndarray, deterministic: bool = False
|
|
) -> Tuple[int, float, float]:
|
|
state_tensor = torch.FloatTensor(state).unsqueeze(0).to(self.device)
|
|
|
|
with torch.no_grad():
|
|
logits, value = self.policy(state_tensor)
|
|
probs = F.softmax(logits, dim=-1)
|
|
|
|
if deterministic:
|
|
action = torch.argmax(probs, dim=-1).item()
|
|
else:
|
|
dist = torch.distributions.Categorical(probs)
|
|
action = dist.sample().item()
|
|
|
|
log_prob = torch.log(probs[0, action] + 1e-10).item()
|
|
|
|
return action, log_prob, value.item()
|
|
|
|
def store_transition(self, state, action, reward, value, log_prob, done):
|
|
self.states.append(state)
|
|
self.actions.append(action)
|
|
self.rewards.append(reward)
|
|
self.values.append(value)
|
|
self.log_probs.append(log_prob)
|
|
self.dones.append(done)
|
|
|
|
def compute_gae(self, next_value: float) -> Tuple[np.ndarray, np.ndarray]:
|
|
advantages = []
|
|
gae = 0
|
|
|
|
for t in reversed(range(len(self.rewards))):
|
|
if t == len(self.rewards) - 1:
|
|
next_val = next_value
|
|
else:
|
|
next_val = self.values[t + 1]
|
|
|
|
delta = self.rewards[t] + self.gamma * next_val * (1 - self.dones[t]) - self.values[t]
|
|
gae = delta + self.gamma * self.gae_lambda * (1 - self.dones[t]) * gae
|
|
advantages.insert(0, gae)
|
|
|
|
advantages = np.array(advantages, dtype=np.float32)
|
|
returns = advantages + np.array(self.values, dtype=np.float32)
|
|
|