124 lines
4.1 KiB
Python
124 lines
4.1 KiB
Python
"""Shared reward configuration and calculation for freeway VSL environments."""
|
|
|
|
from __future__ import annotations
|
|
|
|
from dataclasses import dataclass
|
|
from typing import Dict, Mapping, Sequence
|
|
|
|
import numpy as np
|
|
|
|
|
|
REWARD_COMPONENT_COLUMNS = (
|
|
"r_efficiency",
|
|
"r_safety",
|
|
"r_utility",
|
|
)
|
|
|
|
REWARD_COMPONENT_LABELS = {
|
|
"r_efficiency": "R_efficiency",
|
|
"r_safety": "R_safety",
|
|
"r_utility": "R_utility",
|
|
}
|
|
|
|
|
|
def clip01(value: float) -> float:
|
|
return float(np.clip(value, 0.0, 1.0))
|
|
|
|
|
|
def init_reward_component_totals() -> Dict[str, float]:
|
|
return {column: 0.0 for column in REWARD_COMPONENT_COLUMNS}
|
|
|
|
|
|
def average_reward_components(totals: Mapping[str, float], steps: int) -> Dict[str, float]:
|
|
denom = max(int(steps), 1)
|
|
return {column: float(totals.get(column, 0.0)) / denom for column in REWARD_COMPONENT_COLUMNS}
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class RewardConfig:
|
|
efficiency_alpha: float = 2.19
|
|
safety_beta: float = 9.19
|
|
efficiency_exponent: float = 0.50
|
|
safety_exponent: float = 0.50
|
|
ttc_threshold_s: float = 2.3
|
|
bottleneck_window_size: int = 3
|
|
v_limit: float = 33.33
|
|
leader_gap_threshold_m: float = 100.0
|
|
|
|
@classmethod
|
|
def from_dict(
|
|
cls,
|
|
raw_cfg: Mapping[str, object],
|
|
*,
|
|
speed_actions_ms: Sequence[float],
|
|
) -> "RewardConfig":
|
|
_ = speed_actions_ms
|
|
|
|
return cls(
|
|
efficiency_alpha=float(raw_cfg.get("efficiency_alpha", 2.19)),
|
|
safety_beta=float(raw_cfg.get("safety_beta", 9.19)),
|
|
efficiency_exponent=float(raw_cfg.get("efficiency_exponent", 0.50)),
|
|
safety_exponent=float(raw_cfg.get("safety_exponent", 0.50)),
|
|
ttc_threshold_s=float(raw_cfg.get("ttc_threshold_s", 2.3)),
|
|
bottleneck_window_size=max(1, int(raw_cfg.get("bottleneck_window_size", 3))),
|
|
v_limit=float(raw_cfg.get("v_limit", 33.33)),
|
|
leader_gap_threshold_m=float(raw_cfg.get("leader_gap_threshold_m", 100.0)),
|
|
)
|
|
|
|
|
|
class RewardCalculator:
|
|
"""Encapsulates a multiplicative MAUT reward for freeway VSL control."""
|
|
|
|
def __init__(
|
|
self,
|
|
*,
|
|
config: RewardConfig,
|
|
controlled_edge_start_index: int,
|
|
evaluation_mode: bool = False,
|
|
):
|
|
self.config = config
|
|
self.controlled_edge_start_index = int(controlled_edge_start_index)
|
|
self.evaluation_mode = bool(evaluation_mode)
|
|
|
|
def _normalized_preference_exponents(self) -> tuple[float, float]:
|
|
eff = max(float(self.config.efficiency_exponent), 0.0)
|
|
safe = max(float(self.config.safety_exponent), 0.0)
|
|
total = eff + safe
|
|
if total <= 1e-8:
|
|
return 0.5, 0.5
|
|
return eff / total, safe / total
|
|
|
|
def calculate(
|
|
self,
|
|
*,
|
|
info: Dict,
|
|
current_edge_speeds: np.ndarray,
|
|
prev_edge_speeds: np.ndarray,
|
|
episode_index: int,
|
|
) -> float:
|
|
_ = current_edge_speeds, prev_edge_speeds, episode_index
|
|
|
|
mean_speed = max(float(info.get("mean_speed", 0.0)), 0.0)
|
|
num_vehicles = max(int(info.get("num_vehicles", 0)), 0)
|
|
efficiency_norm = clip01(mean_speed / max(self.config.v_limit, 1e-6)) if num_vehicles > 0 else 0.0
|
|
r_efficiency = 1.0 - float(np.exp(-max(self.config.efficiency_alpha, 0.0) * efficiency_norm))
|
|
|
|
ttc_risk_rate = clip01(float(info.get("ttc_risk_rate", 0.0)))
|
|
safety_risk = ttc_risk_rate
|
|
r_safety = float(np.exp(-max(self.config.safety_beta, 0.0) * safety_risk))
|
|
lambda_eff, lambda_safe = self._normalized_preference_exponents()
|
|
r_utility = float(
|
|
np.power(max(r_efficiency, 1e-8), lambda_eff)
|
|
* np.power(max(r_safety, 1e-8), lambda_safe)
|
|
)
|
|
|
|
info["r_efficiency"] = float(r_efficiency)
|
|
info["r_safety"] = float(r_safety)
|
|
info["r_utility"] = float(r_utility)
|
|
info["efficiency_norm"] = float(efficiency_norm)
|
|
info["safety_risk_norm"] = float(safety_risk)
|
|
info["ttc_threshold_s"] = float(self.config.ttc_threshold_s)
|
|
info["efficiency_lambda"] = float(lambda_eff)
|
|
info["safety_lambda"] = float(lambda_safe)
|
|
return float(r_utility)
|