""" TD3 Agent using Stable-Baselines3 Adapted for MultiDiscrete VSL control. """ import numpy as np import torch from stable_baselines3 import TD3 from stable_baselines3.common.noise import NormalActionNoise import gymnasium as gym from gymnasium import spaces class MultiDiscreteWrapper(gym.Env): """Wrap a MultiDiscrete action space as a continuous Box for TD3.""" def __init__(self, state_dim, action_dims): super().__init__() self.state_dim = state_dim self.action_dims = action_dims self.num_zones = len(action_dims) self.observation_space = spaces.Box( low=-np.inf, high=np.inf, shape=(state_dim,), dtype=np.float32 ) self.action_space = spaces.Box( low=0.0, high=1.0, shape=(self.num_zones,), dtype=np.float32 ) def reset(self, seed=None, options=None): return np.zeros(self.state_dim, dtype=np.float32), {} def step(self, action): return np.zeros(self.state_dim, dtype=np.float32), 0.0, False, False, {} class TD3Agent: """TD3 agent wrapper.""" def __init__( self, state_dim: int, action_dims: list, learning_rate: float = 3e-4, buffer_size: int = 100000, learning_starts: int = 100, batch_size: int = 64, tau: float = 0.005, gamma: float = 0.99, policy_delay: int = 2, exploration_sigma: float = 0.1, device: str = "cuda", ): self.state_dim = state_dim self.action_dims = action_dims self.num_zones = len(action_dims) self.device = device self.learning_starts = learning_starts self.total_steps = 0 self.exploration_sigma = exploration_sigma dummy_env = MultiDiscreteWrapper(state_dim, action_dims) action_noise = NormalActionNoise( mean=np.zeros(self.num_zones), sigma=0.1 * np.ones(self.num_zones), ) self.model = TD3( "MlpPolicy", env=dummy_env, learning_rate=learning_rate, buffer_size=buffer_size, learning_starts=learning_starts, batch_size=batch_size, tau=tau, gamma=gamma, policy_delay=policy_delay, action_noise=action_noise, device=device, verbose=0, ) def select_action(self, state: np.ndarray, deterministic: bool = False): if not deterministic and self.total_steps < self.learning_starts: discrete_action = np.array( [np.random.randint(self.action_dims[i]) for i in range(self.num_zones)], dtype=np.int64, ) return discrete_action, 0.0, 0.0 continuous_action, _ = self.model.predict(state, deterministic=deterministic) if not deterministic: noise = np.random.normal(0.0, self.exploration_sigma, size=self.num_zones) continuous_action = np.clip(continuous_action + noise, 0.0, 1.0) discrete_action = np.array( [int(cont * (self.action_dims[i] - 1) + 0.5) for i, cont in enumerate(continuous_action)] ) discrete_action = np.clip(discrete_action, 0, [d - 1 for d in self.action_dims]) return discrete_action, 0.0, 0.0 def store_transition(self, state, action, reward, next_state, done): self.total_steps += 1 continuous_action = np.array( [action[i] / (self.action_dims[i] - 1) for i in range(self.num_zones)], dtype=np.float32, ) self.model.replay_buffer.add( state, next_state, continuous_action, reward, done, [{}] ) def update(self): if self.model.replay_buffer.size() < self.model.batch_size: return {} self.model.learn(total_timesteps=1, reset_num_timesteps=False, log_interval=None) return {"actor_loss": 0.0, "critic_loss": 0.0} def save(self, path: str): self.model.save(path) def load(self, path: str): self.model = TD3.load(path, device=self.device)