新增d3pg模型,更新热力图绘制

This commit is contained in:
Zihan Ye 2026-04-17 08:31:14 +08:00
parent 70de65d973
commit c0d0b3efd8
8 changed files with 248 additions and 16 deletions

175
agents/d3pg_agent.py Normal file
View File

@ -0,0 +1,175 @@
"""
D3PG: a pragmatic DDPG/TD3 hybrid baseline for motorway VSL.
Design:
- keep deterministic actor and simple continuous-action proxy from DDPG
- keep twin critics and delayed actor updates from TD3
- disable target policy smoothing because actions are ultimately snapped to
discrete speed-limit levels, so smoothed target actions add mismatch
without much benefit
"""
from __future__ import annotations
from typing import List, Sequence
import gymnasium as gym
import numpy as np
import torch.nn as nn
from gymnasium import spaces
from stable_baselines3 import TD3
from stable_baselines3.common.noise import NormalActionNoise
from stable_baselines3.common.torch_layers import FlattenExtractor
from utils.sb3_manual import ensure_manual_logger, sync_manual_timesteps
class MultiDiscreteWrapper(gym.Env):
def __init__(self, state_dim: int, action_dims: Sequence[int]):
super().__init__()
self.state_dim = state_dim
self.action_dims = action_dims
self.num_zones = len(action_dims)
self.observation_space = spaces.Box(
low=-np.inf,
high=np.inf,
shape=(state_dim,),
dtype=np.float32,
)
self.action_space = spaces.Box(
low=0.0,
high=1.0,
shape=(self.num_zones,),
dtype=np.float32,
)
def reset(self, seed=None, options=None):
return np.zeros(self.state_dim, dtype=np.float32), {}
def step(self, action):
return np.zeros(self.state_dim, dtype=np.float32), 0.0, False, False, {}
def _resolve_activation_fn(name: str):
key = (name or "relu").strip().lower()
if key == "relu":
return nn.ReLU
if key == "silu":
return nn.SiLU
if key == "elu":
return nn.ELU
raise ValueError(f"Unsupported D3PG activation: {name}")
def _as_arch_list(value, default: List[int]) -> List[int]:
if value is None:
return list(default)
return [int(v) for v in value]
class D3PGAgent:
"""Twin-critic delayed deterministic policy gradient without target smoothing."""
def __init__(
self,
state_dim: int,
action_dims: list,
learning_rate: float = 3e-4,
buffer_size: int = 100000,
learning_starts: int = 100,
batch_size: int = 64,
tau: float = 0.005,
gamma: float = 0.99,
policy_delay: int = 2,
exploration_sigma: float = 0.1,
device: str = "cuda",
actor_hidden_dims: Sequence[int] | None = None,
critic_hidden_dims: Sequence[int] | None = None,
activation_fn: str = "relu",
):
self.state_dim = state_dim
self.action_dims = action_dims
self.num_zones = len(action_dims)
self.device = device
self.learning_starts = learning_starts
self.total_steps = 0
self.exploration_sigma = exploration_sigma
dummy_env = MultiDiscreteWrapper(state_dim, action_dims)
action_noise = NormalActionNoise(
mean=np.zeros(self.num_zones),
sigma=float(exploration_sigma) * np.ones(self.num_zones),
)
policy_kwargs = {
"net_arch": {
"pi": _as_arch_list(actor_hidden_dims, [256, 256]),
"qf": _as_arch_list(critic_hidden_dims, [256, 256]),
},
"activation_fn": _resolve_activation_fn(activation_fn),
"features_extractor_class": FlattenExtractor,
}
self.model = TD3(
"MlpPolicy",
env=dummy_env,
learning_rate=learning_rate,
buffer_size=buffer_size,
learning_starts=learning_starts,
batch_size=batch_size,
tau=tau,
gamma=gamma,
policy_delay=policy_delay,
target_policy_noise=0.0,
target_noise_clip=0.0,
action_noise=action_noise,
device=device,
verbose=0,
policy_kwargs=policy_kwargs,
)
ensure_manual_logger(self.model)
def select_action(self, state: np.ndarray, deterministic: bool = False):
if not deterministic and self.total_steps < self.learning_starts:
discrete_action = np.array(
[np.random.randint(self.action_dims[i]) for i in range(self.num_zones)],
dtype=np.int64,
)
return discrete_action, 0.0, 0.0
continuous_action, _ = self.model.predict(state, deterministic=deterministic)
if not deterministic:
noise = np.random.normal(0.0, self.exploration_sigma, size=self.num_zones)
continuous_action = np.clip(continuous_action + noise, 0.0, 1.0)
discrete_action = np.array(
[
int(cont * (self.action_dims[i] - 1) + 0.5)
for i, cont in enumerate(continuous_action)
]
)
discrete_action = np.clip(discrete_action, 0, [d - 1 for d in self.action_dims])
return discrete_action, 0.0, 0.0
def store_transition(self, state, action, reward, next_state, done):
self.total_steps += 1
sync_manual_timesteps(self.model, self.total_steps)
continuous_action = np.array(
[action[i] / (self.action_dims[i] - 1) for i in range(self.num_zones)],
dtype=np.float32,
)
self.model.replay_buffer.add(
state, next_state, continuous_action, reward, done, [{}]
)
def update(self):
if self.model.replay_buffer.size() < self.model.batch_size:
return {}
self.model.train(gradient_steps=1, batch_size=self.model.batch_size)
return {"updates": float(self.model._n_updates)}
def save(self, path: str):
self.model.save(path)
def load(self, path: str):
self.model = TD3.load(path, device=self.device)
ensure_manual_logger(self.model)
self.total_steps = int(getattr(self.model, "num_timesteps", 0))

View File

@ -239,6 +239,19 @@ agents:
actor_hidden_dims: [256, 256] actor_hidden_dims: [256, 256]
critic_hidden_dims: [256, 256] critic_hidden_dims: [256, 256]
d3pg:
learning_rate: 0.0003
gamma: 0.99
buffer_size: 20000
learning_starts: 200
batch_size: 128
tau: 0.005
policy_delay: 2
exploration_sigma: 0.15
activation_fn: "relu"
actor_hidden_dims: [256, 256]
critic_hidden_dims: [256, 256]
sac: sac:
learning_rate: 0.0003 learning_rate: 0.0003
gamma: 0.99 gamma: 0.99

View File

@ -22,6 +22,7 @@ import matplotlib.pyplot as plt
from agents.appo_agent import APPOAgent from agents.appo_agent import APPOAgent
from agents.dcmappo_agent import DCMAPPOAgent from agents.dcmappo_agent import DCMAPPOAgent
from agents.d3pg_agent import D3PGAgent
from agents.ddqn_agent import DDQNAgent from agents.ddqn_agent import DDQNAgent
from agents.ddpg_agent import DDPGAgent from agents.ddpg_agent import DDPGAgent
from agents.dqn_agent import DQNAgent from agents.dqn_agent import DQNAgent
@ -47,7 +48,7 @@ from utils.heatmap_plotting import (
from utils.run_dirs import find_shared_config_path, resolve_checkpoint_root from utils.run_dirs import find_shared_config_path, resolve_checkpoint_root
MODEL_ORDER = ["ppo", "gpro", "appo", "mappo", "tcamappo", "dcmappo", "dqn", "madqn", "ddqn", "qmix", "dcqmix", "ddpg", "sac", "td3", "sctd3"] MODEL_ORDER = ["ppo", "gpro", "appo", "mappo", "tcamappo", "dcmappo", "dqn", "madqn", "ddqn", "qmix", "dcqmix", "ddpg", "d3pg", "sac", "td3", "sctd3"]
BASELINE_NAME = "no_control" BASELINE_NAME = "no_control"
EVAL_ORDER = [BASELINE_NAME] + MODEL_ORDER EVAL_ORDER = [BASELINE_NAME] + MODEL_ORDER
MODEL_LABELS = { MODEL_LABELS = {
@ -64,6 +65,7 @@ MODEL_LABELS = {
"qmix": "QMIX", "qmix": "QMIX",
"dcqmix": "DC-QMIX", "dcqmix": "DC-QMIX",
"ddpg": "DDPG", "ddpg": "DDPG",
"d3pg": "D3PG",
"sac": "SAC", "sac": "SAC",
"td3": "TD3", "td3": "TD3",
"sctd3": "SC-TD3", "sctd3": "SC-TD3",
@ -94,7 +96,7 @@ def parse_args():
"--models", "--models",
nargs="*", nargs="*",
default=None, default=None,
help="Subset of models to evaluate, e.g. --models ppo gpro tcamappo dcmappo dqn madqn ddqn qmix dcqmix sac td3 sctd3", help="Subset of models to evaluate, e.g. --models ppo gpro tcamappo dcmappo dqn madqn ddqn qmix dcqmix ddpg d3pg sac td3 sctd3",
) )
parser.add_argument("--seed", type=int, default=42, help="Evaluation seed.") parser.add_argument("--seed", type=int, default=42, help="Evaluation seed.")
parser.add_argument( parser.add_argument(
@ -453,8 +455,9 @@ def build_agent(model_name: str, config: dict, env: SUMOEdgeVSLEnvironment):
return build_value_based_agent(QMIXAgent, agent_cfg, env) return build_value_based_agent(QMIXAgent, agent_cfg, env)
if model_name == "dcqmix": if model_name == "dcqmix":
return build_value_based_agent(DCQMIXAgent, agent_cfg, env) return build_value_based_agent(DCQMIXAgent, agent_cfg, env)
if model_name == "ddpg": if model_name in {"ddpg", "d3pg"}:
return DDPGAgent( agent_cls = D3PGAgent if model_name == "d3pg" else DDPGAgent
common_kwargs = dict(
state_dim=env.state_dim, state_dim=env.state_dim,
action_dims=[env.action_dim] * env.num_controlled_edges, action_dims=[env.action_dim] * env.num_controlled_edges,
learning_rate=agent_cfg.get("learning_rate", 3e-4), learning_rate=agent_cfg.get("learning_rate", 3e-4),
@ -463,12 +466,21 @@ def build_agent(model_name: str, config: dict, env: SUMOEdgeVSLEnvironment):
batch_size=agent_cfg.get("batch_size", 128), batch_size=agent_cfg.get("batch_size", 128),
tau=agent_cfg.get("tau", 0.005), tau=agent_cfg.get("tau", 0.005),
gamma=agent_cfg.get("gamma", 0.99), gamma=agent_cfg.get("gamma", 0.99),
exploration_sigma=agent_cfg.get("exploration_sigma", 0.15),
device=agent_cfg.get("device", "cuda"), device=agent_cfg.get("device", "cuda"),
actor_hidden_dims=agent_cfg.get("actor_hidden_dims"), actor_hidden_dims=agent_cfg.get("actor_hidden_dims"),
critic_hidden_dims=agent_cfg.get("critic_hidden_dims"), critic_hidden_dims=agent_cfg.get("critic_hidden_dims"),
activation_fn=agent_cfg.get("activation_fn", "relu"), activation_fn=agent_cfg.get("activation_fn", "relu"),
) )
if model_name == "d3pg":
common_kwargs.update(
policy_delay=agent_cfg.get("policy_delay", 2),
exploration_sigma=agent_cfg.get("exploration_sigma", 0.15),
)
else:
common_kwargs.update(
exploration_sigma=agent_cfg.get("exploration_sigma", 0.15),
)
return agent_cls(**common_kwargs)
if model_name == "sac": if model_name == "sac":
return SACAgent( return SACAgent(
state_dim=env.state_dim, state_dim=env.state_dim,

View File

@ -20,7 +20,7 @@ from envs.reward_system import REWARD_COMPONENT_COLUMNS, REWARD_COMPONENT_LABELS
from utils.run_dirs import find_latest_run_root, find_run_root_by_timestamp from utils.run_dirs import find_latest_run_root, find_run_root_by_timestamp
MODEL_ORDER = ["ppo", "gpro", "appo", "mappo", "tcamappo", "dcmappo", "dqn", "madqn", "ddqn", "qmix", "dcqmix", "ddpg", "sac", "td3", "sctd3"] MODEL_ORDER = ["ppo", "gpro", "appo", "mappo", "tcamappo", "dcmappo", "dqn", "madqn", "ddqn", "qmix", "dcqmix", "ddpg", "d3pg", "sac", "td3", "sctd3"]
MODEL_LABELS = { MODEL_LABELS = {
"ppo": "PPO", "ppo": "PPO",
"gpro": "GPRO-PPO", "gpro": "GPRO-PPO",
@ -34,6 +34,7 @@ MODEL_LABELS = {
"qmix": "QMIX", "qmix": "QMIX",
"dcqmix": "DC-QMIX", "dcqmix": "DC-QMIX",
"ddpg": "DDPG", "ddpg": "DDPG",
"d3pg": "D3PG",
"sac": "SAC", "sac": "SAC",
"td3": "TD3", "td3": "TD3",
"sctd3": "SC-TD3", "sctd3": "SC-TD3",
@ -51,6 +52,7 @@ MODEL_COLORS = {
"qmix": "#8dd3c7", "qmix": "#8dd3c7",
"dcqmix": "#2b8cbe", "dcqmix": "#2b8cbe",
"ddpg": "#9467bd", "ddpg": "#9467bd",
"d3pg": "#7fc97f",
"sac": "#e377c2", "sac": "#e377c2",
"td3": "#17becf", "td3": "#17becf",
"sctd3": "#bcbd22", "sctd3": "#bcbd22",
@ -61,7 +63,7 @@ EFFICIENCY_LABEL = REWARD_COMPONENT_LABELS.get(EFFICIENCY_COLUMN, "Running Effic
def parse_args(): def parse_args():
parser = argparse.ArgumentParser(description="Plot live training progress from run logs.") parser = argparse.ArgumentParser(description="Plot live training progress from run logs.")
parser.add_argument("--model", default=None, help="Model name, e.g. ppo/gpro/appo/mappo/tcamappo/dcmappo/dqn/madqn/ddqn/qmix/dcqmix/ddpg/sac/td3/sctd3") parser.add_argument("--model", default=None, help="Model name, e.g. ppo/gpro/appo/mappo/tcamappo/dcmappo/dqn/madqn/ddqn/qmix/dcqmix/ddpg/d3pg/sac/td3/sctd3")
parser.add_argument( parser.add_argument(
"--all-models", "--all-models",
action="store_true", action="store_true",

View File

@ -4,6 +4,7 @@ from typing import Callable, Dict, List
from training.train_appo import train_sumo_appo from training.train_appo import train_sumo_appo
from training.train_dcmappo import train_sumo_dcmappo from training.train_dcmappo import train_sumo_dcmappo
from training.train_dcqmix import train_sumo_dcqmix from training.train_dcqmix import train_sumo_dcqmix
from training.train_d3pg import train_sumo_d3pg
from training.train_ddqn import train_sumo_ddqn from training.train_ddqn import train_sumo_ddqn
from training.train_ddpg import train_sumo_ddpg from training.train_ddpg import train_sumo_ddpg
from training.train_dqn import train_sumo_dqn from training.train_dqn import train_sumo_dqn
@ -19,8 +20,8 @@ from training.train_td3 import train_sumo_td3
# DEFAULT_MODELS: List[str] = ["ppo"] # DEFAULT_MODELS: List[str] = ["ppo"]
DEFAULT_MODELS: List[str] = ["ppo", "gpro", "mappo", "tcamappo", "dcmappo", "dqn", "madqn", "ddqn", "qmix", "dcqmix", "ddpg", "sac", "td3"] DEFAULT_MODELS: List[str] = ["ppo", "gpro", "mappo", "tcamappo", "dcmappo", "dqn", "madqn", "ddqn", "qmix", "dcqmix", "ddpg", "d3pg", "sac", "td3"]
ALL_MODELS: List[str] = ["ppo", "gpro", "appo", "mappo", "tcamappo", "dcmappo", "dqn", "madqn", "ddqn", "qmix", "dcqmix", "ddpg", "sac", "td3", "sctd3"] ALL_MODELS: List[str] = ["ppo", "gpro", "appo", "mappo", "tcamappo", "dcmappo", "dqn", "madqn", "ddqn", "qmix", "dcqmix", "ddpg", "d3pg", "sac", "td3", "sctd3"]
TRAINERS: Dict[str, Callable] = { TRAINERS: Dict[str, Callable] = {
@ -36,6 +37,7 @@ TRAINERS: Dict[str, Callable] = {
"qmix": train_sumo_qmix, "qmix": train_sumo_qmix,
"dcqmix": train_sumo_dcqmix, "dcqmix": train_sumo_dcqmix,
"ddpg": train_sumo_ddpg, "ddpg": train_sumo_ddpg,
"d3pg": train_sumo_d3pg,
"sac": train_sumo_sac, "sac": train_sumo_sac,
"td3": train_sumo_td3, "td3": train_sumo_td3,
"sctd3": train_sumo_sctd3, "sctd3": train_sumo_sctd3,

14
training/train_d3pg.py Normal file
View File

@ -0,0 +1,14 @@
from agents.d3pg_agent import D3PGAgent
from training.train_td3 import train_sumo_td3
def train_sumo_d3pg(log_dir=None, checkpoint_dir=None, run_timestamp=None):
return train_sumo_td3(
log_dir=log_dir,
checkpoint_dir=checkpoint_dir,
run_timestamp=run_timestamp,
model_name="d3pg",
config_key="d3pg",
display_name="D3PG",
agent_class=D3PGAgent,
)

View File

@ -82,13 +82,14 @@ def train_sumo_td3(
batch_size=agent_config.get("batch_size", 256), batch_size=agent_config.get("batch_size", 256),
tau=agent_config.get("tau", 0.005), tau=agent_config.get("tau", 0.005),
gamma=agent_config.get("gamma", 0.99), gamma=agent_config.get("gamma", 0.99),
policy_delay=agent_config.get("policy_delay", 2),
exploration_sigma=agent_config.get("exploration_sigma", 0.1), exploration_sigma=agent_config.get("exploration_sigma", 0.1),
device=agent_config.get("device", "cuda"), device=agent_config.get("device", "cuda"),
actor_hidden_dims=agent_config.get("actor_hidden_dims"), actor_hidden_dims=agent_config.get("actor_hidden_dims"),
critic_hidden_dims=agent_config.get("critic_hidden_dims"), critic_hidden_dims=agent_config.get("critic_hidden_dims"),
activation_fn=agent_config.get("activation_fn", "relu"), activation_fn=agent_config.get("activation_fn", "relu"),
) )
if "policy_delay" in agent_config:
common_kwargs["policy_delay"] = agent_config.get("policy_delay", 2)
if config_key == "sctd3": if config_key == "sctd3":
common_kwargs.update( common_kwargs.update(
edge_feature_dim=env.features_per_edge, edge_feature_dim=env.features_per_edge,

View File

@ -13,7 +13,7 @@ matplotlib.use("Agg")
from envs.reward_system import REWARD_COMPONENT_COLUMNS from envs.reward_system import REWARD_COMPONENT_COLUMNS
from utils.heatmap_plotting import ( from utils.heatmap_plotting import (
build_action_panel, build_action_panel,
build_density_panel, build_occupancy_panel,
build_speed_panel, build_speed_panel,
save_heatmap_panels, save_heatmap_panels,
) )
@ -179,6 +179,8 @@ def _build_detector_rows_from_xml(log_dir: str, episode: int) -> List[Dict]:
speed_weight_sum = 0.0 speed_weight_sum = 0.0
speed_weight_total = 0.0 speed_weight_total = 0.0
density_sum = 0.0 density_sum = 0.0
occupancy_sum = 0.0
occupancy_count = 0
for det_id in cell["detector_ids"]: for det_id in cell["detector_ids"]:
lane_metrics = interval_values.get(det_id) lane_metrics = interval_values.get(det_id)
@ -188,6 +190,11 @@ def _build_detector_rows_from_xml(log_dir: str, episode: int) -> List[Dict]:
speed_ms = lane_metrics["speed_ms"] speed_ms = lane_metrics["speed_ms"]
flow_vehph = max(lane_metrics["flow_vehph"], 0.0) flow_vehph = max(lane_metrics["flow_vehph"], 0.0)
n_veh_contrib = max(lane_metrics["n_veh_contrib"], 0.0) n_veh_contrib = max(lane_metrics["n_veh_contrib"], 0.0)
occupancy = max(lane_metrics["occupancy"], 0.0)
if np.isfinite(occupancy):
occupancy_sum += occupancy
occupancy_count += 1
if np.isfinite(speed_ms) and speed_ms > 0.0: if np.isfinite(speed_ms) and speed_ms > 0.0:
speed_weight = n_veh_contrib if n_veh_contrib > 0.0 else flow_vehph speed_weight = n_veh_contrib if n_veh_contrib > 0.0 else flow_vehph
@ -202,6 +209,11 @@ def _build_detector_rows_from_xml(log_dir: str, episode: int) -> List[Dict]:
if speed_weight_total > 0.0 if speed_weight_total > 0.0
else np.nan else np.nan
) )
occupancy_pct = (
occupancy_sum / occupancy_count
if occupancy_count > 0
else np.nan
)
detector_rows.append( detector_rows.append(
{ {
@ -213,6 +225,7 @@ def _build_detector_rows_from_xml(log_dir: str, episode: int) -> List[Dict]:
"pos_index": int(cell["pos_index"]), "pos_index": int(cell["pos_index"]),
"position_m": float(cell["position_m"]), "position_m": float(cell["position_m"]),
"measured_speed_kmh": float(measured_speed_kmh) if np.isfinite(measured_speed_kmh) else np.nan, "measured_speed_kmh": float(measured_speed_kmh) if np.isfinite(measured_speed_kmh) else np.nan,
"occupancy": float(occupancy_pct) if np.isfinite(occupancy_pct) else np.nan,
"density_vehpkm": float(density_sum), "density_vehpkm": float(density_sum),
} }
) )
@ -321,13 +334,13 @@ def _plot_episode_heatmap(
num_cells = len(ordered_cells) num_cells = len(ordered_cells)
cell_to_row = {cell_id: idx for idx, cell_id in enumerate(ordered_cells)} cell_to_row = {cell_id: idx for idx, cell_id in enumerate(ordered_cells)}
speed_grid = np.full((num_cells, num_steps), np.nan, dtype=np.float32) speed_grid = np.full((num_cells, num_steps), np.nan, dtype=np.float32)
density_grid = np.full((num_cells, num_steps), np.nan, dtype=np.float32) occupancy_grid = np.full((num_cells, num_steps), np.nan, dtype=np.float32)
for row in detector_rows: for row in detector_rows:
row_idx = cell_to_row[str(row["cell_id"])] row_idx = cell_to_row[str(row["cell_id"])]
col_idx = step_to_col[int(row["step"])] col_idx = step_to_col[int(row["step"])]
speed_grid[row_idx, col_idx] = _safe_float(row["measured_speed_kmh"]) speed_grid[row_idx, col_idx] = _safe_float(row["measured_speed_kmh"])
density_grid[row_idx, col_idx] = _safe_float(row["density_vehpkm"]) occupancy_grid[row_idx, col_idx] = _safe_float(row["occupancy"])
plots = [ plots = [
build_speed_panel( build_speed_panel(
@ -336,10 +349,10 @@ def _plot_episode_heatmap(
f"{title_prefix} Measured Speed (km/h)", f"{title_prefix} Measured Speed (km/h)",
"Detector Cell", "Detector Cell",
), ),
build_density_panel( build_occupancy_panel(
density_grid, occupancy_grid,
ordered_cells, ordered_cells,
f"{title_prefix} Density (veh/km)", f"{title_prefix} Occupancy (%)",
"Detector Cell", "Detector Cell",
), ),
] ]