"""Shared reward calculation for freeway VSL environments.""" from __future__ import annotations from dataclasses import dataclass from typing import Dict, Mapping, Sequence import numpy as np REWARD_COMPONENT_COLUMNS = ( "r_travel_time_improvement", "r_ttc_improvement", "r_improvement", ) REWARD_COMPONENT_LABELS = { "r_travel_time_improvement": "R_travel_time", "r_ttc_improvement": "R_ttc", "r_improvement": "R_total", } def init_reward_component_totals() -> Dict[str, float]: return {column: 0.0 for column in REWARD_COMPONENT_COLUMNS} def average_reward_components(totals: Mapping[str, float], steps: int) -> Dict[str, float]: denom = max(int(steps), 1) return {column: float(totals.get(column, 0.0)) / denom for column in REWARD_COMPONENT_COLUMNS} @dataclass(frozen=True) class RewardConfig: ttc_threshold_s: float = 2.3 bottleneck_window_size: int = 3 leader_gap_threshold_m: float = 100.0 mode: str = "paired_no_control" baseline_dir: str = "" baseline_key: str = "step" baseline_wait_timeout_s: float = 3600.0 baseline_poll_interval_s: float = 1.0 travel_time_weight: float = 0.5 ttc_weight: float = 0.5 travel_time_min_denominator_s: float = 60.0 @classmethod def from_dict( cls, raw_cfg: Mapping[str, object], *, speed_actions_ms: Sequence[float], ) -> "RewardConfig": _ = speed_actions_ms return cls( ttc_threshold_s=float(raw_cfg.get("ttc_threshold_s", 2.3)), bottleneck_window_size=max(1, int(raw_cfg.get("bottleneck_window_size", 3))), leader_gap_threshold_m=float(raw_cfg.get("leader_gap_threshold_m", 100.0)), mode=str(raw_cfg.get("mode", "paired_no_control")).strip().lower(), baseline_dir=str(raw_cfg.get("baseline_dir", "") or ""), baseline_key=str(raw_cfg.get("baseline_key", "step")).strip().lower(), baseline_wait_timeout_s=float(raw_cfg.get("baseline_wait_timeout_s", 3600.0)), baseline_poll_interval_s=float(raw_cfg.get("baseline_poll_interval_s", 1.0)), travel_time_weight=float(raw_cfg.get("travel_time_weight", 0.5)), ttc_weight=float(raw_cfg.get("ttc_weight", 0.5)), travel_time_min_denominator_s=max( float(raw_cfg.get("travel_time_min_denominator_s", 60.0)), 1e-6, ), ) class RewardCalculator: """Counterfactual reward relative to synchronized no-control episodes.""" def __init__( self, *, config: RewardConfig, controlled_edge_start_index: int, evaluation_mode: bool = False, ): self.config = config self.controlled_edge_start_index = int(controlled_edge_start_index) self.evaluation_mode = bool(evaluation_mode) self.baseline_by_step: Dict[int, Mapping[str, float]] = {} def set_step_baseline(self, baseline_by_step: Mapping[int, Mapping[str, float]]) -> None: self.baseline_by_step = dict(baseline_by_step or {}) def calculate( self, *, info: Dict, current_edge_speeds: np.ndarray, prev_edge_speeds: np.ndarray, episode_index: int, ) -> float: _ = current_edge_speeds, prev_edge_speeds, episode_index current_ttc_risk = _clip(float(info.get("ttc_risk_rate", 0.0)), 0.0, 1.0) current_travel_time = float( info.get("mainline_travel_time_cumulative_mean_s", np.nan) ) baseline = self._get_baseline(info) baseline_ttc_risk = _safe_baseline_float(baseline, "ttc_risk_rate") baseline_travel_time = _safe_baseline_float( baseline, "mainline_travel_time_cumulative_mean_s", ) travel_time_improvement = self._travel_time_improvement( current_travel_time=current_travel_time, baseline_travel_time=baseline_travel_time, ) ttc_improvement = self._ttc_improvement( current_ttc_risk=current_ttc_risk, baseline_ttc_risk=baseline_ttc_risk, ) reward = float( self.config.travel_time_weight * travel_time_improvement + self.config.ttc_weight * ttc_improvement ) info["reward_mode"] = self.config.mode info["safety_risk_norm"] = float(current_ttc_risk) info["ttc_threshold_s"] = float(self.config.ttc_threshold_s) info["baseline_mainline_travel_time_cumulative_mean_s"] = float(baseline_travel_time) info["baseline_ttc_risk_rate"] = float(baseline_ttc_risk) info["travel_time_relative_denominator_s"] = float( self._travel_time_denominator(baseline_travel_time) ) info["r_travel_time_improvement"] = float(travel_time_improvement) info["r_ttc_improvement"] = float(ttc_improvement) info["r_improvement"] = float(reward) if self.config.mode in {"paired_no_control", "episode_baseline"}: return reward return 0.0 def _get_baseline(self, info: Mapping[str, object]) -> Mapping[str, float] | None: lookup_key = "sim_time" if self.config.baseline_key == "sim_time" else "step" try: key_value = int(round(float(info.get(lookup_key, 0.0)))) except (TypeError, ValueError): key_value = 0 return self.baseline_by_step.get(key_value) def _travel_time_denominator(self, baseline_travel_time: float) -> float: if not np.isfinite(baseline_travel_time): return float("nan") return max(float(baseline_travel_time), self.config.travel_time_min_denominator_s) def _travel_time_improvement( self, *, current_travel_time: float, baseline_travel_time: float, ) -> float: denominator = self._travel_time_denominator(baseline_travel_time) if not np.isfinite(current_travel_time) or not np.isfinite(denominator): return 0.0 return _clip((baseline_travel_time - current_travel_time) / denominator, -1.0, 1.0) @staticmethod def _ttc_improvement(*, current_ttc_risk: float, baseline_ttc_risk: float) -> float: if not np.isfinite(baseline_ttc_risk): return 0.0 return _clip(baseline_ttc_risk - current_ttc_risk, -1.0, 1.0) def _safe_baseline_float(baseline: Mapping[str, float] | None, key: str) -> float: if baseline is None: return float("nan") try: return float(baseline.get(key, np.nan)) except (TypeError, ValueError): return float("nan") def _clip(value: float, lower: float, upper: float) -> float: return float(np.clip(value, lower, upper))