546 lines
21 KiB
Python
546 lines
21 KiB
Python
"""
|
|
Directional Corridor MAPPO for ordered motorway VSL control.
|
|
|
|
This revision borrows three ideas that are common in newer sequence/control
|
|
architectures from other domains:
|
|
- Variable selection style gating to suppress noisy neighbor features
|
|
- Conditional modulation (FiLM-like) so global state can reshape local features
|
|
- Dual-path corridor mixing that keeps both local smoothing and directional flow
|
|
"""
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.optim as optim
|
|
from typing import Dict, Tuple
|
|
|
|
|
|
class SelectiveFeatureFusion(nn.Module):
|
|
"""Fuse self/neighbor/delta components with content-aware gates."""
|
|
|
|
def __init__(self, component_dim: int, condition_dim: int, hidden_dim: int, num_components: int):
|
|
super().__init__()
|
|
self.component_proj = nn.Sequential(
|
|
nn.Linear(component_dim, hidden_dim),
|
|
nn.LayerNorm(hidden_dim),
|
|
nn.GELU(),
|
|
)
|
|
self.component_score = nn.Linear(hidden_dim, 1)
|
|
self.condition_score = nn.Sequential(
|
|
nn.Linear(condition_dim, hidden_dim),
|
|
nn.GELU(),
|
|
nn.Linear(hidden_dim, num_components),
|
|
)
|
|
self.film = nn.Sequential(
|
|
nn.Linear(condition_dim, hidden_dim * 2),
|
|
nn.GELU(),
|
|
nn.Linear(hidden_dim * 2, hidden_dim * 2),
|
|
)
|
|
self._init_weights()
|
|
|
|
def _init_weights(self):
|
|
for module in self.modules():
|
|
if isinstance(module, nn.Linear):
|
|
nn.init.orthogonal_(module.weight, gain=np.sqrt(2))
|
|
nn.init.constant_(module.bias, 0)
|
|
|
|
def forward(self, components: torch.Tensor, condition: torch.Tensor) -> torch.Tensor:
|
|
encoded = self.component_proj(components)
|
|
score_from_component = self.component_score(encoded).squeeze(-1)
|
|
score_from_condition = self.condition_score(condition)
|
|
weights = torch.softmax(score_from_component + score_from_condition, dim=1)
|
|
fused = (weights.unsqueeze(-1) * encoded).sum(dim=1)
|
|
|
|
gamma, beta = self.film(condition).chunk(2, dim=-1)
|
|
gamma = 0.1 * torch.tanh(gamma)
|
|
beta = 0.1 * beta
|
|
return fused * (1.0 + gamma) + beta
|
|
|
|
|
|
class DualPathCorridorBlock(nn.Module):
|
|
"""Combine local convolutional mixing and directional recurrent propagation."""
|
|
|
|
def __init__(self, hidden_dim: int, kernel_size: int = 5, dropout: float = 0.05):
|
|
super().__init__()
|
|
padding = kernel_size // 2
|
|
self.local_norm = nn.LayerNorm(hidden_dim)
|
|
self.depthwise = nn.Conv1d(
|
|
hidden_dim,
|
|
hidden_dim,
|
|
kernel_size=kernel_size,
|
|
padding=padding,
|
|
groups=hidden_dim,
|
|
)
|
|
self.pointwise = nn.Conv1d(hidden_dim, hidden_dim, kernel_size=1)
|
|
self.local_gate = nn.Linear(hidden_dim * 2, hidden_dim)
|
|
|
|
self.seq_norm = nn.LayerNorm(hidden_dim)
|
|
self.seq_mixer = nn.GRU(
|
|
input_size=hidden_dim,
|
|
hidden_size=hidden_dim // 2,
|
|
num_layers=1,
|
|
batch_first=True,
|
|
bidirectional=True,
|
|
)
|
|
self.seq_gate = nn.Linear(hidden_dim * 2, hidden_dim)
|
|
|
|
self.ffn_norm = nn.LayerNorm(hidden_dim)
|
|
self.ffn = nn.Sequential(
|
|
nn.Linear(hidden_dim, hidden_dim * 2),
|
|
nn.GELU(),
|
|
nn.Dropout(dropout),
|
|
nn.Linear(hidden_dim * 2, hidden_dim),
|
|
)
|
|
self.dropout = nn.Dropout(dropout)
|
|
self.activation = nn.GELU()
|
|
self._init_weights()
|
|
|
|
def _init_weights(self):
|
|
for module in self.modules():
|
|
if isinstance(module, nn.Linear):
|
|
nn.init.orthogonal_(module.weight, gain=np.sqrt(2))
|
|
nn.init.constant_(module.bias, 0)
|
|
elif isinstance(module, nn.Conv1d):
|
|
nn.init.kaiming_uniform_(module.weight, a=np.sqrt(5))
|
|
if module.bias is not None:
|
|
nn.init.constant_(module.bias, 0)
|
|
|
|
for name, param in self.seq_mixer.named_parameters():
|
|
if "weight" in name:
|
|
nn.init.orthogonal_(param)
|
|
elif "bias" in name:
|
|
nn.init.constant_(param, 0)
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
residual = x
|
|
local = self.local_norm(x).transpose(1, 2)
|
|
local = self.activation(self.pointwise(self.depthwise(local))).transpose(1, 2)
|
|
local_gate = torch.sigmoid(self.local_gate(torch.cat([residual, local], dim=-1)))
|
|
x = residual + self.dropout(local_gate * local)
|
|
|
|
residual = x
|
|
seq_input = self.seq_norm(x)
|
|
seq_out, _ = self.seq_mixer(seq_input)
|
|
seq_gate = torch.sigmoid(self.seq_gate(torch.cat([residual, seq_out], dim=-1)))
|
|
x = residual + self.dropout(seq_gate * seq_out)
|
|
|
|
residual = x
|
|
x = residual + self.dropout(self.ffn(self.ffn_norm(x)))
|
|
return x
|
|
|
|
|
|
class CorridorActor(nn.Module):
|
|
def __init__(
|
|
self,
|
|
edge_token_dim: int,
|
|
condition_dim: int,
|
|
num_agents: int,
|
|
num_actions: int,
|
|
hidden_dim: int = 256,
|
|
num_blocks: int = 2,
|
|
kernel_size: int = 5,
|
|
dropout: float = 0.05,
|
|
):
|
|
super().__init__()
|
|
self.fusion = SelectiveFeatureFusion(
|
|
component_dim=edge_token_dim,
|
|
condition_dim=condition_dim,
|
|
hidden_dim=hidden_dim,
|
|
num_components=5,
|
|
)
|
|
self.position_embedding = nn.Parameter(torch.zeros(1, num_agents, hidden_dim))
|
|
self.blocks = nn.ModuleList(
|
|
[DualPathCorridorBlock(hidden_dim, kernel_size=kernel_size, dropout=dropout) for _ in range(num_blocks)]
|
|
)
|
|
self.head_norm = nn.LayerNorm(hidden_dim)
|
|
self.head = nn.Linear(hidden_dim, num_actions)
|
|
self._init_weights()
|
|
|
|
def _init_weights(self):
|
|
nn.init.normal_(self.position_embedding, mean=0.0, std=0.02)
|
|
nn.init.orthogonal_(self.head.weight, gain=0.01)
|
|
nn.init.constant_(self.head.bias, 0)
|
|
|
|
def forward(self, components: torch.Tensor, condition: torch.Tensor) -> torch.Tensor:
|
|
batch_size, num_agents, _, _ = components.shape
|
|
fused = self.fusion(
|
|
components.view(batch_size * num_agents, components.size(2), components.size(3)),
|
|
condition.view(batch_size * num_agents, condition.size(-1)),
|
|
)
|
|
x = fused.view(batch_size, num_agents, -1) + self.position_embedding
|
|
for block in self.blocks:
|
|
x = block(x)
|
|
return self.head(self.head_norm(x))
|
|
|
|
|
|
class StructuredCorridorCritic(nn.Module):
|
|
"""Structured critic to reduce value underfitting on ordered corridor states."""
|
|
|
|
def __init__(
|
|
self,
|
|
num_agents: int,
|
|
edge_feature_dim: int,
|
|
time_feature_dim: int,
|
|
hidden_dim: int = 256,
|
|
num_blocks: int = 2,
|
|
kernel_size: int = 5,
|
|
dropout: float = 0.05,
|
|
):
|
|
super().__init__()
|
|
self.num_agents = num_agents
|
|
self.edge_feature_dim = edge_feature_dim
|
|
self.time_feature_dim = time_feature_dim
|
|
self.speed_feature_dim = 1
|
|
self.last_reward_dim = 1
|
|
self.global_feature_dim = self.time_feature_dim + self.last_reward_dim
|
|
self.edge_token_dim = self.edge_feature_dim + self.speed_feature_dim
|
|
|
|
self.edge_proj = nn.Sequential(
|
|
nn.Linear(self.edge_token_dim, hidden_dim),
|
|
nn.LayerNorm(hidden_dim),
|
|
nn.GELU(),
|
|
)
|
|
self.global_film = nn.Sequential(
|
|
nn.Linear(self.global_feature_dim, hidden_dim * 2),
|
|
nn.GELU(),
|
|
nn.Linear(hidden_dim * 2, hidden_dim * 2),
|
|
)
|
|
self.position_embedding = nn.Parameter(torch.zeros(1, num_agents, hidden_dim))
|
|
self.blocks = nn.ModuleList(
|
|
[DualPathCorridorBlock(hidden_dim, kernel_size=kernel_size, dropout=dropout) for _ in range(num_blocks)]
|
|
)
|
|
self.head = nn.Sequential(
|
|
nn.Linear(hidden_dim * 2 + self.global_feature_dim, hidden_dim),
|
|
nn.LayerNorm(hidden_dim),
|
|
nn.ReLU(),
|
|
nn.Linear(hidden_dim, hidden_dim // 2),
|
|
nn.ReLU(),
|
|
nn.Linear(hidden_dim // 2, 1),
|
|
)
|
|
self._init_weights()
|
|
|
|
def _init_weights(self):
|
|
for module in self.head.modules():
|
|
if isinstance(module, nn.Linear):
|
|
nn.init.orthogonal_(module.weight, gain=np.sqrt(2))
|
|
nn.init.constant_(module.bias, 0)
|
|
nn.init.orthogonal_(self.head[-1].weight, gain=1.0)
|
|
nn.init.normal_(self.position_embedding, mean=0.0, std=0.02)
|
|
|
|
def forward(self, edge_tokens: torch.Tensor, global_features: torch.Tensor) -> torch.Tensor:
|
|
x = self.edge_proj(edge_tokens)
|
|
gamma, beta = self.global_film(global_features).chunk(2, dim=-1)
|
|
gamma = 0.1 * torch.tanh(gamma).unsqueeze(1)
|
|
beta = 0.1 * beta.unsqueeze(1)
|
|
x = x * (1.0 + gamma) + beta
|
|
x = x + self.position_embedding
|
|
for block in self.blocks:
|
|
x = block(x)
|
|
|
|
pooled_mean = x.mean(dim=1)
|
|
pooled_max = x.max(dim=1).values
|
|
summary = torch.cat([pooled_mean, pooled_max, global_features], dim=-1)
|
|
return self.head(summary)
|
|
|
|
|
|
class DCMAPPOAgent:
|
|
"""Directional Corridor MAPPO with gated actor and structured critic."""
|
|
|
|
def __init__(
|
|
self,
|
|
state_dim: int,
|
|
num_agents: int,
|
|
num_actions: int,
|
|
edge_feature_dim: int = 3,
|
|
time_feature_dim: int = 3,
|
|
hidden_dim: int = 256,
|
|
critic_hidden_dim: int = 256,
|
|
num_corridor_blocks: int = 2,
|
|
corridor_kernel_size: int = 5,
|
|
corridor_dropout: float = 0.05,
|
|
learning_rate: float = 3e-4,
|
|
gamma: float = 0.99,
|
|
gae_lambda: float = 0.95,
|
|
clip_epsilon: float = 0.2,
|
|
value_coef: float = 0.5,
|
|
entropy_coef: float = 0.01,
|
|
max_grad_norm: float = 0.5,
|
|
ppo_epochs: int = 4,
|
|
minibatch_size: int = 15,
|
|
device: str = "cuda",
|
|
lr_schedule: str = "cosine",
|
|
total_episodes: int = 300,
|
|
):
|
|
self.device = torch.device(device if torch.cuda.is_available() else "cpu")
|
|
self.state_dim = state_dim
|
|
self.num_agents = num_agents
|
|
self.num_actions = num_actions
|
|
self.edge_feature_dim = edge_feature_dim
|
|
self.time_feature_dim = time_feature_dim
|
|
self.gamma = gamma
|
|
self.gae_lambda = gae_lambda
|
|
self.clip_epsilon = clip_epsilon
|
|
self.value_coef = value_coef
|
|
self.entropy_coef = entropy_coef
|
|
self.max_grad_norm = max_grad_norm
|
|
self.ppo_epochs = ppo_epochs
|
|
self.minibatch_size = minibatch_size
|
|
|
|
self.speed_feature_dim = 1
|
|
self.last_reward_dim = 1
|
|
self.global_feature_dim = self.time_feature_dim + self.last_reward_dim
|
|
self.agent_id_dim = 1
|
|
self.edge_token_dim = self.edge_feature_dim + self.speed_feature_dim
|
|
self.condition_dim = self.global_feature_dim + self.agent_id_dim
|
|
|
|
self.actor = CorridorActor(
|
|
edge_token_dim=self.edge_token_dim,
|
|
condition_dim=self.condition_dim,
|
|
num_agents=num_agents,
|
|
num_actions=num_actions,
|
|
hidden_dim=hidden_dim,
|
|
num_blocks=num_corridor_blocks,
|
|
kernel_size=corridor_kernel_size,
|
|
dropout=corridor_dropout,
|
|
).to(self.device)
|
|
self.critic = StructuredCorridorCritic(
|
|
num_agents=num_agents,
|
|
edge_feature_dim=edge_feature_dim,
|
|
time_feature_dim=time_feature_dim,
|
|
hidden_dim=critic_hidden_dim,
|
|
num_blocks=max(1, num_corridor_blocks),
|
|
kernel_size=corridor_kernel_size,
|
|
dropout=corridor_dropout,
|
|
).to(self.device)
|
|
|
|
self.actor_optimizer = optim.Adam(self.actor.parameters(), lr=learning_rate, eps=1e-5)
|
|
self.critic_optimizer = optim.Adam(self.critic.parameters(), lr=learning_rate, eps=1e-5)
|
|
|
|
if lr_schedule == "cosine":
|
|
self.actor_scheduler = optim.lr_scheduler.CosineAnnealingLR(
|
|
self.actor_optimizer,
|
|
T_max=total_episodes,
|
|
eta_min=learning_rate * 0.1,
|
|
)
|
|
self.critic_scheduler = optim.lr_scheduler.CosineAnnealingLR(
|
|
self.critic_optimizer,
|
|
T_max=total_episodes,
|
|
eta_min=learning_rate * 0.1,
|
|
)
|
|
else:
|
|
self.actor_scheduler = None
|
|
self.critic_scheduler = None
|
|
|
|
agent_ids = np.linspace(0.0, 1.0, num_agents, dtype=np.float32)
|
|
self.agent_id_features = torch.tensor(agent_ids, device=self.device).view(1, num_agents, 1)
|
|
self.reset_buffers()
|
|
|
|
def reset_buffers(self):
|
|
self.states = []
|
|
self.actions = []
|
|
self.rewards = []
|
|
self.values = []
|
|
self.log_probs = []
|
|
self.dones = []
|
|
|
|
@staticmethod
|
|
def _shift_left(x: torch.Tensor) -> torch.Tensor:
|
|
return torch.cat([x[:, :1, :], x[:, :-1, :]], dim=1)
|
|
|
|
@staticmethod
|
|
def _shift_right(x: torch.Tensor) -> torch.Tensor:
|
|
return torch.cat([x[:, 1:, :], x[:, -1:, :]], dim=1)
|
|
|
|
def _parse_state(self, state_tensor: torch.Tensor):
|
|
if state_tensor.dim() == 1:
|
|
state_tensor = state_tensor.unsqueeze(0)
|
|
|
|
batch_size = state_tensor.size(0)
|
|
edge_block = self.num_agents * self.edge_feature_dim
|
|
speed_block_start = edge_block
|
|
speed_block_end = speed_block_start + self.num_agents
|
|
global_block_start = speed_block_end
|
|
global_block_end = global_block_start + self.global_feature_dim
|
|
|
|
edge_features = state_tensor[:, :edge_block].view(batch_size, self.num_agents, self.edge_feature_dim)
|
|
local_speed_limits = state_tensor[:, speed_block_start:speed_block_end].view(batch_size, self.num_agents, 1)
|
|
edge_tokens = torch.cat([edge_features, local_speed_limits], dim=-1)
|
|
global_features = state_tensor[:, global_block_start:global_block_end]
|
|
agent_ids = self.agent_id_features.expand(batch_size, -1, -1)
|
|
condition = torch.cat(
|
|
[global_features.unsqueeze(1).expand(-1, self.num_agents, -1), agent_ids],
|
|
dim=-1,
|
|
)
|
|
return edge_tokens, global_features, condition
|
|
|
|
def _build_actor_inputs(self, state_tensor: torch.Tensor):
|
|
edge_tokens, global_features, condition = self._parse_state(state_tensor)
|
|
left_tokens = self._shift_left(edge_tokens)
|
|
right_tokens = self._shift_right(edge_tokens)
|
|
delta_left = edge_tokens - left_tokens
|
|
delta_right = right_tokens - edge_tokens
|
|
components = torch.stack(
|
|
[edge_tokens, left_tokens, right_tokens, delta_left, delta_right],
|
|
dim=2,
|
|
)
|
|
return components, condition, edge_tokens, global_features
|
|
|
|
def select_action(
|
|
self,
|
|
state: np.ndarray,
|
|
deterministic: bool = False,
|
|
) -> Tuple[np.ndarray, np.ndarray, float]:
|
|
state_tensor = torch.FloatTensor(state).unsqueeze(0).to(self.device)
|
|
components, condition, edge_tokens, global_features = self._build_actor_inputs(state_tensor)
|
|
|
|
with torch.no_grad():
|
|
logits = self.actor(components, condition)
|
|
value = self.critic(edge_tokens, global_features)
|
|
|
|
actions = []
|
|
log_probs = []
|
|
for agent_idx in range(self.num_agents):
|
|
dist = torch.distributions.Categorical(logits=logits[0, agent_idx])
|
|
if deterministic:
|
|
action = torch.argmax(logits[0, agent_idx], dim=-1).item()
|
|
else:
|
|
action = dist.sample().item()
|
|
actions.append(action)
|
|
log_probs.append(dist.log_prob(torch.tensor(action, device=self.device)).item())
|
|
|
|
return np.array(actions, dtype=np.int64), np.array(log_probs, dtype=np.float32), value.item()
|
|
|
|
def get_value(self, state: np.ndarray) -> float:
|
|
state_tensor = torch.FloatTensor(state).unsqueeze(0).to(self.device)
|
|
_, _, edge_tokens, global_features = self._build_actor_inputs(state_tensor)
|
|
with torch.no_grad():
|
|
value = self.critic(edge_tokens, global_features)
|
|
return value.item()
|
|
|
|
def store_transition(self, state, action, reward, value, log_prob, done):
|
|
self.states.append(state)
|
|
self.actions.append(action)
|
|
self.rewards.append(reward)
|
|
self.values.append(value)
|
|
self.log_probs.append(log_prob)
|
|
self.dones.append(done)
|
|
|
|
def compute_gae(self, next_value: float):
|
|
advantages = []
|
|
gae = 0.0
|
|
|
|
for t in reversed(range(len(self.rewards))):
|
|
next_val = next_value if t == len(self.rewards) - 1 else self.values[t + 1]
|
|
delta = self.rewards[t] + self.gamma * next_val * (1 - self.dones[t]) - self.values[t]
|
|
gae = delta + self.gamma * self.gae_lambda * (1 - self.dones[t]) * gae
|
|
advantages.insert(0, gae)
|
|
|
|
advantages = np.array(advantages, dtype=np.float32)
|
|
returns = advantages + np.array(self.values, dtype=np.float32)
|
|
return advantages, returns
|
|
|
|
def update(self, next_value: float) -> Dict[str, float]:
|
|
if len(self.states) == 0:
|
|
return {}
|
|
|
|
advantages, returns = self.compute_gae(next_value)
|
|
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
|
|
|
|
states = torch.FloatTensor(np.array(self.states)).to(self.device)
|
|
actions = torch.LongTensor(np.array(self.actions)).to(self.device)
|
|
old_log_probs = torch.FloatTensor(np.array(self.log_probs)).to(self.device)
|
|
advantages_t = torch.FloatTensor(advantages).to(self.device)
|
|
returns_t = torch.FloatTensor(returns).to(self.device)
|
|
|
|
total_policy_loss = 0.0
|
|
total_value_loss = 0.0
|
|
total_entropy = 0.0
|
|
update_count = 0
|
|
|
|
dataset_size = len(self.states)
|
|
for _ in range(self.ppo_epochs):
|
|
indices = np.random.permutation(dataset_size)
|
|
for start_idx in range(0, dataset_size, self.minibatch_size):
|
|
end_idx = min(start_idx + self.minibatch_size, dataset_size)
|
|
batch_idx = indices[start_idx:end_idx]
|
|
|
|
batch_states = states[batch_idx]
|
|
batch_actions = actions[batch_idx]
|
|
batch_old_lp = old_log_probs[batch_idx]
|
|
batch_adv = advantages_t[batch_idx]
|
|
batch_ret = returns_t[batch_idx]
|
|
|
|
components, condition, edge_tokens, global_features = self._build_actor_inputs(batch_states)
|
|
logits = self.actor(components, condition)
|
|
dist = torch.distributions.Categorical(logits=logits)
|
|
|
|
new_log_probs = dist.log_prob(batch_actions)
|
|
entropy = dist.entropy().mean()
|
|
|
|
expanded_adv = batch_adv.unsqueeze(1).expand(-1, self.num_agents)
|
|
ratio = torch.exp(new_log_probs - batch_old_lp)
|
|
surr1 = ratio * expanded_adv
|
|
surr2 = torch.clamp(ratio, 1 - self.clip_epsilon, 1 + self.clip_epsilon) * expanded_adv
|
|
policy_loss = -torch.min(surr1, surr2).mean()
|
|
|
|
values = self.critic(edge_tokens, global_features).squeeze(-1)
|
|
value_loss = nn.functional.mse_loss(values, batch_ret)
|
|
|
|
actor_loss = policy_loss - self.entropy_coef * entropy
|
|
critic_loss = self.value_coef * value_loss
|
|
|
|
self.actor_optimizer.zero_grad()
|
|
actor_loss.backward()
|
|
nn.utils.clip_grad_norm_(self.actor.parameters(), self.max_grad_norm)
|
|
self.actor_optimizer.step()
|
|
|
|
self.critic_optimizer.zero_grad()
|
|
critic_loss.backward()
|
|
nn.utils.clip_grad_norm_(self.critic.parameters(), self.max_grad_norm)
|
|
self.critic_optimizer.step()
|
|
|
|
total_policy_loss += policy_loss.item()
|
|
total_value_loss += value_loss.item()
|
|
total_entropy += entropy.item()
|
|
update_count += 1
|
|
|
|
if self.actor_scheduler is not None:
|
|
self.actor_scheduler.step()
|
|
if self.critic_scheduler is not None:
|
|
self.critic_scheduler.step()
|
|
|
|
self.reset_buffers()
|
|
return {
|
|
"policy_loss": total_policy_loss / max(update_count, 1),
|
|
"value_loss": total_value_loss / max(update_count, 1),
|
|
"entropy": total_entropy / max(update_count, 1),
|
|
}
|
|
|
|
def save(self, path: str):
|
|
torch.save(
|
|
{
|
|
"actor_state_dict": self.actor.state_dict(),
|
|
"critic_state_dict": self.critic.state_dict(),
|
|
"actor_optimizer_state_dict": self.actor_optimizer.state_dict(),
|
|
"critic_optimizer_state_dict": self.critic_optimizer.state_dict(),
|
|
},
|
|
path,
|
|
)
|
|
|
|
def load(self, path: str):
|
|
checkpoint = torch.load(path, map_location=self.device, weights_only=False)
|
|
self.actor.load_state_dict(checkpoint["actor_state_dict"])
|
|
self.critic.load_state_dict(checkpoint["critic_state_dict"])
|
|
|
|
actor_optim_state = checkpoint.get("actor_optimizer_state_dict")
|
|
critic_optim_state = checkpoint.get("critic_optimizer_state_dict")
|
|
legacy_optim_state = checkpoint.get("optimizer_state_dict")
|
|
if actor_optim_state is not None:
|
|
self.actor_optimizer.load_state_dict(actor_optim_state)
|
|
elif legacy_optim_state is not None:
|
|
self.actor_optimizer.load_state_dict(legacy_optim_state)
|
|
|
|
if critic_optim_state is not None:
|
|
self.critic_optimizer.load_state_dict(critic_optim_state)
|
|
elif legacy_optim_state is not None:
|
|
self.critic_optimizer.load_state_dict(legacy_optim_state)
|