新增d3pg模型,更新热力图绘制
This commit is contained in:
parent
70de65d973
commit
c0d0b3efd8
|
|
@ -0,0 +1,175 @@
|
||||||
|
"""
|
||||||
|
D3PG: a pragmatic DDPG/TD3 hybrid baseline for motorway VSL.
|
||||||
|
|
||||||
|
Design:
|
||||||
|
- keep deterministic actor and simple continuous-action proxy from DDPG
|
||||||
|
- keep twin critics and delayed actor updates from TD3
|
||||||
|
- disable target policy smoothing because actions are ultimately snapped to
|
||||||
|
discrete speed-limit levels, so smoothed target actions add mismatch
|
||||||
|
without much benefit
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import List, Sequence
|
||||||
|
|
||||||
|
import gymnasium as gym
|
||||||
|
import numpy as np
|
||||||
|
import torch.nn as nn
|
||||||
|
from gymnasium import spaces
|
||||||
|
from stable_baselines3 import TD3
|
||||||
|
from stable_baselines3.common.noise import NormalActionNoise
|
||||||
|
from stable_baselines3.common.torch_layers import FlattenExtractor
|
||||||
|
|
||||||
|
from utils.sb3_manual import ensure_manual_logger, sync_manual_timesteps
|
||||||
|
|
||||||
|
|
||||||
|
class MultiDiscreteWrapper(gym.Env):
|
||||||
|
def __init__(self, state_dim: int, action_dims: Sequence[int]):
|
||||||
|
super().__init__()
|
||||||
|
self.state_dim = state_dim
|
||||||
|
self.action_dims = action_dims
|
||||||
|
self.num_zones = len(action_dims)
|
||||||
|
self.observation_space = spaces.Box(
|
||||||
|
low=-np.inf,
|
||||||
|
high=np.inf,
|
||||||
|
shape=(state_dim,),
|
||||||
|
dtype=np.float32,
|
||||||
|
)
|
||||||
|
self.action_space = spaces.Box(
|
||||||
|
low=0.0,
|
||||||
|
high=1.0,
|
||||||
|
shape=(self.num_zones,),
|
||||||
|
dtype=np.float32,
|
||||||
|
)
|
||||||
|
|
||||||
|
def reset(self, seed=None, options=None):
|
||||||
|
return np.zeros(self.state_dim, dtype=np.float32), {}
|
||||||
|
|
||||||
|
def step(self, action):
|
||||||
|
return np.zeros(self.state_dim, dtype=np.float32), 0.0, False, False, {}
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_activation_fn(name: str):
|
||||||
|
key = (name or "relu").strip().lower()
|
||||||
|
if key == "relu":
|
||||||
|
return nn.ReLU
|
||||||
|
if key == "silu":
|
||||||
|
return nn.SiLU
|
||||||
|
if key == "elu":
|
||||||
|
return nn.ELU
|
||||||
|
raise ValueError(f"Unsupported D3PG activation: {name}")
|
||||||
|
|
||||||
|
|
||||||
|
def _as_arch_list(value, default: List[int]) -> List[int]:
|
||||||
|
if value is None:
|
||||||
|
return list(default)
|
||||||
|
return [int(v) for v in value]
|
||||||
|
|
||||||
|
|
||||||
|
class D3PGAgent:
|
||||||
|
"""Twin-critic delayed deterministic policy gradient without target smoothing."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
state_dim: int,
|
||||||
|
action_dims: list,
|
||||||
|
learning_rate: float = 3e-4,
|
||||||
|
buffer_size: int = 100000,
|
||||||
|
learning_starts: int = 100,
|
||||||
|
batch_size: int = 64,
|
||||||
|
tau: float = 0.005,
|
||||||
|
gamma: float = 0.99,
|
||||||
|
policy_delay: int = 2,
|
||||||
|
exploration_sigma: float = 0.1,
|
||||||
|
device: str = "cuda",
|
||||||
|
actor_hidden_dims: Sequence[int] | None = None,
|
||||||
|
critic_hidden_dims: Sequence[int] | None = None,
|
||||||
|
activation_fn: str = "relu",
|
||||||
|
):
|
||||||
|
self.state_dim = state_dim
|
||||||
|
self.action_dims = action_dims
|
||||||
|
self.num_zones = len(action_dims)
|
||||||
|
self.device = device
|
||||||
|
self.learning_starts = learning_starts
|
||||||
|
self.total_steps = 0
|
||||||
|
self.exploration_sigma = exploration_sigma
|
||||||
|
|
||||||
|
dummy_env = MultiDiscreteWrapper(state_dim, action_dims)
|
||||||
|
action_noise = NormalActionNoise(
|
||||||
|
mean=np.zeros(self.num_zones),
|
||||||
|
sigma=float(exploration_sigma) * np.ones(self.num_zones),
|
||||||
|
)
|
||||||
|
policy_kwargs = {
|
||||||
|
"net_arch": {
|
||||||
|
"pi": _as_arch_list(actor_hidden_dims, [256, 256]),
|
||||||
|
"qf": _as_arch_list(critic_hidden_dims, [256, 256]),
|
||||||
|
},
|
||||||
|
"activation_fn": _resolve_activation_fn(activation_fn),
|
||||||
|
"features_extractor_class": FlattenExtractor,
|
||||||
|
}
|
||||||
|
|
||||||
|
self.model = TD3(
|
||||||
|
"MlpPolicy",
|
||||||
|
env=dummy_env,
|
||||||
|
learning_rate=learning_rate,
|
||||||
|
buffer_size=buffer_size,
|
||||||
|
learning_starts=learning_starts,
|
||||||
|
batch_size=batch_size,
|
||||||
|
tau=tau,
|
||||||
|
gamma=gamma,
|
||||||
|
policy_delay=policy_delay,
|
||||||
|
target_policy_noise=0.0,
|
||||||
|
target_noise_clip=0.0,
|
||||||
|
action_noise=action_noise,
|
||||||
|
device=device,
|
||||||
|
verbose=0,
|
||||||
|
policy_kwargs=policy_kwargs,
|
||||||
|
)
|
||||||
|
ensure_manual_logger(self.model)
|
||||||
|
|
||||||
|
def select_action(self, state: np.ndarray, deterministic: bool = False):
|
||||||
|
if not deterministic and self.total_steps < self.learning_starts:
|
||||||
|
discrete_action = np.array(
|
||||||
|
[np.random.randint(self.action_dims[i]) for i in range(self.num_zones)],
|
||||||
|
dtype=np.int64,
|
||||||
|
)
|
||||||
|
return discrete_action, 0.0, 0.0
|
||||||
|
|
||||||
|
continuous_action, _ = self.model.predict(state, deterministic=deterministic)
|
||||||
|
if not deterministic:
|
||||||
|
noise = np.random.normal(0.0, self.exploration_sigma, size=self.num_zones)
|
||||||
|
continuous_action = np.clip(continuous_action + noise, 0.0, 1.0)
|
||||||
|
|
||||||
|
discrete_action = np.array(
|
||||||
|
[
|
||||||
|
int(cont * (self.action_dims[i] - 1) + 0.5)
|
||||||
|
for i, cont in enumerate(continuous_action)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
discrete_action = np.clip(discrete_action, 0, [d - 1 for d in self.action_dims])
|
||||||
|
return discrete_action, 0.0, 0.0
|
||||||
|
|
||||||
|
def store_transition(self, state, action, reward, next_state, done):
|
||||||
|
self.total_steps += 1
|
||||||
|
sync_manual_timesteps(self.model, self.total_steps)
|
||||||
|
continuous_action = np.array(
|
||||||
|
[action[i] / (self.action_dims[i] - 1) for i in range(self.num_zones)],
|
||||||
|
dtype=np.float32,
|
||||||
|
)
|
||||||
|
self.model.replay_buffer.add(
|
||||||
|
state, next_state, continuous_action, reward, done, [{}]
|
||||||
|
)
|
||||||
|
|
||||||
|
def update(self):
|
||||||
|
if self.model.replay_buffer.size() < self.model.batch_size:
|
||||||
|
return {}
|
||||||
|
self.model.train(gradient_steps=1, batch_size=self.model.batch_size)
|
||||||
|
return {"updates": float(self.model._n_updates)}
|
||||||
|
|
||||||
|
def save(self, path: str):
|
||||||
|
self.model.save(path)
|
||||||
|
|
||||||
|
def load(self, path: str):
|
||||||
|
self.model = TD3.load(path, device=self.device)
|
||||||
|
ensure_manual_logger(self.model)
|
||||||
|
self.total_steps = int(getattr(self.model, "num_timesteps", 0))
|
||||||
|
|
@ -239,6 +239,19 @@ agents:
|
||||||
actor_hidden_dims: [256, 256]
|
actor_hidden_dims: [256, 256]
|
||||||
critic_hidden_dims: [256, 256]
|
critic_hidden_dims: [256, 256]
|
||||||
|
|
||||||
|
d3pg:
|
||||||
|
learning_rate: 0.0003
|
||||||
|
gamma: 0.99
|
||||||
|
buffer_size: 20000
|
||||||
|
learning_starts: 200
|
||||||
|
batch_size: 128
|
||||||
|
tau: 0.005
|
||||||
|
policy_delay: 2
|
||||||
|
exploration_sigma: 0.15
|
||||||
|
activation_fn: "relu"
|
||||||
|
actor_hidden_dims: [256, 256]
|
||||||
|
critic_hidden_dims: [256, 256]
|
||||||
|
|
||||||
sac:
|
sac:
|
||||||
learning_rate: 0.0003
|
learning_rate: 0.0003
|
||||||
gamma: 0.99
|
gamma: 0.99
|
||||||
|
|
|
||||||
|
|
@ -22,6 +22,7 @@ import matplotlib.pyplot as plt
|
||||||
|
|
||||||
from agents.appo_agent import APPOAgent
|
from agents.appo_agent import APPOAgent
|
||||||
from agents.dcmappo_agent import DCMAPPOAgent
|
from agents.dcmappo_agent import DCMAPPOAgent
|
||||||
|
from agents.d3pg_agent import D3PGAgent
|
||||||
from agents.ddqn_agent import DDQNAgent
|
from agents.ddqn_agent import DDQNAgent
|
||||||
from agents.ddpg_agent import DDPGAgent
|
from agents.ddpg_agent import DDPGAgent
|
||||||
from agents.dqn_agent import DQNAgent
|
from agents.dqn_agent import DQNAgent
|
||||||
|
|
@ -47,7 +48,7 @@ from utils.heatmap_plotting import (
|
||||||
from utils.run_dirs import find_shared_config_path, resolve_checkpoint_root
|
from utils.run_dirs import find_shared_config_path, resolve_checkpoint_root
|
||||||
|
|
||||||
|
|
||||||
MODEL_ORDER = ["ppo", "gpro", "appo", "mappo", "tcamappo", "dcmappo", "dqn", "madqn", "ddqn", "qmix", "dcqmix", "ddpg", "sac", "td3", "sctd3"]
|
MODEL_ORDER = ["ppo", "gpro", "appo", "mappo", "tcamappo", "dcmappo", "dqn", "madqn", "ddqn", "qmix", "dcqmix", "ddpg", "d3pg", "sac", "td3", "sctd3"]
|
||||||
BASELINE_NAME = "no_control"
|
BASELINE_NAME = "no_control"
|
||||||
EVAL_ORDER = [BASELINE_NAME] + MODEL_ORDER
|
EVAL_ORDER = [BASELINE_NAME] + MODEL_ORDER
|
||||||
MODEL_LABELS = {
|
MODEL_LABELS = {
|
||||||
|
|
@ -64,6 +65,7 @@ MODEL_LABELS = {
|
||||||
"qmix": "QMIX",
|
"qmix": "QMIX",
|
||||||
"dcqmix": "DC-QMIX",
|
"dcqmix": "DC-QMIX",
|
||||||
"ddpg": "DDPG",
|
"ddpg": "DDPG",
|
||||||
|
"d3pg": "D3PG",
|
||||||
"sac": "SAC",
|
"sac": "SAC",
|
||||||
"td3": "TD3",
|
"td3": "TD3",
|
||||||
"sctd3": "SC-TD3",
|
"sctd3": "SC-TD3",
|
||||||
|
|
@ -94,7 +96,7 @@ def parse_args():
|
||||||
"--models",
|
"--models",
|
||||||
nargs="*",
|
nargs="*",
|
||||||
default=None,
|
default=None,
|
||||||
help="Subset of models to evaluate, e.g. --models ppo gpro tcamappo dcmappo dqn madqn ddqn qmix dcqmix sac td3 sctd3",
|
help="Subset of models to evaluate, e.g. --models ppo gpro tcamappo dcmappo dqn madqn ddqn qmix dcqmix ddpg d3pg sac td3 sctd3",
|
||||||
)
|
)
|
||||||
parser.add_argument("--seed", type=int, default=42, help="Evaluation seed.")
|
parser.add_argument("--seed", type=int, default=42, help="Evaluation seed.")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
|
|
@ -453,8 +455,9 @@ def build_agent(model_name: str, config: dict, env: SUMOEdgeVSLEnvironment):
|
||||||
return build_value_based_agent(QMIXAgent, agent_cfg, env)
|
return build_value_based_agent(QMIXAgent, agent_cfg, env)
|
||||||
if model_name == "dcqmix":
|
if model_name == "dcqmix":
|
||||||
return build_value_based_agent(DCQMIXAgent, agent_cfg, env)
|
return build_value_based_agent(DCQMIXAgent, agent_cfg, env)
|
||||||
if model_name == "ddpg":
|
if model_name in {"ddpg", "d3pg"}:
|
||||||
return DDPGAgent(
|
agent_cls = D3PGAgent if model_name == "d3pg" else DDPGAgent
|
||||||
|
common_kwargs = dict(
|
||||||
state_dim=env.state_dim,
|
state_dim=env.state_dim,
|
||||||
action_dims=[env.action_dim] * env.num_controlled_edges,
|
action_dims=[env.action_dim] * env.num_controlled_edges,
|
||||||
learning_rate=agent_cfg.get("learning_rate", 3e-4),
|
learning_rate=agent_cfg.get("learning_rate", 3e-4),
|
||||||
|
|
@ -463,12 +466,21 @@ def build_agent(model_name: str, config: dict, env: SUMOEdgeVSLEnvironment):
|
||||||
batch_size=agent_cfg.get("batch_size", 128),
|
batch_size=agent_cfg.get("batch_size", 128),
|
||||||
tau=agent_cfg.get("tau", 0.005),
|
tau=agent_cfg.get("tau", 0.005),
|
||||||
gamma=agent_cfg.get("gamma", 0.99),
|
gamma=agent_cfg.get("gamma", 0.99),
|
||||||
exploration_sigma=agent_cfg.get("exploration_sigma", 0.15),
|
|
||||||
device=agent_cfg.get("device", "cuda"),
|
device=agent_cfg.get("device", "cuda"),
|
||||||
actor_hidden_dims=agent_cfg.get("actor_hidden_dims"),
|
actor_hidden_dims=agent_cfg.get("actor_hidden_dims"),
|
||||||
critic_hidden_dims=agent_cfg.get("critic_hidden_dims"),
|
critic_hidden_dims=agent_cfg.get("critic_hidden_dims"),
|
||||||
activation_fn=agent_cfg.get("activation_fn", "relu"),
|
activation_fn=agent_cfg.get("activation_fn", "relu"),
|
||||||
)
|
)
|
||||||
|
if model_name == "d3pg":
|
||||||
|
common_kwargs.update(
|
||||||
|
policy_delay=agent_cfg.get("policy_delay", 2),
|
||||||
|
exploration_sigma=agent_cfg.get("exploration_sigma", 0.15),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
common_kwargs.update(
|
||||||
|
exploration_sigma=agent_cfg.get("exploration_sigma", 0.15),
|
||||||
|
)
|
||||||
|
return agent_cls(**common_kwargs)
|
||||||
if model_name == "sac":
|
if model_name == "sac":
|
||||||
return SACAgent(
|
return SACAgent(
|
||||||
state_dim=env.state_dim,
|
state_dim=env.state_dim,
|
||||||
|
|
|
||||||
|
|
@ -20,7 +20,7 @@ from envs.reward_system import REWARD_COMPONENT_COLUMNS, REWARD_COMPONENT_LABELS
|
||||||
from utils.run_dirs import find_latest_run_root, find_run_root_by_timestamp
|
from utils.run_dirs import find_latest_run_root, find_run_root_by_timestamp
|
||||||
|
|
||||||
|
|
||||||
MODEL_ORDER = ["ppo", "gpro", "appo", "mappo", "tcamappo", "dcmappo", "dqn", "madqn", "ddqn", "qmix", "dcqmix", "ddpg", "sac", "td3", "sctd3"]
|
MODEL_ORDER = ["ppo", "gpro", "appo", "mappo", "tcamappo", "dcmappo", "dqn", "madqn", "ddqn", "qmix", "dcqmix", "ddpg", "d3pg", "sac", "td3", "sctd3"]
|
||||||
MODEL_LABELS = {
|
MODEL_LABELS = {
|
||||||
"ppo": "PPO",
|
"ppo": "PPO",
|
||||||
"gpro": "GPRO-PPO",
|
"gpro": "GPRO-PPO",
|
||||||
|
|
@ -34,6 +34,7 @@ MODEL_LABELS = {
|
||||||
"qmix": "QMIX",
|
"qmix": "QMIX",
|
||||||
"dcqmix": "DC-QMIX",
|
"dcqmix": "DC-QMIX",
|
||||||
"ddpg": "DDPG",
|
"ddpg": "DDPG",
|
||||||
|
"d3pg": "D3PG",
|
||||||
"sac": "SAC",
|
"sac": "SAC",
|
||||||
"td3": "TD3",
|
"td3": "TD3",
|
||||||
"sctd3": "SC-TD3",
|
"sctd3": "SC-TD3",
|
||||||
|
|
@ -51,6 +52,7 @@ MODEL_COLORS = {
|
||||||
"qmix": "#8dd3c7",
|
"qmix": "#8dd3c7",
|
||||||
"dcqmix": "#2b8cbe",
|
"dcqmix": "#2b8cbe",
|
||||||
"ddpg": "#9467bd",
|
"ddpg": "#9467bd",
|
||||||
|
"d3pg": "#7fc97f",
|
||||||
"sac": "#e377c2",
|
"sac": "#e377c2",
|
||||||
"td3": "#17becf",
|
"td3": "#17becf",
|
||||||
"sctd3": "#bcbd22",
|
"sctd3": "#bcbd22",
|
||||||
|
|
@ -61,7 +63,7 @@ EFFICIENCY_LABEL = REWARD_COMPONENT_LABELS.get(EFFICIENCY_COLUMN, "Running Effic
|
||||||
|
|
||||||
def parse_args():
|
def parse_args():
|
||||||
parser = argparse.ArgumentParser(description="Plot live training progress from run logs.")
|
parser = argparse.ArgumentParser(description="Plot live training progress from run logs.")
|
||||||
parser.add_argument("--model", default=None, help="Model name, e.g. ppo/gpro/appo/mappo/tcamappo/dcmappo/dqn/madqn/ddqn/qmix/dcqmix/ddpg/sac/td3/sctd3")
|
parser.add_argument("--model", default=None, help="Model name, e.g. ppo/gpro/appo/mappo/tcamappo/dcmappo/dqn/madqn/ddqn/qmix/dcqmix/ddpg/d3pg/sac/td3/sctd3")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--all-models",
|
"--all-models",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
|
|
|
||||||
|
|
@ -4,6 +4,7 @@ from typing import Callable, Dict, List
|
||||||
from training.train_appo import train_sumo_appo
|
from training.train_appo import train_sumo_appo
|
||||||
from training.train_dcmappo import train_sumo_dcmappo
|
from training.train_dcmappo import train_sumo_dcmappo
|
||||||
from training.train_dcqmix import train_sumo_dcqmix
|
from training.train_dcqmix import train_sumo_dcqmix
|
||||||
|
from training.train_d3pg import train_sumo_d3pg
|
||||||
from training.train_ddqn import train_sumo_ddqn
|
from training.train_ddqn import train_sumo_ddqn
|
||||||
from training.train_ddpg import train_sumo_ddpg
|
from training.train_ddpg import train_sumo_ddpg
|
||||||
from training.train_dqn import train_sumo_dqn
|
from training.train_dqn import train_sumo_dqn
|
||||||
|
|
@ -19,8 +20,8 @@ from training.train_td3 import train_sumo_td3
|
||||||
|
|
||||||
|
|
||||||
# DEFAULT_MODELS: List[str] = ["ppo"]
|
# DEFAULT_MODELS: List[str] = ["ppo"]
|
||||||
DEFAULT_MODELS: List[str] = ["ppo", "gpro", "mappo", "tcamappo", "dcmappo", "dqn", "madqn", "ddqn", "qmix", "dcqmix", "ddpg", "sac", "td3"]
|
DEFAULT_MODELS: List[str] = ["ppo", "gpro", "mappo", "tcamappo", "dcmappo", "dqn", "madqn", "ddqn", "qmix", "dcqmix", "ddpg", "d3pg", "sac", "td3"]
|
||||||
ALL_MODELS: List[str] = ["ppo", "gpro", "appo", "mappo", "tcamappo", "dcmappo", "dqn", "madqn", "ddqn", "qmix", "dcqmix", "ddpg", "sac", "td3", "sctd3"]
|
ALL_MODELS: List[str] = ["ppo", "gpro", "appo", "mappo", "tcamappo", "dcmappo", "dqn", "madqn", "ddqn", "qmix", "dcqmix", "ddpg", "d3pg", "sac", "td3", "sctd3"]
|
||||||
|
|
||||||
|
|
||||||
TRAINERS: Dict[str, Callable] = {
|
TRAINERS: Dict[str, Callable] = {
|
||||||
|
|
@ -36,6 +37,7 @@ TRAINERS: Dict[str, Callable] = {
|
||||||
"qmix": train_sumo_qmix,
|
"qmix": train_sumo_qmix,
|
||||||
"dcqmix": train_sumo_dcqmix,
|
"dcqmix": train_sumo_dcqmix,
|
||||||
"ddpg": train_sumo_ddpg,
|
"ddpg": train_sumo_ddpg,
|
||||||
|
"d3pg": train_sumo_d3pg,
|
||||||
"sac": train_sumo_sac,
|
"sac": train_sumo_sac,
|
||||||
"td3": train_sumo_td3,
|
"td3": train_sumo_td3,
|
||||||
"sctd3": train_sumo_sctd3,
|
"sctd3": train_sumo_sctd3,
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,14 @@
|
||||||
|
from agents.d3pg_agent import D3PGAgent
|
||||||
|
from training.train_td3 import train_sumo_td3
|
||||||
|
|
||||||
|
|
||||||
|
def train_sumo_d3pg(log_dir=None, checkpoint_dir=None, run_timestamp=None):
|
||||||
|
return train_sumo_td3(
|
||||||
|
log_dir=log_dir,
|
||||||
|
checkpoint_dir=checkpoint_dir,
|
||||||
|
run_timestamp=run_timestamp,
|
||||||
|
model_name="d3pg",
|
||||||
|
config_key="d3pg",
|
||||||
|
display_name="D3PG",
|
||||||
|
agent_class=D3PGAgent,
|
||||||
|
)
|
||||||
|
|
@ -82,13 +82,14 @@ def train_sumo_td3(
|
||||||
batch_size=agent_config.get("batch_size", 256),
|
batch_size=agent_config.get("batch_size", 256),
|
||||||
tau=agent_config.get("tau", 0.005),
|
tau=agent_config.get("tau", 0.005),
|
||||||
gamma=agent_config.get("gamma", 0.99),
|
gamma=agent_config.get("gamma", 0.99),
|
||||||
policy_delay=agent_config.get("policy_delay", 2),
|
|
||||||
exploration_sigma=agent_config.get("exploration_sigma", 0.1),
|
exploration_sigma=agent_config.get("exploration_sigma", 0.1),
|
||||||
device=agent_config.get("device", "cuda"),
|
device=agent_config.get("device", "cuda"),
|
||||||
actor_hidden_dims=agent_config.get("actor_hidden_dims"),
|
actor_hidden_dims=agent_config.get("actor_hidden_dims"),
|
||||||
critic_hidden_dims=agent_config.get("critic_hidden_dims"),
|
critic_hidden_dims=agent_config.get("critic_hidden_dims"),
|
||||||
activation_fn=agent_config.get("activation_fn", "relu"),
|
activation_fn=agent_config.get("activation_fn", "relu"),
|
||||||
)
|
)
|
||||||
|
if "policy_delay" in agent_config:
|
||||||
|
common_kwargs["policy_delay"] = agent_config.get("policy_delay", 2)
|
||||||
if config_key == "sctd3":
|
if config_key == "sctd3":
|
||||||
common_kwargs.update(
|
common_kwargs.update(
|
||||||
edge_feature_dim=env.features_per_edge,
|
edge_feature_dim=env.features_per_edge,
|
||||||
|
|
|
||||||
|
|
@ -13,7 +13,7 @@ matplotlib.use("Agg")
|
||||||
from envs.reward_system import REWARD_COMPONENT_COLUMNS
|
from envs.reward_system import REWARD_COMPONENT_COLUMNS
|
||||||
from utils.heatmap_plotting import (
|
from utils.heatmap_plotting import (
|
||||||
build_action_panel,
|
build_action_panel,
|
||||||
build_density_panel,
|
build_occupancy_panel,
|
||||||
build_speed_panel,
|
build_speed_panel,
|
||||||
save_heatmap_panels,
|
save_heatmap_panels,
|
||||||
)
|
)
|
||||||
|
|
@ -179,6 +179,8 @@ def _build_detector_rows_from_xml(log_dir: str, episode: int) -> List[Dict]:
|
||||||
speed_weight_sum = 0.0
|
speed_weight_sum = 0.0
|
||||||
speed_weight_total = 0.0
|
speed_weight_total = 0.0
|
||||||
density_sum = 0.0
|
density_sum = 0.0
|
||||||
|
occupancy_sum = 0.0
|
||||||
|
occupancy_count = 0
|
||||||
|
|
||||||
for det_id in cell["detector_ids"]:
|
for det_id in cell["detector_ids"]:
|
||||||
lane_metrics = interval_values.get(det_id)
|
lane_metrics = interval_values.get(det_id)
|
||||||
|
|
@ -188,6 +190,11 @@ def _build_detector_rows_from_xml(log_dir: str, episode: int) -> List[Dict]:
|
||||||
speed_ms = lane_metrics["speed_ms"]
|
speed_ms = lane_metrics["speed_ms"]
|
||||||
flow_vehph = max(lane_metrics["flow_vehph"], 0.0)
|
flow_vehph = max(lane_metrics["flow_vehph"], 0.0)
|
||||||
n_veh_contrib = max(lane_metrics["n_veh_contrib"], 0.0)
|
n_veh_contrib = max(lane_metrics["n_veh_contrib"], 0.0)
|
||||||
|
occupancy = max(lane_metrics["occupancy"], 0.0)
|
||||||
|
|
||||||
|
if np.isfinite(occupancy):
|
||||||
|
occupancy_sum += occupancy
|
||||||
|
occupancy_count += 1
|
||||||
|
|
||||||
if np.isfinite(speed_ms) and speed_ms > 0.0:
|
if np.isfinite(speed_ms) and speed_ms > 0.0:
|
||||||
speed_weight = n_veh_contrib if n_veh_contrib > 0.0 else flow_vehph
|
speed_weight = n_veh_contrib if n_veh_contrib > 0.0 else flow_vehph
|
||||||
|
|
@ -202,6 +209,11 @@ def _build_detector_rows_from_xml(log_dir: str, episode: int) -> List[Dict]:
|
||||||
if speed_weight_total > 0.0
|
if speed_weight_total > 0.0
|
||||||
else np.nan
|
else np.nan
|
||||||
)
|
)
|
||||||
|
occupancy_pct = (
|
||||||
|
occupancy_sum / occupancy_count
|
||||||
|
if occupancy_count > 0
|
||||||
|
else np.nan
|
||||||
|
)
|
||||||
|
|
||||||
detector_rows.append(
|
detector_rows.append(
|
||||||
{
|
{
|
||||||
|
|
@ -213,6 +225,7 @@ def _build_detector_rows_from_xml(log_dir: str, episode: int) -> List[Dict]:
|
||||||
"pos_index": int(cell["pos_index"]),
|
"pos_index": int(cell["pos_index"]),
|
||||||
"position_m": float(cell["position_m"]),
|
"position_m": float(cell["position_m"]),
|
||||||
"measured_speed_kmh": float(measured_speed_kmh) if np.isfinite(measured_speed_kmh) else np.nan,
|
"measured_speed_kmh": float(measured_speed_kmh) if np.isfinite(measured_speed_kmh) else np.nan,
|
||||||
|
"occupancy": float(occupancy_pct) if np.isfinite(occupancy_pct) else np.nan,
|
||||||
"density_vehpkm": float(density_sum),
|
"density_vehpkm": float(density_sum),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
@ -321,13 +334,13 @@ def _plot_episode_heatmap(
|
||||||
num_cells = len(ordered_cells)
|
num_cells = len(ordered_cells)
|
||||||
cell_to_row = {cell_id: idx for idx, cell_id in enumerate(ordered_cells)}
|
cell_to_row = {cell_id: idx for idx, cell_id in enumerate(ordered_cells)}
|
||||||
speed_grid = np.full((num_cells, num_steps), np.nan, dtype=np.float32)
|
speed_grid = np.full((num_cells, num_steps), np.nan, dtype=np.float32)
|
||||||
density_grid = np.full((num_cells, num_steps), np.nan, dtype=np.float32)
|
occupancy_grid = np.full((num_cells, num_steps), np.nan, dtype=np.float32)
|
||||||
|
|
||||||
for row in detector_rows:
|
for row in detector_rows:
|
||||||
row_idx = cell_to_row[str(row["cell_id"])]
|
row_idx = cell_to_row[str(row["cell_id"])]
|
||||||
col_idx = step_to_col[int(row["step"])]
|
col_idx = step_to_col[int(row["step"])]
|
||||||
speed_grid[row_idx, col_idx] = _safe_float(row["measured_speed_kmh"])
|
speed_grid[row_idx, col_idx] = _safe_float(row["measured_speed_kmh"])
|
||||||
density_grid[row_idx, col_idx] = _safe_float(row["density_vehpkm"])
|
occupancy_grid[row_idx, col_idx] = _safe_float(row["occupancy"])
|
||||||
|
|
||||||
plots = [
|
plots = [
|
||||||
build_speed_panel(
|
build_speed_panel(
|
||||||
|
|
@ -336,10 +349,10 @@ def _plot_episode_heatmap(
|
||||||
f"{title_prefix} Measured Speed (km/h)",
|
f"{title_prefix} Measured Speed (km/h)",
|
||||||
"Detector Cell",
|
"Detector Cell",
|
||||||
),
|
),
|
||||||
build_density_panel(
|
build_occupancy_panel(
|
||||||
density_grid,
|
occupancy_grid,
|
||||||
ordered_cells,
|
ordered_cells,
|
||||||
f"{title_prefix} Density (veh/km)",
|
f"{title_prefix} Occupancy (%)",
|
||||||
"Detector Cell",
|
"Detector Cell",
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue