ctm-dqn/agents/sac_agent.py

167 lines
5.7 KiB
Python

"""SAC agent using Stable-Baselines3 for MultiDiscrete-like VSL control."""
from __future__ import annotations
from typing import List, Sequence
import gymnasium as gym
import numpy as np
import torch.nn as nn
from gymnasium import spaces
from stable_baselines3 import SAC
from stable_baselines3.common.torch_layers import FlattenExtractor
from utils.sb3_manual import ensure_manual_logger, sync_manual_timesteps
class MultiDiscreteWrapper(gym.Env):
"""Wrap a MultiDiscrete action space as a continuous Box for SAC."""
def __init__(self, state_dim: int, action_dims: Sequence[int]):
super().__init__()
self.state_dim = state_dim
self.action_dims = action_dims
self.num_zones = len(action_dims)
self.observation_space = spaces.Box(
low=-np.inf, high=np.inf, shape=(state_dim,), dtype=np.float32
)
# Use [-1, 1] so the replay buffer stores SAC's native squashed action scale.
self.action_space = spaces.Box(
low=-1.0, high=1.0, shape=(self.num_zones,), dtype=np.float32
)
def reset(self, seed=None, options=None):
return np.zeros(self.state_dim, dtype=np.float32), {}
def step(self, action):
return np.zeros(self.state_dim, dtype=np.float32), 0.0, False, False, {}
def _resolve_activation_fn(name: str):
key = (name or "relu").strip().lower()
if key == "relu":
return nn.ReLU
if key == "silu":
return nn.SiLU
if key == "elu":
return nn.ELU
raise ValueError(f"Unsupported SAC activation: {name}")
def _as_arch_list(value, default: List[int]) -> List[int]:
if value is None:
return list(default)
return [int(v) for v in value]
class SACAgent:
"""SAC agent wrapper."""
def __init__(
self,
state_dim: int,
action_dims: list,
seed: int | None = None,
learning_rate: float = 3e-4,
buffer_size: int = 100000,
learning_starts: int = 100,
batch_size: int = 64,
tau: float = 0.005,
gamma: float = 0.99,
ent_coef: str | float = "auto",
target_entropy: str | float = "auto",
target_update_interval: int = 1,
log_std_init: float = -3.0,
device: str = "cuda",
actor_hidden_dims: Sequence[int] | None = None,
critic_hidden_dims: Sequence[int] | None = None,
activation_fn: str = "relu",
):
self.state_dim = state_dim
self.action_dims = action_dims
self.num_zones = len(action_dims)
self.device = device
self.learning_starts = learning_starts
self.total_steps = 0
self.seed = seed
dummy_env = MultiDiscreteWrapper(state_dim, action_dims)
policy_kwargs = {
"net_arch": {
"pi": _as_arch_list(actor_hidden_dims, [256, 256]),
"qf": _as_arch_list(critic_hidden_dims, [256, 256]),
},
"activation_fn": _resolve_activation_fn(activation_fn),
"features_extractor_class": FlattenExtractor,
"log_std_init": float(log_std_init),
}
self.model = SAC(
"MlpPolicy",
env=dummy_env,
seed=seed,
learning_rate=learning_rate,
buffer_size=buffer_size,
learning_starts=learning_starts,
batch_size=batch_size,
tau=tau,
gamma=gamma,
ent_coef=ent_coef,
target_entropy=target_entropy,
target_update_interval=target_update_interval,
device=device,
verbose=0,
policy_kwargs=policy_kwargs,
)
ensure_manual_logger(self.model)
def _scaled_to_discrete(self, scaled_action: np.ndarray) -> np.ndarray:
normalized_action = np.clip((scaled_action + 1.0) * 0.5, 0.0, 1.0)
discrete_action = np.array(
[
int(cont * (self.action_dims[i] - 1) + 0.5)
for i, cont in enumerate(normalized_action)
],
dtype=np.int64,
)
discrete_action = np.clip(discrete_action, 0, [d - 1 for d in self.action_dims])
return discrete_action
def _discrete_to_scaled(self, action: np.ndarray) -> np.ndarray:
normalized_action = np.array(
[action[i] / (self.action_dims[i] - 1) for i in range(self.num_zones)],
dtype=np.float32,
)
return normalized_action * 2.0 - 1.0
def select_action(self, state: np.ndarray, deterministic: bool = False):
if not deterministic and self.total_steps < self.learning_starts:
discrete_action = np.array(
[np.random.randint(self.action_dims[i]) for i in range(self.num_zones)],
dtype=np.int64,
)
return discrete_action, 0.0, 0.0
scaled_action, _ = self.model.predict(state, deterministic=deterministic)
discrete_action = self._scaled_to_discrete(scaled_action)
return discrete_action, 0.0, 0.0
def store_transition(self, state, action, reward, next_state, done):
self.total_steps += 1
sync_manual_timesteps(self.model, self.total_steps)
scaled_action = self._discrete_to_scaled(action)
self.model.replay_buffer.add(state, next_state, scaled_action, reward, done, [{}])
def update(self):
if self.model.replay_buffer.size() < self.model.batch_size:
return {}
self.model.train(gradient_steps=1, batch_size=self.model.batch_size)
return {"updates": float(self.model._n_updates)}
def save(self, path: str):
self.model.save(path)
def load(self, path: str):
self.model = SAC.load(path, device=self.device)
ensure_manual_logger(self.model)
self.total_steps = int(getattr(self.model, "num_timesteps", 0))