ctm-dqn/utils/reward_baseline.py

139 lines
4.5 KiB
Python

"""Step-level no-control reward baseline exchange utilities."""
from __future__ import annotations
import csv
import os
import tempfile
import time
from pathlib import Path
from typing import Dict, Iterable, Mapping
BASELINE_COLUMNS = [
"episode",
"step",
"seed",
"sim_time",
"reward",
"mean_speed_kmh",
"num_vehicles",
"mainline_completed_count",
"mainline_interval_travel_time_mean_s",
"mainline_travel_time_cumulative_mean_s",
"ttc_risk_rate",
]
def resolve_baseline_dir(config: Mapping[str, object], run_timestamp: str | None = None) -> str:
runtime_cfg = config.get("runtime", {}) if isinstance(config, Mapping) else {}
reward_cfg = (config.get("environment", {}) or {}).get("reward", {}) if isinstance(config, Mapping) else {}
baseline_dir = str(
runtime_cfg.get("reward_baseline_dir")
or reward_cfg.get("baseline_dir")
or ""
).strip()
if baseline_dir:
return os.path.abspath(baseline_dir)
timestamp = run_timestamp or str(runtime_cfg.get("run_timestamp", "") or "default")
return os.path.abspath(os.path.join("runs", timestamp, "reward_baseline"))
def episode_baseline_path(baseline_dir: str, episode: int) -> str:
return os.path.join(os.path.abspath(baseline_dir), f"episode_{int(episode):04d}.csv")
def write_episode_baseline(
*,
baseline_dir: str,
episode: int,
rows: Iterable[Mapping[str, object]],
) -> str:
os.makedirs(baseline_dir, exist_ok=True)
target_path = episode_baseline_path(baseline_dir, episode)
fd, temp_path = tempfile.mkstemp(
prefix=f"episode_{int(episode):04d}_",
suffix=".tmp",
dir=baseline_dir,
text=True,
)
try:
with os.fdopen(fd, "w", newline="", encoding="utf-8-sig") as f:
writer = csv.DictWriter(f, fieldnames=BASELINE_COLUMNS)
writer.writeheader()
for row in rows:
writer.writerow({column: row.get(column, "") for column in BASELINE_COLUMNS})
os.replace(temp_path, target_path)
finally:
if os.path.exists(temp_path):
os.remove(temp_path)
return target_path
class EpisodeBaselineWriter:
def __init__(self, *, baseline_dir: str, episode: int):
self.baseline_dir = os.path.abspath(baseline_dir)
self.episode = int(episode)
self.rows: list[Mapping[str, object]] = []
def append(self, row: Mapping[str, object]) -> str:
self.rows.append(dict(row))
return write_episode_baseline(
baseline_dir=self.baseline_dir,
episode=self.episode,
rows=self.rows,
)
def read_episode_baseline(path: str) -> Dict[int, Dict[str, float]]:
baseline_by_step: Dict[int, Dict[str, float]] = {}
with open(path, "r", newline="", encoding="utf-8-sig") as f:
reader = csv.DictReader(f)
for row in reader:
try:
step = int(float(row.get("step", "")))
except (TypeError, ValueError):
continue
baseline_by_step[step] = {
"mean_speed_kmh": _safe_float(row.get("mean_speed_kmh")),
"mainline_completed_count": _safe_float(row.get("mainline_completed_count")),
"mainline_interval_travel_time_mean_s": _safe_float(
row.get("mainline_interval_travel_time_mean_s")
),
"mainline_travel_time_cumulative_mean_s": _safe_float(
row.get("mainline_travel_time_cumulative_mean_s")
),
"ttc_risk_rate": _safe_float(row.get("ttc_risk_rate")),
}
return baseline_by_step
def wait_for_episode_baseline(
*,
baseline_dir: str,
episode: int,
min_step: int,
timeout_s: float,
poll_interval_s: float,
) -> Dict[int, Dict[str, float]]:
path = episode_baseline_path(baseline_dir, episode)
deadline = time.monotonic() + max(float(timeout_s), 0.0)
while True:
if os.path.isfile(path):
baseline = read_episode_baseline(path)
if baseline and max(baseline) >= int(min_step):
return baseline
if time.monotonic() >= deadline:
raise TimeoutError(
f"Timed out waiting for reward baseline episode={episode} step>={min_step}: {path}"
)
time.sleep(max(float(poll_interval_s), 0.05))
def _safe_float(value: object) -> float:
try:
return float(value)
except (TypeError, ValueError):
return float("nan")