ctm-dqn/agents/sctd3_agent.py

342 lines
12 KiB
Python

"""
Structured Corridor TD3 (SC-TD3) for motorway VSL control.
"""
from __future__ import annotations
from typing import List, Sequence
import gymnasium as gym
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from gymnasium import spaces
from stable_baselines3 import TD3
from stable_baselines3.common.noise import NormalActionNoise
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor
from utils.sb3_manual import ensure_manual_logger, sync_manual_timesteps
class MultiDiscreteWrapper(gym.Env):
"""Wrap a MultiDiscrete action space as a continuous Box for TD3."""
def __init__(self, state_dim: int, action_dims: Sequence[int]):
super().__init__()
self.state_dim = state_dim
self.action_dims = action_dims
self.num_zones = len(action_dims)
self.observation_space = spaces.Box(
low=-np.inf, high=np.inf, shape=(state_dim,), dtype=np.float32
)
self.action_space = spaces.Box(
low=0.0, high=1.0, shape=(self.num_zones,), dtype=np.float32
)
def reset(self, seed=None, options=None):
return np.zeros(self.state_dim, dtype=np.float32), {}
def step(self, action):
return np.zeros(self.state_dim, dtype=np.float32), 0.0, False, False, {}
class EdgeLocalMixer(nn.Module):
"""Lightweight local mixer that preserves strong per-edge identity."""
def __init__(self, hidden_dim: int, kernel_size: int = 3):
super().__init__()
padding = (kernel_size - 1) // 2
self.norm = nn.LayerNorm(hidden_dim)
self.depthwise = nn.Conv1d(
hidden_dim,
hidden_dim,
kernel_size=kernel_size,
padding=padding,
groups=hidden_dim,
)
self.pointwise = nn.Conv1d(hidden_dim, hidden_dim, kernel_size=1)
self.ffn = nn.Sequential(
nn.LayerNorm(hidden_dim),
nn.Linear(hidden_dim, hidden_dim * 2),
nn.SiLU(),
nn.Linear(hidden_dim * 2, hidden_dim),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
y = self.norm(x).transpose(1, 2)
y = F.silu(self.depthwise(y))
y = self.pointwise(y).transpose(1, 2)
# Keep spatial mixing weak to avoid washing out local bottlenecks.
x = x + 0.25 * y
x = x + 0.25 * self.ffn(x)
return x
class EdgeStructuredExtractor(BaseFeaturesExtractor):
"""
Feature extractor tailored to motorway VSL.
State layout is assumed to be:
- per-edge traffic features: [speed_norm, occ_norm, flow_norm] * num_edges
- per-edge current limit feature: [limit_norm] * num_edges
- global features: [time_progress, sin_t, cos_t, last_reward]
"""
def __init__(
self,
observation_space: spaces.Box,
num_edges: int,
edge_feature_dim: int = 3,
global_feature_dim: int = 4,
edge_hidden_dim: int = 16,
global_hidden_dim: int = 32,
spatial_blocks: int = 1,
kernel_size: int = 3,
features_dim: int = 128,
):
super().__init__(observation_space, features_dim)
self.num_edges = num_edges
self.edge_feature_dim = edge_feature_dim
self.edge_input_dim = edge_feature_dim + 1
self.edge_hidden_dim = edge_hidden_dim
self.global_feature_dim = global_feature_dim
self.edge_encoder = nn.Sequential(
nn.Linear(self.edge_input_dim, edge_hidden_dim),
nn.LayerNorm(edge_hidden_dim),
nn.SiLU(),
)
self.delta_encoder = nn.Sequential(
nn.Linear(edge_feature_dim, edge_hidden_dim),
nn.LayerNorm(edge_hidden_dim),
nn.SiLU(),
)
self.local_mixers = nn.ModuleList(
[
EdgeLocalMixer(
hidden_dim=edge_hidden_dim,
kernel_size=kernel_size,
)
for idx in range(spatial_blocks)
]
)
self.global_encoder = nn.Sequential(
nn.Linear(global_feature_dim + self.edge_input_dim * 2, global_hidden_dim),
nn.LayerNorm(global_hidden_dim),
nn.SiLU(),
)
self.edge_gate = nn.Linear(global_hidden_dim, edge_hidden_dim)
self.edge_bias = nn.Linear(global_hidden_dim, edge_hidden_dim)
raw_input_dim = num_edges * self.edge_input_dim
delta_input_dim = num_edges * edge_feature_dim
fused_input_dim = num_edges * edge_hidden_dim + raw_input_dim + delta_input_dim + global_hidden_dim
self.output_proj = nn.Sequential(
nn.Linear(fused_input_dim, features_dim),
nn.LayerNorm(features_dim),
nn.SiLU(),
)
self._init_weights()
def _init_weights(self):
for module in self.modules():
if isinstance(module, nn.Linear):
nn.init.orthogonal_(module.weight, gain=np.sqrt(2))
nn.init.constant_(module.bias, 0.0)
elif isinstance(module, nn.Conv1d):
nn.init.kaiming_normal_(module.weight, nonlinearity="relu")
if module.bias is not None:
nn.init.constant_(module.bias, 0.0)
def forward(self, observations: torch.Tensor) -> torch.Tensor:
batch_size = observations.shape[0]
edge_traffic_dim = self.num_edges * self.edge_feature_dim
edge_limit_dim = self.num_edges
edge_traffic = observations[:, :edge_traffic_dim].view(
batch_size, self.num_edges, self.edge_feature_dim
)
edge_limits = observations[
:, edge_traffic_dim : edge_traffic_dim + edge_limit_dim
].view(batch_size, self.num_edges, 1)
global_features = observations[:, edge_traffic_dim + edge_limit_dim :]
edge_inputs = torch.cat([edge_traffic, edge_limits], dim=-1)
edge_latent = self.edge_encoder(edge_inputs)
edge_deltas = torch.zeros_like(edge_traffic)
edge_deltas[:, 1:, :] = edge_traffic[:, 1:, :] - edge_traffic[:, :-1, :]
delta_latent = self.delta_encoder(edge_deltas)
edge_latent = edge_latent + 0.5 * delta_latent
for mixer in self.local_mixers:
edge_latent = mixer(edge_latent)
edge_inputs_mean = edge_inputs.mean(dim=1)
edge_inputs_max = edge_inputs.max(dim=1).values
global_latent = self.global_encoder(
torch.cat([global_features, edge_inputs_mean, edge_inputs_max], dim=-1)
)
gate = torch.tanh(self.edge_gate(global_latent)).unsqueeze(1)
bias = torch.tanh(self.edge_bias(global_latent)).unsqueeze(1)
edge_latent = edge_latent * (1.0 + gate) + bias
fused = torch.cat(
[
edge_latent.reshape(batch_size, -1),
edge_inputs.reshape(batch_size, -1),
edge_deltas.reshape(batch_size, -1),
global_latent,
],
dim=-1,
)
return self.output_proj(fused)
def _resolve_activation_fn(name: str):
key = (name or "relu").strip().lower()
if key == "relu":
return nn.ReLU
if key == "silu":
return nn.SiLU
if key == "elu":
return nn.ELU
raise ValueError(f"Unsupported SC-TD3 activation: {name}")
def _as_arch_list(value, default: List[int]) -> List[int]:
if value is None:
return list(default)
return [int(v) for v in value]
class SCTD3Agent:
"""Structured Corridor TD3 agent wrapper."""
def __init__(
self,
state_dim: int,
action_dims: list,
learning_rate: float = 3e-4,
buffer_size: int = 100000,
learning_starts: int = 100,
batch_size: int = 64,
tau: float = 0.005,
gamma: float = 0.99,
policy_delay: int = 2,
exploration_sigma: float = 0.1,
device: str = "cuda",
actor_hidden_dims: Sequence[int] | None = None,
critic_hidden_dims: Sequence[int] | None = None,
edge_feature_dim: int = 3,
extractor_feature_dim: int = 128,
extractor_edge_hidden_dim: int = 16,
extractor_global_hidden_dim: int = 32,
extractor_spatial_blocks: int = 1,
extractor_kernel_size: int = 3,
activation_fn: str = "silu",
):
self.state_dim = state_dim
self.action_dims = action_dims
self.num_zones = len(action_dims)
self.device = device
self.learning_starts = learning_starts
self.total_steps = 0
self.exploration_sigma = exploration_sigma
dummy_env = MultiDiscreteWrapper(state_dim, action_dims)
action_noise = NormalActionNoise(
mean=np.zeros(self.num_zones),
sigma=0.1 * np.ones(self.num_zones),
)
policy_kwargs = {
"net_arch": {
"pi": _as_arch_list(actor_hidden_dims, [128, 64]),
"qf": _as_arch_list(critic_hidden_dims, [192, 128]),
},
"activation_fn": _resolve_activation_fn(activation_fn),
"features_extractor_class": EdgeStructuredExtractor,
"features_extractor_kwargs": {
"num_edges": self.num_zones,
"edge_feature_dim": edge_feature_dim,
"global_feature_dim": 4,
"edge_hidden_dim": extractor_edge_hidden_dim,
"global_hidden_dim": extractor_global_hidden_dim,
"spatial_blocks": extractor_spatial_blocks,
"kernel_size": extractor_kernel_size,
"features_dim": extractor_feature_dim,
},
"share_features_extractor": False,
}
self.model = TD3(
"MlpPolicy",
env=dummy_env,
learning_rate=learning_rate,
buffer_size=buffer_size,
learning_starts=learning_starts,
batch_size=batch_size,
tau=tau,
gamma=gamma,
policy_delay=policy_delay,
action_noise=action_noise,
device=device,
verbose=0,
policy_kwargs=policy_kwargs,
)
ensure_manual_logger(self.model)
def select_action(self, state: np.ndarray, deterministic: bool = False):
if not deterministic and self.total_steps < self.learning_starts:
discrete_action = np.array(
[np.random.randint(self.action_dims[i]) for i in range(self.num_zones)],
dtype=np.int64,
)
return discrete_action, 0.0, 0.0
continuous_action, _ = self.model.predict(state, deterministic=deterministic)
if not deterministic:
noise = np.random.normal(0.0, self.exploration_sigma, size=self.num_zones)
continuous_action = np.clip(continuous_action + noise, 0.0, 1.0)
discrete_action = np.array(
[
int(cont * (self.action_dims[i] - 1) + 0.5)
for i, cont in enumerate(continuous_action)
]
)
discrete_action = np.clip(discrete_action, 0, [d - 1 for d in self.action_dims])
return discrete_action, 0.0, 0.0
def store_transition(self, state, action, reward, next_state, done):
self.total_steps += 1
sync_manual_timesteps(self.model, self.total_steps)
continuous_action = np.array(
[action[i] / (self.action_dims[i] - 1) for i in range(self.num_zones)],
dtype=np.float32,
)
self.model.replay_buffer.add(
state, next_state, continuous_action, reward, done, [{}]
)
def update(self):
if self.model.replay_buffer.size() < self.model.batch_size:
return {}
self.model.train(gradient_steps=1, batch_size=self.model.batch_size)
return {"updates": float(self.model._n_updates)}
def save(self, path: str):
self.model.save(path)
def load(self, path: str):
self.model = TD3.load(path, device=self.device)
ensure_manual_logger(self.model)
self.total_steps = int(getattr(self.model, "num_timesteps", 0))