ctm-dqn/agents/gpro_agent.py

353 lines
14 KiB
Python

"""PPO-style GPRO agent with group-relative advantages for corridor control."""
from __future__ import annotations
from typing import Dict, List, Tuple
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
class MultiDiscreteActorCritic(nn.Module):
"""Shared actor-critic backbone with one categorical head per control segment."""
def __init__(
self,
state_dim: int,
action_dims: List[int],
hidden_layers: List[int] = [256, 256],
):
super().__init__()
self.action_dims = action_dims
layers = []
prev_dim = state_dim
for hidden_dim in hidden_layers:
layers.append(nn.Linear(prev_dim, hidden_dim))
layers.append(nn.LayerNorm(hidden_dim))
layers.append(nn.ReLU())
prev_dim = hidden_dim
self.feature_extractor = nn.Sequential(*layers)
self.actor_heads = nn.ModuleList(
[nn.Linear(prev_dim, action_dim) for action_dim in action_dims]
)
self.critic = nn.Sequential(
nn.Linear(prev_dim, 128),
nn.ReLU(),
nn.Linear(128, 1),
)
self._init_weights()
def _init_weights(self):
for module in self.modules():
if isinstance(module, nn.Linear):
nn.init.orthogonal_(module.weight, gain=np.sqrt(2))
nn.init.constant_(module.bias, 0)
for head in self.actor_heads:
nn.init.orthogonal_(head.weight, gain=0.01)
nn.init.orthogonal_(self.critic[-1].weight, gain=1.0)
def forward(self, state: torch.Tensor) -> Tuple[List[torch.Tensor], torch.Tensor]:
features = self.feature_extractor(state)
logits_list = [head(features) for head in self.actor_heads]
value = self.critic(features)
return logits_list, value
def get_action_probs(self, state: torch.Tensor) -> Tuple[List[torch.Tensor], torch.Tensor]:
logits_list, value = self.forward(state)
probs_list = [F.softmax(logits, dim=-1) for logits in logits_list]
return probs_list, value
def get_value(self, state: torch.Tensor) -> torch.Tensor:
features = self.feature_extractor(state)
return self.critic(features)
class GPROAgent:
"""PPO actor-critic with group-relative trajectory ranking."""
def __init__(
self,
state_dim: int,
action_dims: List[int],
hidden_layers: List[int] = [256, 256],
learning_rate: float = 3e-4,
gamma: float = 0.99,
gae_lambda: float = 0.95,
clip_epsilon: float = 0.2,
value_coef: float = 0.5,
entropy_coef: float = 0.01,
max_grad_norm: float = 0.5,
ppo_epochs: int = 4,
minibatch_size: int = 64,
group_size: int = 4,
group_advantage_coef: float = 0.35,
advantage_epsilon: float = 1e-8,
device: str = "cuda",
lr_schedule: str = "cosine",
total_episodes: int = 300,
):
self.device = torch.device(device if torch.cuda.is_available() else "cpu")
self.gamma = gamma
self.gae_lambda = gae_lambda
self.clip_epsilon = clip_epsilon
self.value_coef = value_coef
self.entropy_coef = entropy_coef
self.max_grad_norm = max_grad_norm
self.ppo_epochs = ppo_epochs
self.minibatch_size = minibatch_size
self.group_size = max(int(group_size), 2)
self.group_advantage_coef = float(np.clip(group_advantage_coef, 0.0, 1.0))
self.advantage_epsilon = advantage_epsilon
self.action_dims = action_dims
self.num_heads = len(action_dims)
self.policy = MultiDiscreteActorCritic(state_dim, action_dims, hidden_layers).to(
self.device
)
self.optimizer = optim.Adam(self.policy.parameters(), lr=learning_rate, eps=1e-5)
self.lr_schedule = lr_schedule
if lr_schedule == "cosine":
self.scheduler = optim.lr_scheduler.CosineAnnealingLR(
self.optimizer, T_max=total_episodes, eta_min=learning_rate * 0.1
)
else:
self.scheduler = None
self.reset_group_buffers()
self._reset_episode_buffer()
def _reset_episode_buffer(self):
self.current_states: List[np.ndarray] = []
self.current_actions: List[np.ndarray] = []
self.current_rewards: List[float] = []
self.current_values: List[float] = []
self.current_log_probs: List[float] = []
self.current_dones: List[float] = []
def reset_group_buffers(self):
self.trajectories: List[Dict[str, np.ndarray | float]] = []
self._reset_episode_buffer()
def select_action(
self, state: np.ndarray, deterministic: bool = False
) -> Tuple[np.ndarray, float, float]:
state_tensor = torch.FloatTensor(state).unsqueeze(0).to(self.device)
with torch.no_grad():
probs_list, value = self.policy.get_action_probs(state_tensor)
actions = []
total_log_prob = 0.0
for probs in probs_list:
if deterministic:
action = torch.argmax(probs, dim=-1).item()
else:
dist = torch.distributions.Categorical(probs)
action = dist.sample().item()
log_prob = torch.log(probs[0, action] + 1e-10).item()
actions.append(action)
total_log_prob += log_prob
return np.array(actions, dtype=np.int64), total_log_prob, value.item()
def store_transition(self, state, action, reward, value, log_prob, done):
self.current_states.append(np.asarray(state, dtype=np.float32))
self.current_actions.append(np.asarray(action, dtype=np.int64))
self.current_rewards.append(float(reward))
self.current_values.append(float(value))
self.current_log_probs.append(float(log_prob))
self.current_dones.append(float(done))
def finish_episode(self, score: float):
if not self.current_states:
return
self.trajectories.append(
{
"states": np.asarray(self.current_states, dtype=np.float32),
"actions": np.asarray(self.current_actions, dtype=np.int64),
"rewards": np.asarray(self.current_rewards, dtype=np.float32),
"values": np.asarray(self.current_values, dtype=np.float32),
"log_probs": np.asarray(self.current_log_probs, dtype=np.float32),
"dones": np.asarray(self.current_dones, dtype=np.float32),
"score": float(score),
}
)
self._reset_episode_buffer()
def _compute_gae(
self,
rewards: np.ndarray,
values: np.ndarray,
dones: np.ndarray,
next_value: float = 0.0,
) -> Tuple[np.ndarray, np.ndarray]:
advantages = np.zeros_like(rewards, dtype=np.float32)
gae = 0.0
for t in reversed(range(len(rewards))):
if t == len(rewards) - 1:
next_val = next_value
else:
next_val = values[t + 1]
delta = rewards[t] + self.gamma * next_val * (1.0 - dones[t]) - values[t]
gae = delta + self.gamma * self.gae_lambda * (1.0 - dones[t]) * gae
advantages[t] = gae
returns = advantages + values
return advantages, returns
def _normalize(self, values: np.ndarray) -> np.ndarray:
if values.size == 0:
return values
mean = float(values.mean())
std = float(values.std())
if std < self.advantage_epsilon:
return np.zeros_like(values, dtype=np.float32)
return ((values - mean) / (std + self.advantage_epsilon)).astype(np.float32)
def _build_training_targets(self) -> Tuple[np.ndarray, np.ndarray]:
if not self.trajectories:
return np.array([], dtype=np.float32), np.array([], dtype=np.float32)
gae_advantages = []
returns = []
for trajectory in self.trajectories:
trajectory_adv, trajectory_ret = self._compute_gae(
rewards=trajectory["rewards"],
values=trajectory["values"],
dones=trajectory["dones"],
next_value=0.0,
)
gae_advantages.append(trajectory_adv)
returns.append(trajectory_ret)
normalized_gae = self._normalize(np.concatenate(gae_advantages, axis=0))
group_scores = np.asarray(
[trajectory["score"] for trajectory in self.trajectories], dtype=np.float32
)
normalized_group_scores = self._normalize(group_scores)
repeated_group_advantages = np.concatenate(
[
np.full(len(trajectory["states"]), normalized_group_scores[idx], dtype=np.float32)
for idx, trajectory in enumerate(self.trajectories)
],
axis=0,
)
combined_advantages = (
(1.0 - self.group_advantage_coef) * normalized_gae
+ self.group_advantage_coef * repeated_group_advantages
)
combined_advantages = self._normalize(combined_advantages)
return combined_advantages, np.concatenate(returns, axis=0).astype(np.float32)
def update(self) -> Dict[str, float]:
if not self.trajectories:
return {}
advantages, returns = self._build_training_targets()
states = torch.FloatTensor(
np.concatenate([trajectory["states"] for trajectory in self.trajectories], axis=0)
).to(self.device)
actions = torch.LongTensor(
np.concatenate([trajectory["actions"] for trajectory in self.trajectories], axis=0)
).to(self.device)
old_log_probs = torch.FloatTensor(
np.concatenate([trajectory["log_probs"] for trajectory in self.trajectories], axis=0)
).to(self.device)
advantages_t = torch.FloatTensor(advantages).to(self.device)
returns_t = torch.FloatTensor(returns).to(self.device)
total_loss = 0.0
total_policy_loss = 0.0
total_value_loss = 0.0
total_entropy_value = 0.0
update_count = 0
dataset_size = states.shape[0]
for _ in range(self.ppo_epochs):
indices = np.random.permutation(dataset_size)
for start_idx in range(0, dataset_size, self.minibatch_size):
end_idx = min(start_idx + self.minibatch_size, dataset_size)
batch_idx = indices[start_idx:end_idx]
batch_states = states[batch_idx]
batch_actions = actions[batch_idx]
batch_old_log_probs = old_log_probs[batch_idx]
batch_advantages = advantages_t[batch_idx]
batch_returns = returns_t[batch_idx]
logits_list, values = self.policy(batch_states)
total_new_log_probs = torch.zeros(len(batch_idx), device=self.device)
entropy_terms = torch.zeros(len(batch_idx), device=self.device)
for head_idx in range(self.num_heads):
probs = F.softmax(logits_list[head_idx], dim=-1)
dist = torch.distributions.Categorical(probs)
total_new_log_probs += dist.log_prob(batch_actions[:, head_idx])
entropy_terms += dist.entropy()
entropy = entropy_terms.mean()
ratio = torch.exp(total_new_log_probs - batch_old_log_probs)
surr1 = ratio * batch_advantages
surr2 = (
torch.clamp(ratio, 1 - self.clip_epsilon, 1 + self.clip_epsilon)
* batch_advantages
)
policy_loss = -torch.min(surr1, surr2).mean()
value_loss = F.mse_loss(values.squeeze(-1), batch_returns)
loss = (
policy_loss
+ self.value_coef * value_loss
- self.entropy_coef * entropy
)
self.optimizer.zero_grad()
loss.backward()
nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm)
self.optimizer.step()
total_loss += float(loss.item())
total_policy_loss += float(policy_loss.item())
total_value_loss += float(value_loss.item())
total_entropy_value += float(entropy.item())
update_count += 1
if self.scheduler is not None:
self.scheduler.step()
scores = np.asarray(
[trajectory["score"] for trajectory in self.trajectories], dtype=np.float32
)
stats = {
"loss": total_loss / max(update_count, 1),
"policy_loss": total_policy_loss / max(update_count, 1),
"value_loss": total_value_loss / max(update_count, 1),
"entropy": total_entropy_value / max(update_count, 1),
"lr": self.optimizer.param_groups[0]["lr"],
"group_score_mean": float(scores.mean()) if scores.size else 0.0,
"group_score_std": float(scores.std()) if scores.size else 0.0,
"group_size": float(len(self.trajectories)),
"group_advantage_coef": float(self.group_advantage_coef),
}
self.reset_group_buffers()
return stats
def save(self, path: str):
torch.save(
{
"policy_state_dict": self.policy.state_dict(),
"optimizer_state_dict": self.optimizer.state_dict(),
},
path,
)
def load(self, path: str):
checkpoint = torch.load(path, map_location=self.device, weights_only=False)
self.policy.load_state_dict(checkpoint["policy_state_dict"])
self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"])