ctm-dqn/envs/reward_system.py

184 lines
6.6 KiB
Python

"""Shared reward calculation for freeway VSL environments."""
from __future__ import annotations
from dataclasses import dataclass
from typing import Dict, Mapping, Sequence
import numpy as np
REWARD_COMPONENT_COLUMNS = (
"r_travel_time_improvement",
"r_ttc_improvement",
"r_improvement",
)
REWARD_COMPONENT_LABELS = {
"r_travel_time_improvement": "R_travel_time",
"r_ttc_improvement": "R_ttc",
"r_improvement": "R_total",
}
def init_reward_component_totals() -> Dict[str, float]:
return {column: 0.0 for column in REWARD_COMPONENT_COLUMNS}
def average_reward_components(totals: Mapping[str, float], steps: int) -> Dict[str, float]:
denom = max(int(steps), 1)
return {column: float(totals.get(column, 0.0)) / denom for column in REWARD_COMPONENT_COLUMNS}
@dataclass(frozen=True)
class RewardConfig:
ttc_threshold_s: float = 2.3
bottleneck_window_size: int = 3
leader_gap_threshold_m: float = 100.0
mode: str = "paired_no_control"
baseline_dir: str = ""
baseline_key: str = "step"
baseline_wait_timeout_s: float = 3600.0
baseline_poll_interval_s: float = 1.0
travel_time_weight: float = 0.5
ttc_weight: float = 0.5
travel_time_min_denominator_s: float = 60.0
@classmethod
def from_dict(
cls,
raw_cfg: Mapping[str, object],
*,
speed_actions_ms: Sequence[float],
) -> "RewardConfig":
_ = speed_actions_ms
return cls(
ttc_threshold_s=float(raw_cfg.get("ttc_threshold_s", 2.3)),
bottleneck_window_size=max(1, int(raw_cfg.get("bottleneck_window_size", 3))),
leader_gap_threshold_m=float(raw_cfg.get("leader_gap_threshold_m", 100.0)),
mode=str(raw_cfg.get("mode", "paired_no_control")).strip().lower(),
baseline_dir=str(raw_cfg.get("baseline_dir", "") or ""),
baseline_key=str(raw_cfg.get("baseline_key", "step")).strip().lower(),
baseline_wait_timeout_s=float(raw_cfg.get("baseline_wait_timeout_s", 3600.0)),
baseline_poll_interval_s=float(raw_cfg.get("baseline_poll_interval_s", 1.0)),
travel_time_weight=float(raw_cfg.get("travel_time_weight", 0.5)),
ttc_weight=float(raw_cfg.get("ttc_weight", 0.5)),
travel_time_min_denominator_s=max(
float(raw_cfg.get("travel_time_min_denominator_s", 60.0)),
1e-6,
),
)
class RewardCalculator:
"""Counterfactual reward relative to synchronized no-control episodes."""
def __init__(
self,
*,
config: RewardConfig,
controlled_edge_start_index: int,
evaluation_mode: bool = False,
):
self.config = config
self.controlled_edge_start_index = int(controlled_edge_start_index)
self.evaluation_mode = bool(evaluation_mode)
self.baseline_by_step: Dict[int, Mapping[str, float]] = {}
def set_step_baseline(self, baseline_by_step: Mapping[int, Mapping[str, float]]) -> None:
self.baseline_by_step = dict(baseline_by_step or {})
def calculate(
self,
*,
info: Dict,
current_edge_speeds: np.ndarray,
prev_edge_speeds: np.ndarray,
episode_index: int,
) -> float:
_ = current_edge_speeds, prev_edge_speeds, episode_index
current_ttc_risk = _clip(float(info.get("ttc_risk_rate", 0.0)), 0.0, 1.0)
current_travel_time = float(
info.get("mainline_travel_time_cumulative_mean_s", np.nan)
)
baseline = self._get_baseline(info)
baseline_ttc_risk = _safe_baseline_float(baseline, "ttc_risk_rate")
baseline_travel_time = _safe_baseline_float(
baseline,
"mainline_travel_time_cumulative_mean_s",
)
travel_time_improvement = self._travel_time_improvement(
current_travel_time=current_travel_time,
baseline_travel_time=baseline_travel_time,
)
ttc_improvement = self._ttc_improvement(
current_ttc_risk=current_ttc_risk,
baseline_ttc_risk=baseline_ttc_risk,
)
reward = float(
self.config.travel_time_weight * travel_time_improvement
+ self.config.ttc_weight * ttc_improvement
)
info["reward_mode"] = self.config.mode
info["safety_risk_norm"] = float(current_ttc_risk)
info["ttc_threshold_s"] = float(self.config.ttc_threshold_s)
info["baseline_mainline_travel_time_cumulative_mean_s"] = float(baseline_travel_time)
info["baseline_ttc_risk_rate"] = float(baseline_ttc_risk)
info["travel_time_relative_denominator_s"] = float(
self._travel_time_denominator(baseline_travel_time)
)
info["r_travel_time_improvement"] = float(travel_time_improvement)
info["r_ttc_improvement"] = float(ttc_improvement)
info["r_improvement"] = float(reward)
if self.config.mode in {"paired_no_control", "episode_baseline"}:
return reward
return 0.0
def _get_baseline(self, info: Mapping[str, object]) -> Mapping[str, float] | None:
lookup_key = "sim_time" if self.config.baseline_key == "sim_time" else "step"
try:
key_value = int(round(float(info.get(lookup_key, 0.0))))
except (TypeError, ValueError):
key_value = 0
return self.baseline_by_step.get(key_value)
def _travel_time_denominator(self, baseline_travel_time: float) -> float:
if not np.isfinite(baseline_travel_time):
return float("nan")
return max(float(baseline_travel_time), self.config.travel_time_min_denominator_s)
def _travel_time_improvement(
self,
*,
current_travel_time: float,
baseline_travel_time: float,
) -> float:
denominator = self._travel_time_denominator(baseline_travel_time)
if not np.isfinite(current_travel_time) or not np.isfinite(denominator):
return 0.0
return _clip((baseline_travel_time - current_travel_time) / denominator, -1.0, 1.0)
@staticmethod
def _ttc_improvement(*, current_ttc_risk: float, baseline_ttc_risk: float) -> float:
if not np.isfinite(baseline_ttc_risk):
return 0.0
return _clip(baseline_ttc_risk - current_ttc_risk, -1.0, 1.0)
def _safe_baseline_float(baseline: Mapping[str, float] | None, key: str) -> float:
if baseline is None:
return float("nan")
try:
return float(baseline.get(key, np.nan))
except (TypeError, ValueError):
return float("nan")
def _clip(value: float, lower: float, upper: float) -> float:
return float(np.clip(value, lower, upper))