"""DDPG agent using Stable-Baselines3 for VSL control.""" from __future__ import annotations from typing import List, Sequence import gymnasium as gym import numpy as np import torch.nn as nn from gymnasium import spaces from stable_baselines3 import DDPG from stable_baselines3.common.noise import NormalActionNoise from stable_baselines3.common.torch_layers import FlattenExtractor from utils.sb3_manual import ensure_manual_logger, sync_manual_timesteps class MultiDiscreteWrapper(gym.Env): def __init__(self, state_dim, action_dims): super().__init__() self.state_dim = state_dim self.action_dims = action_dims self.num_zones = len(action_dims) self.observation_space = spaces.Box(low=-np.inf, high=np.inf, shape=(state_dim,), dtype=np.float32) self.action_space = spaces.Box(low=0.0, high=1.0, shape=(self.num_zones,), dtype=np.float32) def reset(self, seed=None, options=None): return np.zeros(self.state_dim, dtype=np.float32), {} def step(self, action): return np.zeros(self.state_dim, dtype=np.float32), 0.0, False, False, {} def _resolve_activation_fn(name: str): key = (name or "relu").strip().lower() if key == "relu": return nn.ReLU if key == "silu": return nn.SiLU if key == "elu": return nn.ELU raise ValueError(f"Unsupported DDPG activation: {name}") def _as_arch_list(value, default: List[int]) -> List[int]: if value is None: return list(default) return [int(v) for v in value] class DDPGAgent: """DDPG agent wrapper.""" def __init__( self, state_dim: int, action_dims: list, learning_rate: float = 3e-4, buffer_size: int = 100000, learning_starts: int = 100, batch_size: int = 64, tau: float = 0.005, gamma: float = 0.99, exploration_sigma: float = 0.1, device: str = "cuda", actor_hidden_dims: Sequence[int] | None = None, critic_hidden_dims: Sequence[int] | None = None, activation_fn: str = "relu", ): self.action_dims = action_dims self.num_zones = len(action_dims) self.device = device self.learning_starts = learning_starts self.total_steps = 0 self.exploration_sigma = exploration_sigma dummy_env = MultiDiscreteWrapper(state_dim, action_dims) action_noise = NormalActionNoise( mean=np.zeros(self.num_zones), sigma=float(exploration_sigma) * np.ones(self.num_zones), ) policy_kwargs = { "net_arch": { "pi": _as_arch_list(actor_hidden_dims, [256, 256]), "qf": _as_arch_list(critic_hidden_dims, [256, 256]), }, "activation_fn": _resolve_activation_fn(activation_fn), "features_extractor_class": FlattenExtractor, } self.model = DDPG( "MlpPolicy", env=dummy_env, learning_rate=learning_rate, buffer_size=buffer_size, learning_starts=learning_starts, batch_size=batch_size, tau=tau, gamma=gamma, action_noise=action_noise, device=device, verbose=0, policy_kwargs=policy_kwargs, ) ensure_manual_logger(self.model) def select_action(self, state: np.ndarray, deterministic: bool = False): if not deterministic and self.total_steps < self.learning_starts: discrete_action = np.array([ np.random.randint(self.action_dims[i]) for i in range(self.num_zones) ], dtype=np.int64) return discrete_action, 0.0, 0.0 continuous_action, _ = self.model.predict(state, deterministic=deterministic) if not deterministic: noise = np.random.normal(0.0, self.exploration_sigma, size=self.num_zones) continuous_action = np.clip(continuous_action + noise, 0.0, 1.0) discrete_action = np.array([ int(cont * (self.action_dims[i] - 1) + 0.5) for i, cont in enumerate(continuous_action) ]) discrete_action = np.clip(discrete_action, 0, [d - 1 for d in self.action_dims]) return discrete_action, 0.0, 0.0 def store_transition(self, state, action, reward, next_state, done): self.total_steps += 1 sync_manual_timesteps(self.model, self.total_steps) continuous_action = np.array([ action[i] / (self.action_dims[i] - 1) for i in range(self.num_zones) ], dtype=np.float32) self.model.replay_buffer.add(state, next_state, continuous_action, reward, done, [{}]) def update(self): if self.model.replay_buffer.size() < self.model.batch_size: return {} self.model.train(gradient_steps=1, batch_size=self.model.batch_size) return {"updates": float(self.model._n_updates)} def save(self, path: str): self.model.save(path) def load(self, path: str): self.model = DDPG.load(path, device=self.device) ensure_manual_logger(self.model) self.total_steps = int(getattr(self.model, "num_timesteps", 0))