调整gpro-ppo模型架构
This commit is contained in:
parent
cea9d42397
commit
e45b083067
|
|
@ -1,4 +1,4 @@
|
||||||
"""GRPO-inspired PPO agent for grouped corridor control rollouts."""
|
"""PPO-style GPRO agent with group-relative advantages for corridor control."""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import Dict, List, Tuple
|
from typing import Dict, List, Tuple
|
||||||
|
|
@ -10,8 +10,8 @@ import torch.nn.functional as F
|
||||||
import torch.optim as optim
|
import torch.optim as optim
|
||||||
|
|
||||||
|
|
||||||
class MultiDiscreteActor(nn.Module):
|
class MultiDiscreteActorCritic(nn.Module):
|
||||||
"""Shared trunk plus one categorical head per controlled edge."""
|
"""Shared actor-critic backbone with one categorical head per control segment."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|
@ -20,9 +20,7 @@ class MultiDiscreteActor(nn.Module):
|
||||||
hidden_layers: List[int] = [256, 256],
|
hidden_layers: List[int] = [256, 256],
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.state_dim = state_dim
|
|
||||||
self.action_dims = action_dims
|
self.action_dims = action_dims
|
||||||
self.num_heads = len(action_dims)
|
|
||||||
|
|
||||||
layers = []
|
layers = []
|
||||||
prev_dim = state_dim
|
prev_dim = state_dim
|
||||||
|
|
@ -32,9 +30,15 @@ class MultiDiscreteActor(nn.Module):
|
||||||
layers.append(nn.ReLU())
|
layers.append(nn.ReLU())
|
||||||
prev_dim = hidden_dim
|
prev_dim = hidden_dim
|
||||||
self.feature_extractor = nn.Sequential(*layers)
|
self.feature_extractor = nn.Sequential(*layers)
|
||||||
|
|
||||||
self.actor_heads = nn.ModuleList(
|
self.actor_heads = nn.ModuleList(
|
||||||
[nn.Linear(prev_dim, action_dim) for action_dim in action_dims]
|
[nn.Linear(prev_dim, action_dim) for action_dim in action_dims]
|
||||||
)
|
)
|
||||||
|
self.critic = nn.Sequential(
|
||||||
|
nn.Linear(prev_dim, 128),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Linear(128, 1),
|
||||||
|
)
|
||||||
self._init_weights()
|
self._init_weights()
|
||||||
|
|
||||||
def _init_weights(self):
|
def _init_weights(self):
|
||||||
|
|
@ -44,17 +48,26 @@ class MultiDiscreteActor(nn.Module):
|
||||||
nn.init.constant_(module.bias, 0)
|
nn.init.constant_(module.bias, 0)
|
||||||
for head in self.actor_heads:
|
for head in self.actor_heads:
|
||||||
nn.init.orthogonal_(head.weight, gain=0.01)
|
nn.init.orthogonal_(head.weight, gain=0.01)
|
||||||
|
nn.init.orthogonal_(self.critic[-1].weight, gain=1.0)
|
||||||
|
|
||||||
def forward(self, state: torch.Tensor) -> List[torch.Tensor]:
|
def forward(self, state: torch.Tensor) -> Tuple[List[torch.Tensor], torch.Tensor]:
|
||||||
features = self.feature_extractor(state)
|
features = self.feature_extractor(state)
|
||||||
return [head(features) for head in self.actor_heads]
|
logits_list = [head(features) for head in self.actor_heads]
|
||||||
|
value = self.critic(features)
|
||||||
|
return logits_list, value
|
||||||
|
|
||||||
def get_action_probs(self, state: torch.Tensor) -> List[torch.Tensor]:
|
def get_action_probs(self, state: torch.Tensor) -> Tuple[List[torch.Tensor], torch.Tensor]:
|
||||||
return [F.softmax(logits, dim=-1) for logits in self.forward(state)]
|
logits_list, value = self.forward(state)
|
||||||
|
probs_list = [F.softmax(logits, dim=-1) for logits in logits_list]
|
||||||
|
return probs_list, value
|
||||||
|
|
||||||
|
def get_value(self, state: torch.Tensor) -> torch.Tensor:
|
||||||
|
features = self.feature_extractor(state)
|
||||||
|
return self.critic(features)
|
||||||
|
|
||||||
|
|
||||||
class GPROAgent:
|
class GPROAgent:
|
||||||
"""Grouped relative PPO without value critic."""
|
"""PPO actor-critic with group-relative trajectory ranking."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|
@ -62,29 +75,39 @@ class GPROAgent:
|
||||||
action_dims: List[int],
|
action_dims: List[int],
|
||||||
hidden_layers: List[int] = [256, 256],
|
hidden_layers: List[int] = [256, 256],
|
||||||
learning_rate: float = 3e-4,
|
learning_rate: float = 3e-4,
|
||||||
|
gamma: float = 0.99,
|
||||||
|
gae_lambda: float = 0.95,
|
||||||
clip_epsilon: float = 0.2,
|
clip_epsilon: float = 0.2,
|
||||||
|
value_coef: float = 0.5,
|
||||||
entropy_coef: float = 0.01,
|
entropy_coef: float = 0.01,
|
||||||
max_grad_norm: float = 0.5,
|
max_grad_norm: float = 0.5,
|
||||||
ppo_epochs: int = 4,
|
ppo_epochs: int = 4,
|
||||||
minibatch_size: int = 64,
|
minibatch_size: int = 64,
|
||||||
group_size: int = 4,
|
group_size: int = 4,
|
||||||
|
group_advantage_coef: float = 0.35,
|
||||||
advantage_epsilon: float = 1e-8,
|
advantage_epsilon: float = 1e-8,
|
||||||
device: str = "cuda",
|
device: str = "cuda",
|
||||||
lr_schedule: str = "cosine",
|
lr_schedule: str = "cosine",
|
||||||
total_episodes: int = 300,
|
total_episodes: int = 300,
|
||||||
):
|
):
|
||||||
self.device = torch.device(device if torch.cuda.is_available() else "cpu")
|
self.device = torch.device(device if torch.cuda.is_available() else "cpu")
|
||||||
|
self.gamma = gamma
|
||||||
|
self.gae_lambda = gae_lambda
|
||||||
self.clip_epsilon = clip_epsilon
|
self.clip_epsilon = clip_epsilon
|
||||||
|
self.value_coef = value_coef
|
||||||
self.entropy_coef = entropy_coef
|
self.entropy_coef = entropy_coef
|
||||||
self.max_grad_norm = max_grad_norm
|
self.max_grad_norm = max_grad_norm
|
||||||
self.ppo_epochs = ppo_epochs
|
self.ppo_epochs = ppo_epochs
|
||||||
self.minibatch_size = minibatch_size
|
self.minibatch_size = minibatch_size
|
||||||
self.group_size = max(int(group_size), 2)
|
self.group_size = max(int(group_size), 2)
|
||||||
|
self.group_advantage_coef = float(np.clip(group_advantage_coef, 0.0, 1.0))
|
||||||
self.advantage_epsilon = advantage_epsilon
|
self.advantage_epsilon = advantage_epsilon
|
||||||
self.action_dims = action_dims
|
self.action_dims = action_dims
|
||||||
self.num_heads = len(action_dims)
|
self.num_heads = len(action_dims)
|
||||||
|
|
||||||
self.policy = MultiDiscreteActor(state_dim, action_dims, hidden_layers).to(self.device)
|
self.policy = MultiDiscreteActorCritic(state_dim, action_dims, hidden_layers).to(
|
||||||
|
self.device
|
||||||
|
)
|
||||||
self.optimizer = optim.Adam(self.policy.parameters(), lr=learning_rate, eps=1e-5)
|
self.optimizer = optim.Adam(self.policy.parameters(), lr=learning_rate, eps=1e-5)
|
||||||
|
|
||||||
self.lr_schedule = lr_schedule
|
self.lr_schedule = lr_schedule
|
||||||
|
|
@ -101,10 +124,13 @@ class GPROAgent:
|
||||||
def _reset_episode_buffer(self):
|
def _reset_episode_buffer(self):
|
||||||
self.current_states: List[np.ndarray] = []
|
self.current_states: List[np.ndarray] = []
|
||||||
self.current_actions: List[np.ndarray] = []
|
self.current_actions: List[np.ndarray] = []
|
||||||
|
self.current_rewards: List[float] = []
|
||||||
|
self.current_values: List[float] = []
|
||||||
self.current_log_probs: List[float] = []
|
self.current_log_probs: List[float] = []
|
||||||
|
self.current_dones: List[float] = []
|
||||||
|
|
||||||
def reset_group_buffers(self):
|
def reset_group_buffers(self):
|
||||||
self.trajectories: List[Dict] = []
|
self.trajectories: List[Dict[str, np.ndarray | float]] = []
|
||||||
self._reset_episode_buffer()
|
self._reset_episode_buffer()
|
||||||
|
|
||||||
def select_action(
|
def select_action(
|
||||||
|
|
@ -112,7 +138,7 @@ class GPROAgent:
|
||||||
) -> Tuple[np.ndarray, float, float]:
|
) -> Tuple[np.ndarray, float, float]:
|
||||||
state_tensor = torch.FloatTensor(state).unsqueeze(0).to(self.device)
|
state_tensor = torch.FloatTensor(state).unsqueeze(0).to(self.device)
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
probs_list = self.policy.get_action_probs(state_tensor)
|
probs_list, value = self.policy.get_action_probs(state_tensor)
|
||||||
|
|
||||||
actions = []
|
actions = []
|
||||||
total_log_prob = 0.0
|
total_log_prob = 0.0
|
||||||
|
|
@ -126,13 +152,15 @@ class GPROAgent:
|
||||||
actions.append(action)
|
actions.append(action)
|
||||||
total_log_prob += log_prob
|
total_log_prob += log_prob
|
||||||
|
|
||||||
return np.array(actions, dtype=np.int64), total_log_prob, 0.0
|
return np.array(actions, dtype=np.int64), total_log_prob, value.item()
|
||||||
|
|
||||||
def store_transition(self, state, action, reward, value, log_prob, done):
|
def store_transition(self, state, action, reward, value, log_prob, done):
|
||||||
del reward, value, done
|
|
||||||
self.current_states.append(np.asarray(state, dtype=np.float32))
|
self.current_states.append(np.asarray(state, dtype=np.float32))
|
||||||
self.current_actions.append(np.asarray(action, dtype=np.int64))
|
self.current_actions.append(np.asarray(action, dtype=np.int64))
|
||||||
|
self.current_rewards.append(float(reward))
|
||||||
|
self.current_values.append(float(value))
|
||||||
self.current_log_probs.append(float(log_prob))
|
self.current_log_probs.append(float(log_prob))
|
||||||
|
self.current_dones.append(float(done))
|
||||||
|
|
||||||
def finish_episode(self, score: float):
|
def finish_episode(self, score: float):
|
||||||
if not self.current_states:
|
if not self.current_states:
|
||||||
|
|
@ -141,47 +169,102 @@ class GPROAgent:
|
||||||
{
|
{
|
||||||
"states": np.asarray(self.current_states, dtype=np.float32),
|
"states": np.asarray(self.current_states, dtype=np.float32),
|
||||||
"actions": np.asarray(self.current_actions, dtype=np.int64),
|
"actions": np.asarray(self.current_actions, dtype=np.int64),
|
||||||
|
"rewards": np.asarray(self.current_rewards, dtype=np.float32),
|
||||||
|
"values": np.asarray(self.current_values, dtype=np.float32),
|
||||||
"log_probs": np.asarray(self.current_log_probs, dtype=np.float32),
|
"log_probs": np.asarray(self.current_log_probs, dtype=np.float32),
|
||||||
|
"dones": np.asarray(self.current_dones, dtype=np.float32),
|
||||||
"score": float(score),
|
"score": float(score),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
self._reset_episode_buffer()
|
self._reset_episode_buffer()
|
||||||
|
|
||||||
def _build_group_advantages(self) -> List[np.ndarray]:
|
def _compute_gae(
|
||||||
scores = np.asarray([traj["score"] for traj in self.trajectories], dtype=np.float32)
|
self,
|
||||||
if scores.size == 0:
|
rewards: np.ndarray,
|
||||||
return []
|
values: np.ndarray,
|
||||||
score_mean = float(scores.mean())
|
dones: np.ndarray,
|
||||||
score_std = float(scores.std())
|
next_value: float = 0.0,
|
||||||
if score_std < self.advantage_epsilon:
|
) -> Tuple[np.ndarray, np.ndarray]:
|
||||||
normalized_scores = np.zeros_like(scores, dtype=np.float32)
|
advantages = np.zeros_like(rewards, dtype=np.float32)
|
||||||
|
gae = 0.0
|
||||||
|
|
||||||
|
for t in reversed(range(len(rewards))):
|
||||||
|
if t == len(rewards) - 1:
|
||||||
|
next_val = next_value
|
||||||
else:
|
else:
|
||||||
normalized_scores = (scores - score_mean) / (score_std + self.advantage_epsilon)
|
next_val = values[t + 1]
|
||||||
return [
|
delta = rewards[t] + self.gamma * next_val * (1.0 - dones[t]) - values[t]
|
||||||
np.full(len(traj["states"]), normalized_scores[idx], dtype=np.float32)
|
gae = delta + self.gamma * self.gae_lambda * (1.0 - dones[t]) * gae
|
||||||
for idx, traj in enumerate(self.trajectories)
|
advantages[t] = gae
|
||||||
]
|
|
||||||
|
returns = advantages + values
|
||||||
|
return advantages, returns
|
||||||
|
|
||||||
|
def _normalize(self, values: np.ndarray) -> np.ndarray:
|
||||||
|
if values.size == 0:
|
||||||
|
return values
|
||||||
|
mean = float(values.mean())
|
||||||
|
std = float(values.std())
|
||||||
|
if std < self.advantage_epsilon:
|
||||||
|
return np.zeros_like(values, dtype=np.float32)
|
||||||
|
return ((values - mean) / (std + self.advantage_epsilon)).astype(np.float32)
|
||||||
|
|
||||||
|
def _build_training_targets(self) -> Tuple[np.ndarray, np.ndarray]:
|
||||||
|
if not self.trajectories:
|
||||||
|
return np.array([], dtype=np.float32), np.array([], dtype=np.float32)
|
||||||
|
|
||||||
|
gae_advantages = []
|
||||||
|
returns = []
|
||||||
|
for trajectory in self.trajectories:
|
||||||
|
trajectory_adv, trajectory_ret = self._compute_gae(
|
||||||
|
rewards=trajectory["rewards"],
|
||||||
|
values=trajectory["values"],
|
||||||
|
dones=trajectory["dones"],
|
||||||
|
next_value=0.0,
|
||||||
|
)
|
||||||
|
gae_advantages.append(trajectory_adv)
|
||||||
|
returns.append(trajectory_ret)
|
||||||
|
|
||||||
|
normalized_gae = self._normalize(np.concatenate(gae_advantages, axis=0))
|
||||||
|
group_scores = np.asarray(
|
||||||
|
[trajectory["score"] for trajectory in self.trajectories], dtype=np.float32
|
||||||
|
)
|
||||||
|
normalized_group_scores = self._normalize(group_scores)
|
||||||
|
repeated_group_advantages = np.concatenate(
|
||||||
|
[
|
||||||
|
np.full(len(trajectory["states"]), normalized_group_scores[idx], dtype=np.float32)
|
||||||
|
for idx, trajectory in enumerate(self.trajectories)
|
||||||
|
],
|
||||||
|
axis=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
combined_advantages = (
|
||||||
|
(1.0 - self.group_advantage_coef) * normalized_gae
|
||||||
|
+ self.group_advantage_coef * repeated_group_advantages
|
||||||
|
)
|
||||||
|
combined_advantages = self._normalize(combined_advantages)
|
||||||
|
return combined_advantages, np.concatenate(returns, axis=0).astype(np.float32)
|
||||||
|
|
||||||
def update(self) -> Dict[str, float]:
|
def update(self) -> Dict[str, float]:
|
||||||
if not self.trajectories:
|
if not self.trajectories:
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
trajectory_advantages = self._build_group_advantages()
|
advantages, returns = self._build_training_targets()
|
||||||
states = torch.FloatTensor(
|
states = torch.FloatTensor(
|
||||||
np.concatenate([traj["states"] for traj in self.trajectories], axis=0)
|
np.concatenate([trajectory["states"] for trajectory in self.trajectories], axis=0)
|
||||||
).to(self.device)
|
).to(self.device)
|
||||||
actions = torch.LongTensor(
|
actions = torch.LongTensor(
|
||||||
np.concatenate([traj["actions"] for traj in self.trajectories], axis=0)
|
np.concatenate([trajectory["actions"] for trajectory in self.trajectories], axis=0)
|
||||||
).to(self.device)
|
).to(self.device)
|
||||||
old_log_probs = torch.FloatTensor(
|
old_log_probs = torch.FloatTensor(
|
||||||
np.concatenate([traj["log_probs"] for traj in self.trajectories], axis=0)
|
np.concatenate([trajectory["log_probs"] for trajectory in self.trajectories], axis=0)
|
||||||
).to(self.device)
|
|
||||||
advantages = torch.FloatTensor(
|
|
||||||
np.concatenate(trajectory_advantages, axis=0)
|
|
||||||
).to(self.device)
|
).to(self.device)
|
||||||
|
advantages_t = torch.FloatTensor(advantages).to(self.device)
|
||||||
|
returns_t = torch.FloatTensor(returns).to(self.device)
|
||||||
|
|
||||||
total_loss = 0.0
|
total_loss = 0.0
|
||||||
total_policy_loss = 0.0
|
total_policy_loss = 0.0
|
||||||
|
total_value_loss = 0.0
|
||||||
total_entropy_value = 0.0
|
total_entropy_value = 0.0
|
||||||
update_count = 0
|
update_count = 0
|
||||||
dataset_size = states.shape[0]
|
dataset_size = states.shape[0]
|
||||||
|
|
@ -195,9 +278,10 @@ class GPROAgent:
|
||||||
batch_states = states[batch_idx]
|
batch_states = states[batch_idx]
|
||||||
batch_actions = actions[batch_idx]
|
batch_actions = actions[batch_idx]
|
||||||
batch_old_log_probs = old_log_probs[batch_idx]
|
batch_old_log_probs = old_log_probs[batch_idx]
|
||||||
batch_advantages = advantages[batch_idx]
|
batch_advantages = advantages_t[batch_idx]
|
||||||
|
batch_returns = returns_t[batch_idx]
|
||||||
|
|
||||||
logits_list = self.policy(batch_states)
|
logits_list, values = self.policy(batch_states)
|
||||||
total_new_log_probs = torch.zeros(len(batch_idx), device=self.device)
|
total_new_log_probs = torch.zeros(len(batch_idx), device=self.device)
|
||||||
entropy_terms = torch.zeros(len(batch_idx), device=self.device)
|
entropy_terms = torch.zeros(len(batch_idx), device=self.device)
|
||||||
|
|
||||||
|
|
@ -215,7 +299,12 @@ class GPROAgent:
|
||||||
* batch_advantages
|
* batch_advantages
|
||||||
)
|
)
|
||||||
policy_loss = -torch.min(surr1, surr2).mean()
|
policy_loss = -torch.min(surr1, surr2).mean()
|
||||||
loss = policy_loss - self.entropy_coef * entropy
|
value_loss = F.mse_loss(values.squeeze(-1), batch_returns)
|
||||||
|
loss = (
|
||||||
|
policy_loss
|
||||||
|
+ self.value_coef * value_loss
|
||||||
|
- self.entropy_coef * entropy
|
||||||
|
)
|
||||||
|
|
||||||
self.optimizer.zero_grad()
|
self.optimizer.zero_grad()
|
||||||
loss.backward()
|
loss.backward()
|
||||||
|
|
@ -224,25 +313,26 @@ class GPROAgent:
|
||||||
|
|
||||||
total_loss += float(loss.item())
|
total_loss += float(loss.item())
|
||||||
total_policy_loss += float(policy_loss.item())
|
total_policy_loss += float(policy_loss.item())
|
||||||
|
total_value_loss += float(value_loss.item())
|
||||||
total_entropy_value += float(entropy.item())
|
total_entropy_value += float(entropy.item())
|
||||||
update_count += 1
|
update_count += 1
|
||||||
|
|
||||||
if self.scheduler is not None:
|
if self.scheduler is not None:
|
||||||
self.scheduler.step()
|
self.scheduler.step()
|
||||||
|
|
||||||
|
scores = np.asarray(
|
||||||
|
[trajectory["score"] for trajectory in self.trajectories], dtype=np.float32
|
||||||
|
)
|
||||||
stats = {
|
stats = {
|
||||||
"loss": total_loss / max(update_count, 1),
|
"loss": total_loss / max(update_count, 1),
|
||||||
"policy_loss": total_policy_loss / max(update_count, 1),
|
"policy_loss": total_policy_loss / max(update_count, 1),
|
||||||
"value_loss": 0.0,
|
"value_loss": total_value_loss / max(update_count, 1),
|
||||||
"entropy": total_entropy_value / max(update_count, 1),
|
"entropy": total_entropy_value / max(update_count, 1),
|
||||||
"lr": self.optimizer.param_groups[0]["lr"],
|
"lr": self.optimizer.param_groups[0]["lr"],
|
||||||
"group_score_mean": float(
|
"group_score_mean": float(scores.mean()) if scores.size else 0.0,
|
||||||
np.mean([traj["score"] for traj in self.trajectories], dtype=np.float32)
|
"group_score_std": float(scores.std()) if scores.size else 0.0,
|
||||||
),
|
|
||||||
"group_score_std": float(
|
|
||||||
np.std([traj["score"] for traj in self.trajectories], dtype=np.float32)
|
|
||||||
),
|
|
||||||
"group_size": float(len(self.trajectories)),
|
"group_size": float(len(self.trajectories)),
|
||||||
|
"group_advantage_coef": float(self.group_advantage_coef),
|
||||||
}
|
}
|
||||||
self.reset_group_buffers()
|
self.reset_group_buffers()
|
||||||
return stats
|
return stats
|
||||||
|
|
|
||||||
|
|
@ -107,13 +107,17 @@ agents:
|
||||||
gpro:
|
gpro:
|
||||||
hidden_layers: [256, 256]
|
hidden_layers: [256, 256]
|
||||||
learning_rate: 0.0003
|
learning_rate: 0.0003
|
||||||
|
gamma: 0.99
|
||||||
|
gae_lambda: 0.95
|
||||||
clip_epsilon: 0.2
|
clip_epsilon: 0.2
|
||||||
|
value_coef: 0.5
|
||||||
entropy_coef: 0.01
|
entropy_coef: 0.01
|
||||||
max_grad_norm: 0.5
|
max_grad_norm: 0.5
|
||||||
ppo_epochs: 4
|
ppo_epochs: 4
|
||||||
batch_size: 15
|
batch_size: 15
|
||||||
lr_schedule: "cosine"
|
lr_schedule: "cosine"
|
||||||
group_size: 4
|
group_size: 4
|
||||||
|
group_advantage_coef: 0.35
|
||||||
advantage_epsilon: 1.0e-8
|
advantage_epsilon: 1.0e-8
|
||||||
|
|
||||||
appo:
|
appo:
|
||||||
|
|
|
||||||
|
|
@ -60,6 +60,7 @@ def train_sumo_gpro(log_dir=None, checkpoint_dir=None, run_timestamp=None):
|
||||||
print(f" Hidden layers: {agent_config.get('hidden_layers', [256, 256])}")
|
print(f" Hidden layers: {agent_config.get('hidden_layers', [256, 256])}")
|
||||||
print(f" Learning rate: {agent_config.get('learning_rate', 3e-4)}")
|
print(f" Learning rate: {agent_config.get('learning_rate', 3e-4)}")
|
||||||
print(f" Group size: {group_size}")
|
print(f" Group size: {group_size}")
|
||||||
|
print(f" Group advantage coef: {agent_config.get('group_advantage_coef', 0.35)}")
|
||||||
print(f" Device: {agent_config.get('device', 'cuda')}")
|
print(f" Device: {agent_config.get('device', 'cuda')}")
|
||||||
print()
|
print()
|
||||||
|
|
||||||
|
|
@ -68,12 +69,16 @@ def train_sumo_gpro(log_dir=None, checkpoint_dir=None, run_timestamp=None):
|
||||||
action_dims=action_dims,
|
action_dims=action_dims,
|
||||||
hidden_layers=agent_config.get("hidden_layers", [256, 256]),
|
hidden_layers=agent_config.get("hidden_layers", [256, 256]),
|
||||||
learning_rate=agent_config.get("learning_rate", 3e-4),
|
learning_rate=agent_config.get("learning_rate", 3e-4),
|
||||||
|
gamma=agent_config.get("gamma", 0.99),
|
||||||
|
gae_lambda=agent_config.get("gae_lambda", 0.95),
|
||||||
clip_epsilon=agent_config.get("clip_epsilon", 0.2),
|
clip_epsilon=agent_config.get("clip_epsilon", 0.2),
|
||||||
|
value_coef=agent_config.get("value_coef", 0.5),
|
||||||
entropy_coef=agent_config.get("entropy_coef", 0.01),
|
entropy_coef=agent_config.get("entropy_coef", 0.01),
|
||||||
max_grad_norm=agent_config.get("max_grad_norm", 0.5),
|
max_grad_norm=agent_config.get("max_grad_norm", 0.5),
|
||||||
ppo_epochs=agent_config.get("ppo_epochs", 4),
|
ppo_epochs=agent_config.get("ppo_epochs", 4),
|
||||||
minibatch_size=agent_config.get("batch_size", 64),
|
minibatch_size=agent_config.get("batch_size", 64),
|
||||||
group_size=group_size,
|
group_size=group_size,
|
||||||
|
group_advantage_coef=agent_config.get("group_advantage_coef", 0.35),
|
||||||
advantage_epsilon=agent_config.get("advantage_epsilon", 1e-8),
|
advantage_epsilon=agent_config.get("advantage_epsilon", 1e-8),
|
||||||
device=agent_config.get("device", "cuda"),
|
device=agent_config.get("device", "cuda"),
|
||||||
lr_schedule=agent_config.get("lr_schedule", "cosine"),
|
lr_schedule=agent_config.get("lr_schedule", "cosine"),
|
||||||
|
|
@ -125,10 +130,10 @@ def train_sumo_gpro(log_dir=None, checkpoint_dir=None, run_timestamp=None):
|
||||||
)
|
)
|
||||||
|
|
||||||
while not done:
|
while not done:
|
||||||
action, log_prob, _ = agent.select_action(state, deterministic=False)
|
action, log_prob, value = agent.select_action(state, deterministic=False)
|
||||||
next_state, reward, done, info = env.step(action)
|
next_state, reward, done, info = env.step(action)
|
||||||
|
|
||||||
agent.store_transition(state, action, reward, 0.0, log_prob, done)
|
agent.store_transition(state, action, reward, value, log_prob, done)
|
||||||
|
|
||||||
episode_reward += reward
|
episode_reward += reward
|
||||||
episode_throughput += info["throughput"]
|
episode_throughput += info["throughput"]
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue