184 lines
6.6 KiB
Python
184 lines
6.6 KiB
Python
"""Shared reward calculation for freeway VSL environments."""
|
|
|
|
from __future__ import annotations
|
|
|
|
from dataclasses import dataclass
|
|
from typing import Dict, Mapping, Sequence
|
|
|
|
import numpy as np
|
|
|
|
|
|
REWARD_COMPONENT_COLUMNS = (
|
|
"r_travel_time_improvement",
|
|
"r_ttc_improvement",
|
|
"r_improvement",
|
|
)
|
|
|
|
REWARD_COMPONENT_LABELS = {
|
|
"r_travel_time_improvement": "R_travel_time",
|
|
"r_ttc_improvement": "R_ttc",
|
|
"r_improvement": "R_total",
|
|
}
|
|
|
|
|
|
def init_reward_component_totals() -> Dict[str, float]:
|
|
return {column: 0.0 for column in REWARD_COMPONENT_COLUMNS}
|
|
|
|
|
|
def average_reward_components(totals: Mapping[str, float], steps: int) -> Dict[str, float]:
|
|
denom = max(int(steps), 1)
|
|
return {column: float(totals.get(column, 0.0)) / denom for column in REWARD_COMPONENT_COLUMNS}
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class RewardConfig:
|
|
ttc_threshold_s: float = 2.3
|
|
bottleneck_window_size: int = 3
|
|
leader_gap_threshold_m: float = 100.0
|
|
mode: str = "paired_no_control"
|
|
baseline_dir: str = ""
|
|
baseline_key: str = "step"
|
|
baseline_wait_timeout_s: float = 3600.0
|
|
baseline_poll_interval_s: float = 1.0
|
|
travel_time_weight: float = 0.5
|
|
ttc_weight: float = 0.5
|
|
travel_time_min_denominator_s: float = 60.0
|
|
|
|
@classmethod
|
|
def from_dict(
|
|
cls,
|
|
raw_cfg: Mapping[str, object],
|
|
*,
|
|
speed_actions_ms: Sequence[float],
|
|
) -> "RewardConfig":
|
|
_ = speed_actions_ms
|
|
return cls(
|
|
ttc_threshold_s=float(raw_cfg.get("ttc_threshold_s", 2.3)),
|
|
bottleneck_window_size=max(1, int(raw_cfg.get("bottleneck_window_size", 3))),
|
|
leader_gap_threshold_m=float(raw_cfg.get("leader_gap_threshold_m", 100.0)),
|
|
mode=str(raw_cfg.get("mode", "paired_no_control")).strip().lower(),
|
|
baseline_dir=str(raw_cfg.get("baseline_dir", "") or ""),
|
|
baseline_key=str(raw_cfg.get("baseline_key", "step")).strip().lower(),
|
|
baseline_wait_timeout_s=float(raw_cfg.get("baseline_wait_timeout_s", 3600.0)),
|
|
baseline_poll_interval_s=float(raw_cfg.get("baseline_poll_interval_s", 1.0)),
|
|
travel_time_weight=float(raw_cfg.get("travel_time_weight", 0.5)),
|
|
ttc_weight=float(raw_cfg.get("ttc_weight", 0.5)),
|
|
travel_time_min_denominator_s=max(
|
|
float(raw_cfg.get("travel_time_min_denominator_s", 60.0)),
|
|
1e-6,
|
|
),
|
|
)
|
|
|
|
|
|
class RewardCalculator:
|
|
"""Counterfactual reward relative to synchronized no-control episodes."""
|
|
|
|
def __init__(
|
|
self,
|
|
*,
|
|
config: RewardConfig,
|
|
controlled_edge_start_index: int,
|
|
evaluation_mode: bool = False,
|
|
):
|
|
self.config = config
|
|
self.controlled_edge_start_index = int(controlled_edge_start_index)
|
|
self.evaluation_mode = bool(evaluation_mode)
|
|
self.baseline_by_step: Dict[int, Mapping[str, float]] = {}
|
|
|
|
def set_step_baseline(self, baseline_by_step: Mapping[int, Mapping[str, float]]) -> None:
|
|
self.baseline_by_step = dict(baseline_by_step or {})
|
|
|
|
def calculate(
|
|
self,
|
|
*,
|
|
info: Dict,
|
|
current_edge_speeds: np.ndarray,
|
|
prev_edge_speeds: np.ndarray,
|
|
episode_index: int,
|
|
) -> float:
|
|
_ = current_edge_speeds, prev_edge_speeds, episode_index
|
|
|
|
current_ttc_risk = _clip(float(info.get("ttc_risk_rate", 0.0)), 0.0, 1.0)
|
|
current_travel_time = float(
|
|
info.get("mainline_travel_time_cumulative_mean_s", np.nan)
|
|
)
|
|
|
|
baseline = self._get_baseline(info)
|
|
baseline_ttc_risk = _safe_baseline_float(baseline, "ttc_risk_rate")
|
|
baseline_travel_time = _safe_baseline_float(
|
|
baseline,
|
|
"mainline_travel_time_cumulative_mean_s",
|
|
)
|
|
|
|
travel_time_improvement = self._travel_time_improvement(
|
|
current_travel_time=current_travel_time,
|
|
baseline_travel_time=baseline_travel_time,
|
|
)
|
|
ttc_improvement = self._ttc_improvement(
|
|
current_ttc_risk=current_ttc_risk,
|
|
baseline_ttc_risk=baseline_ttc_risk,
|
|
)
|
|
reward = float(
|
|
self.config.travel_time_weight * travel_time_improvement
|
|
+ self.config.ttc_weight * ttc_improvement
|
|
)
|
|
|
|
info["reward_mode"] = self.config.mode
|
|
info["safety_risk_norm"] = float(current_ttc_risk)
|
|
info["ttc_threshold_s"] = float(self.config.ttc_threshold_s)
|
|
info["baseline_mainline_travel_time_cumulative_mean_s"] = float(baseline_travel_time)
|
|
info["baseline_ttc_risk_rate"] = float(baseline_ttc_risk)
|
|
info["travel_time_relative_denominator_s"] = float(
|
|
self._travel_time_denominator(baseline_travel_time)
|
|
)
|
|
info["r_travel_time_improvement"] = float(travel_time_improvement)
|
|
info["r_ttc_improvement"] = float(ttc_improvement)
|
|
info["r_improvement"] = float(reward)
|
|
|
|
if self.config.mode in {"paired_no_control", "episode_baseline"}:
|
|
return reward
|
|
return 0.0
|
|
|
|
def _get_baseline(self, info: Mapping[str, object]) -> Mapping[str, float] | None:
|
|
lookup_key = "sim_time" if self.config.baseline_key == "sim_time" else "step"
|
|
try:
|
|
key_value = int(round(float(info.get(lookup_key, 0.0))))
|
|
except (TypeError, ValueError):
|
|
key_value = 0
|
|
return self.baseline_by_step.get(key_value)
|
|
|
|
def _travel_time_denominator(self, baseline_travel_time: float) -> float:
|
|
if not np.isfinite(baseline_travel_time):
|
|
return float("nan")
|
|
return max(float(baseline_travel_time), self.config.travel_time_min_denominator_s)
|
|
|
|
def _travel_time_improvement(
|
|
self,
|
|
*,
|
|
current_travel_time: float,
|
|
baseline_travel_time: float,
|
|
) -> float:
|
|
denominator = self._travel_time_denominator(baseline_travel_time)
|
|
if not np.isfinite(current_travel_time) or not np.isfinite(denominator):
|
|
return 0.0
|
|
return _clip((baseline_travel_time - current_travel_time) / denominator, -1.0, 1.0)
|
|
|
|
@staticmethod
|
|
def _ttc_improvement(*, current_ttc_risk: float, baseline_ttc_risk: float) -> float:
|
|
if not np.isfinite(baseline_ttc_risk):
|
|
return 0.0
|
|
return _clip(baseline_ttc_risk - current_ttc_risk, -1.0, 1.0)
|
|
|
|
|
|
def _safe_baseline_float(baseline: Mapping[str, float] | None, key: str) -> float:
|
|
if baseline is None:
|
|
return float("nan")
|
|
try:
|
|
return float(baseline.get(key, np.nan))
|
|
except (TypeError, ValueError):
|
|
return float("nan")
|
|
|
|
|
|
def _clip(value: float, lower: float, upper: float) -> float:
|
|
return float(np.clip(value, lower, upper))
|