ctm-dqn/agents/d3pg_agent.py

179 lines
5.9 KiB
Python

"""
D3PG: a pragmatic DDPG/TD3 hybrid baseline for motorway VSL.
Design:
- keep deterministic actor and simple continuous-action proxy from DDPG
- keep twin critics and delayed actor updates from TD3
- disable target policy smoothing because actions are ultimately snapped to
discrete speed-limit levels, so smoothed target actions add mismatch
without much benefit
"""
from __future__ import annotations
from typing import List, Sequence
import gymnasium as gym
import numpy as np
import torch.nn as nn
from gymnasium import spaces
from stable_baselines3 import TD3
from stable_baselines3.common.noise import NormalActionNoise
from stable_baselines3.common.torch_layers import FlattenExtractor
from utils.sb3_manual import ensure_manual_logger, sync_manual_timesteps
class MultiDiscreteWrapper(gym.Env):
def __init__(self, state_dim: int, action_dims: Sequence[int]):
super().__init__()
self.state_dim = state_dim
self.action_dims = action_dims
self.num_zones = len(action_dims)
self.observation_space = spaces.Box(
low=-np.inf,
high=np.inf,
shape=(state_dim,),
dtype=np.float32,
)
self.action_space = spaces.Box(
low=0.0,
high=1.0,
shape=(self.num_zones,),
dtype=np.float32,
)
def reset(self, seed=None, options=None):
return np.zeros(self.state_dim, dtype=np.float32), {}
def step(self, action):
return np.zeros(self.state_dim, dtype=np.float32), 0.0, False, False, {}
def _resolve_activation_fn(name: str):
key = (name or "relu").strip().lower()
if key == "relu":
return nn.ReLU
if key == "silu":
return nn.SiLU
if key == "elu":
return nn.ELU
raise ValueError(f"Unsupported D3PG activation: {name}")
def _as_arch_list(value, default: List[int]) -> List[int]:
if value is None:
return list(default)
return [int(v) for v in value]
class D3PGAgent:
"""Twin-critic delayed deterministic policy gradient without target smoothing."""
def __init__(
self,
state_dim: int,
action_dims: list,
seed: int | None = None,
learning_rate: float = 3e-4,
buffer_size: int = 100000,
learning_starts: int = 100,
batch_size: int = 64,
tau: float = 0.005,
gamma: float = 0.99,
policy_delay: int = 2,
exploration_sigma: float = 0.1,
device: str = "cuda",
actor_hidden_dims: Sequence[int] | None = None,
critic_hidden_dims: Sequence[int] | None = None,
activation_fn: str = "relu",
):
self.state_dim = state_dim
self.action_dims = action_dims
self.num_zones = len(action_dims)
self.device = device
self.learning_starts = learning_starts
self.total_steps = 0
self.exploration_sigma = exploration_sigma
self.seed = seed
dummy_env = MultiDiscreteWrapper(state_dim, action_dims)
action_noise = NormalActionNoise(
mean=np.zeros(self.num_zones),
sigma=float(exploration_sigma) * np.ones(self.num_zones),
)
policy_kwargs = {
"net_arch": {
"pi": _as_arch_list(actor_hidden_dims, [256, 256]),
"qf": _as_arch_list(critic_hidden_dims, [256, 256]),
},
"activation_fn": _resolve_activation_fn(activation_fn),
"features_extractor_class": FlattenExtractor,
}
self.model = TD3(
"MlpPolicy",
env=dummy_env,
seed=seed,
learning_rate=learning_rate,
buffer_size=buffer_size,
learning_starts=learning_starts,
batch_size=batch_size,
tau=tau,
gamma=gamma,
policy_delay=policy_delay,
target_policy_noise=0.0,
target_noise_clip=0.0,
action_noise=action_noise,
device=device,
verbose=0,
policy_kwargs=policy_kwargs,
)
ensure_manual_logger(self.model)
def select_action(self, state: np.ndarray, deterministic: bool = False):
if not deterministic and self.total_steps < self.learning_starts:
discrete_action = np.array(
[np.random.randint(self.action_dims[i]) for i in range(self.num_zones)],
dtype=np.int64,
)
return discrete_action, 0.0, 0.0
continuous_action, _ = self.model.predict(state, deterministic=deterministic)
if not deterministic:
noise = np.random.normal(0.0, self.exploration_sigma, size=self.num_zones)
continuous_action = np.clip(continuous_action + noise, 0.0, 1.0)
discrete_action = np.array(
[
int(cont * (self.action_dims[i] - 1) + 0.5)
for i, cont in enumerate(continuous_action)
]
)
discrete_action = np.clip(discrete_action, 0, [d - 1 for d in self.action_dims])
return discrete_action, 0.0, 0.0
def store_transition(self, state, action, reward, next_state, done):
self.total_steps += 1
sync_manual_timesteps(self.model, self.total_steps)
continuous_action = np.array(
[action[i] / (self.action_dims[i] - 1) for i in range(self.num_zones)],
dtype=np.float32,
)
self.model.replay_buffer.add(
state, next_state, continuous_action, reward, done, [{}]
)
def update(self):
if self.model.replay_buffer.size() < self.model.batch_size:
return {}
self.model.train(gradient_steps=1, batch_size=self.model.batch_size)
return {"updates": float(self.model._n_updates)}
def save(self, path: str):
self.model.save(path)
def load(self, path: str):
self.model = TD3.load(path, device=self.device)
ensure_manual_logger(self.model)
self.total_steps = int(getattr(self.model, "num_timesteps", 0))