""" Structured Corridor TD3 (SC-TD3) for motorway VSL control. """ from __future__ import annotations from typing import List, Sequence import gymnasium as gym import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from gymnasium import spaces from stable_baselines3 import TD3 from stable_baselines3.common.noise import NormalActionNoise from stable_baselines3.common.torch_layers import BaseFeaturesExtractor from utils.sb3_manual import ensure_manual_logger, sync_manual_timesteps class MultiDiscreteWrapper(gym.Env): """Wrap a MultiDiscrete action space as a continuous Box for TD3.""" def __init__(self, state_dim: int, action_dims: Sequence[int]): super().__init__() self.state_dim = state_dim self.action_dims = action_dims self.num_zones = len(action_dims) self.observation_space = spaces.Box( low=-np.inf, high=np.inf, shape=(state_dim,), dtype=np.float32 ) self.action_space = spaces.Box( low=0.0, high=1.0, shape=(self.num_zones,), dtype=np.float32 ) def reset(self, seed=None, options=None): return np.zeros(self.state_dim, dtype=np.float32), {} def step(self, action): return np.zeros(self.state_dim, dtype=np.float32), 0.0, False, False, {} class EdgeLocalMixer(nn.Module): """Lightweight local mixer that preserves strong per-edge identity.""" def __init__(self, hidden_dim: int, kernel_size: int = 3): super().__init__() padding = (kernel_size - 1) // 2 self.norm = nn.LayerNorm(hidden_dim) self.depthwise = nn.Conv1d( hidden_dim, hidden_dim, kernel_size=kernel_size, padding=padding, groups=hidden_dim, ) self.pointwise = nn.Conv1d(hidden_dim, hidden_dim, kernel_size=1) self.ffn = nn.Sequential( nn.LayerNorm(hidden_dim), nn.Linear(hidden_dim, hidden_dim * 2), nn.SiLU(), nn.Linear(hidden_dim * 2, hidden_dim), ) def forward(self, x: torch.Tensor) -> torch.Tensor: y = self.norm(x).transpose(1, 2) y = F.silu(self.depthwise(y)) y = self.pointwise(y).transpose(1, 2) # Keep spatial mixing weak to avoid washing out local bottlenecks. x = x + 0.25 * y x = x + 0.25 * self.ffn(x) return x class EdgeStructuredExtractor(BaseFeaturesExtractor): """ Feature extractor tailored to motorway VSL. State layout is assumed to be: - per-edge traffic features: [speed_norm, occ_norm, flow_norm] * num_edges - per-edge current limit feature: [limit_norm] * num_edges - global features: [time_progress, sin_t, cos_t, last_reward] """ def __init__( self, observation_space: spaces.Box, num_edges: int, edge_feature_dim: int = 3, global_feature_dim: int = 4, total_edge_count: int | None = None, controlled_start_index: int = 0, edge_hidden_dim: int = 16, global_hidden_dim: int = 32, spatial_blocks: int = 1, kernel_size: int = 3, features_dim: int = 128, ): super().__init__(observation_space, features_dim) self.num_edges = num_edges self.total_edge_count = int(total_edge_count if total_edge_count is not None else num_edges) self.controlled_start_index = int(controlled_start_index) self.controlled_end_index = self.controlled_start_index + self.num_edges if self.controlled_end_index > self.total_edge_count: raise ValueError("controlled action slice exceeds total edge count") self.edge_feature_dim = edge_feature_dim self.edge_input_dim = edge_feature_dim + 1 self.edge_hidden_dim = edge_hidden_dim self.global_feature_dim = global_feature_dim self.edge_encoder = nn.Sequential( nn.Linear(self.edge_input_dim, edge_hidden_dim), nn.LayerNorm(edge_hidden_dim), nn.SiLU(), ) self.delta_encoder = nn.Sequential( nn.Linear(edge_feature_dim, edge_hidden_dim), nn.LayerNorm(edge_hidden_dim), nn.SiLU(), ) self.local_mixers = nn.ModuleList( [ EdgeLocalMixer( hidden_dim=edge_hidden_dim, kernel_size=kernel_size, ) for idx in range(spatial_blocks) ] ) self.global_encoder = nn.Sequential( nn.Linear(global_feature_dim + self.edge_input_dim * 2, global_hidden_dim), nn.LayerNorm(global_hidden_dim), nn.SiLU(), ) self.edge_gate = nn.Linear(global_hidden_dim, edge_hidden_dim) self.edge_bias = nn.Linear(global_hidden_dim, edge_hidden_dim) raw_input_dim = num_edges * self.edge_input_dim delta_input_dim = num_edges * edge_feature_dim fused_input_dim = num_edges * edge_hidden_dim + raw_input_dim + delta_input_dim + global_hidden_dim self.output_proj = nn.Sequential( nn.Linear(fused_input_dim, features_dim), nn.LayerNorm(features_dim), nn.SiLU(), ) self._init_weights() def _init_weights(self): for module in self.modules(): if isinstance(module, nn.Linear): nn.init.orthogonal_(module.weight, gain=np.sqrt(2)) nn.init.constant_(module.bias, 0.0) elif isinstance(module, nn.Conv1d): nn.init.kaiming_normal_(module.weight, nonlinearity="relu") if module.bias is not None: nn.init.constant_(module.bias, 0.0) def forward(self, observations: torch.Tensor) -> torch.Tensor: batch_size = observations.shape[0] edge_traffic_dim = self.total_edge_count * self.edge_feature_dim edge_limit_dim = self.total_edge_count edge_traffic = observations[:, :edge_traffic_dim].view( batch_size, self.total_edge_count, self.edge_feature_dim ) edge_traffic = edge_traffic[:, self.controlled_start_index:self.controlled_end_index, :] edge_limits = observations[ :, edge_traffic_dim : edge_traffic_dim + edge_limit_dim ].view(batch_size, self.total_edge_count, 1) edge_limits = edge_limits[:, self.controlled_start_index:self.controlled_end_index, :] global_features = observations[:, edge_traffic_dim + edge_limit_dim :] edge_inputs = torch.cat([edge_traffic, edge_limits], dim=-1) edge_latent = self.edge_encoder(edge_inputs) edge_deltas = torch.zeros_like(edge_traffic) edge_deltas[:, 1:, :] = edge_traffic[:, 1:, :] - edge_traffic[:, :-1, :] delta_latent = self.delta_encoder(edge_deltas) edge_latent = edge_latent + 0.5 * delta_latent for mixer in self.local_mixers: edge_latent = mixer(edge_latent) edge_inputs_mean = edge_inputs.mean(dim=1) edge_inputs_max = edge_inputs.max(dim=1).values global_latent = self.global_encoder( torch.cat([global_features, edge_inputs_mean, edge_inputs_max], dim=-1) ) gate = torch.tanh(self.edge_gate(global_latent)).unsqueeze(1) bias = torch.tanh(self.edge_bias(global_latent)).unsqueeze(1) edge_latent = edge_latent * (1.0 + gate) + bias fused = torch.cat( [ edge_latent.reshape(batch_size, -1), edge_inputs.reshape(batch_size, -1), edge_deltas.reshape(batch_size, -1), global_latent, ], dim=-1, ) return self.output_proj(fused) def _resolve_activation_fn(name: str): key = (name or "relu").strip().lower() if key == "relu": return nn.ReLU if key == "silu": return nn.SiLU if key == "elu": return nn.ELU raise ValueError(f"Unsupported SC-TD3 activation: {name}") def _as_arch_list(value, default: List[int]) -> List[int]: if value is None: return list(default) return [int(v) for v in value] class SCTD3Agent: """Structured Corridor TD3 agent wrapper.""" def __init__( self, state_dim: int, action_dims: list, learning_rate: float = 3e-4, buffer_size: int = 100000, learning_starts: int = 100, batch_size: int = 64, tau: float = 0.005, gamma: float = 0.99, policy_delay: int = 2, exploration_sigma: float = 0.1, device: str = "cuda", actor_hidden_dims: Sequence[int] | None = None, critic_hidden_dims: Sequence[int] | None = None, edge_feature_dim: int = 3, total_edge_count: int | None = None, controlled_start_index: int = 0, extractor_feature_dim: int = 128, extractor_edge_hidden_dim: int = 16, extractor_global_hidden_dim: int = 32, extractor_spatial_blocks: int = 1, extractor_kernel_size: int = 3, activation_fn: str = "relu", ): self.state_dim = state_dim self.action_dims = action_dims self.num_zones = len(action_dims) self.device = device self.learning_starts = learning_starts self.total_steps = 0 self.exploration_sigma = exploration_sigma dummy_env = MultiDiscreteWrapper(state_dim, action_dims) action_noise = NormalActionNoise( mean=np.zeros(self.num_zones), sigma=float(exploration_sigma) * np.ones(self.num_zones), ) policy_kwargs = { "net_arch": { "pi": _as_arch_list(actor_hidden_dims, [256, 256]), "qf": _as_arch_list(critic_hidden_dims, [256, 256]), }, "activation_fn": _resolve_activation_fn(activation_fn), "features_extractor_class": EdgeStructuredExtractor, "features_extractor_kwargs": { "num_edges": self.num_zones, "edge_feature_dim": edge_feature_dim, "global_feature_dim": 4, "total_edge_count": total_edge_count if total_edge_count is not None else self.num_zones, "controlled_start_index": controlled_start_index, "edge_hidden_dim": extractor_edge_hidden_dim, "global_hidden_dim": extractor_global_hidden_dim, "spatial_blocks": extractor_spatial_blocks, "kernel_size": extractor_kernel_size, "features_dim": extractor_feature_dim, }, "share_features_extractor": False, } self.model = TD3( "MlpPolicy", env=dummy_env, learning_rate=learning_rate, buffer_size=buffer_size, learning_starts=learning_starts, batch_size=batch_size, tau=tau, gamma=gamma, policy_delay=policy_delay, action_noise=action_noise, device=device, verbose=0, policy_kwargs=policy_kwargs, ) ensure_manual_logger(self.model) def select_action(self, state: np.ndarray, deterministic: bool = False): if not deterministic and self.total_steps < self.learning_starts: discrete_action = np.array( [np.random.randint(self.action_dims[i]) for i in range(self.num_zones)], dtype=np.int64, ) return discrete_action, 0.0, 0.0 continuous_action, _ = self.model.predict(state, deterministic=deterministic) if not deterministic: noise = np.random.normal(0.0, self.exploration_sigma, size=self.num_zones) continuous_action = np.clip(continuous_action + noise, 0.0, 1.0) discrete_action = np.array( [ int(cont * (self.action_dims[i] - 1) + 0.5) for i, cont in enumerate(continuous_action) ] ) discrete_action = np.clip(discrete_action, 0, [d - 1 for d in self.action_dims]) return discrete_action, 0.0, 0.0 def store_transition(self, state, action, reward, next_state, done): self.total_steps += 1 sync_manual_timesteps(self.model, self.total_steps) continuous_action = np.array( [action[i] / (self.action_dims[i] - 1) for i in range(self.num_zones)], dtype=np.float32, ) self.model.replay_buffer.add( state, next_state, continuous_action, reward, done, [{}] ) def update(self): if self.model.replay_buffer.size() < self.model.batch_size: return {} self.model.train(gradient_steps=1, batch_size=self.model.batch_size) return {"updates": float(self.model._n_updates)} def save(self, path: str): self.model.save(path) def load(self, path: str): self.model = TD3.load(path, device=self.device) ensure_manual_logger(self.model) self.total_steps = int(getattr(self.model, "num_timesteps", 0))