ctm-dqn/agents/ddpg_agent.py

146 lines
5.1 KiB
Python

"""DDPG agent using Stable-Baselines3 for VSL control."""
from __future__ import annotations
from typing import List, Sequence
import gymnasium as gym
import numpy as np
import torch.nn as nn
from gymnasium import spaces
from stable_baselines3 import DDPG
from stable_baselines3.common.noise import NormalActionNoise
from stable_baselines3.common.torch_layers import FlattenExtractor
from utils.sb3_manual import ensure_manual_logger, sync_manual_timesteps
class MultiDiscreteWrapper(gym.Env):
def __init__(self, state_dim, action_dims):
super().__init__()
self.state_dim = state_dim
self.action_dims = action_dims
self.num_zones = len(action_dims)
self.observation_space = spaces.Box(low=-np.inf, high=np.inf, shape=(state_dim,), dtype=np.float32)
self.action_space = spaces.Box(low=0.0, high=1.0, shape=(self.num_zones,), dtype=np.float32)
def reset(self, seed=None, options=None):
return np.zeros(self.state_dim, dtype=np.float32), {}
def step(self, action):
return np.zeros(self.state_dim, dtype=np.float32), 0.0, False, False, {}
def _resolve_activation_fn(name: str):
key = (name or "relu").strip().lower()
if key == "relu":
return nn.ReLU
if key == "silu":
return nn.SiLU
if key == "elu":
return nn.ELU
raise ValueError(f"Unsupported DDPG activation: {name}")
def _as_arch_list(value, default: List[int]) -> List[int]:
if value is None:
return list(default)
return [int(v) for v in value]
class DDPGAgent:
"""DDPG agent wrapper."""
def __init__(
self,
state_dim: int,
action_dims: list,
learning_rate: float = 3e-4,
buffer_size: int = 100000,
learning_starts: int = 100,
batch_size: int = 64,
tau: float = 0.005,
gamma: float = 0.99,
exploration_sigma: float = 0.1,
device: str = "cuda",
actor_hidden_dims: Sequence[int] | None = None,
critic_hidden_dims: Sequence[int] | None = None,
activation_fn: str = "relu",
):
self.action_dims = action_dims
self.num_zones = len(action_dims)
self.device = device
self.learning_starts = learning_starts
self.total_steps = 0
self.exploration_sigma = exploration_sigma
dummy_env = MultiDiscreteWrapper(state_dim, action_dims)
action_noise = NormalActionNoise(
mean=np.zeros(self.num_zones),
sigma=float(exploration_sigma) * np.ones(self.num_zones),
)
policy_kwargs = {
"net_arch": {
"pi": _as_arch_list(actor_hidden_dims, [256, 256]),
"qf": _as_arch_list(critic_hidden_dims, [256, 256]),
},
"activation_fn": _resolve_activation_fn(activation_fn),
"features_extractor_class": FlattenExtractor,
}
self.model = DDPG(
"MlpPolicy",
env=dummy_env,
learning_rate=learning_rate,
buffer_size=buffer_size,
learning_starts=learning_starts,
batch_size=batch_size,
tau=tau,
gamma=gamma,
action_noise=action_noise,
device=device,
verbose=0,
policy_kwargs=policy_kwargs,
)
ensure_manual_logger(self.model)
def select_action(self, state: np.ndarray, deterministic: bool = False):
if not deterministic and self.total_steps < self.learning_starts:
discrete_action = np.array([
np.random.randint(self.action_dims[i]) for i in range(self.num_zones)
], dtype=np.int64)
return discrete_action, 0.0, 0.0
continuous_action, _ = self.model.predict(state, deterministic=deterministic)
if not deterministic:
noise = np.random.normal(0.0, self.exploration_sigma, size=self.num_zones)
continuous_action = np.clip(continuous_action + noise, 0.0, 1.0)
discrete_action = np.array([
int(cont * (self.action_dims[i] - 1) + 0.5)
for i, cont in enumerate(continuous_action)
])
discrete_action = np.clip(discrete_action, 0, [d - 1 for d in self.action_dims])
return discrete_action, 0.0, 0.0
def store_transition(self, state, action, reward, next_state, done):
self.total_steps += 1
sync_manual_timesteps(self.model, self.total_steps)
continuous_action = np.array([
action[i] / (self.action_dims[i] - 1)
for i in range(self.num_zones)
], dtype=np.float32)
self.model.replay_buffer.add(state, next_state, continuous_action, reward, done, [{}])
def update(self):
if self.model.replay_buffer.size() < self.model.batch_size:
return {}
self.model.train(gradient_steps=1, batch_size=self.model.batch_size)
return {"updates": float(self.model._n_updates)}
def save(self, path: str):
self.model.save(path)
def load(self, path: str):
self.model = DDPG.load(path, device=self.device)
ensure_manual_logger(self.model)
self.total_steps = int(getattr(self.model, "num_timesteps", 0))