"""Step-level no-control reward baseline exchange utilities.""" from __future__ import annotations import csv import os import tempfile import time from pathlib import Path from typing import Dict, Iterable, Mapping BASELINE_COLUMNS = [ "episode", "step", "seed", "sim_time", "reward", "mean_speed_kmh", "num_vehicles", "mainline_completed_count", "mainline_interval_travel_time_mean_s", "mainline_travel_time_cumulative_mean_s", "ttc_risk_rate", ] def resolve_baseline_dir(config: Mapping[str, object], run_timestamp: str | None = None) -> str: runtime_cfg = config.get("runtime", {}) if isinstance(config, Mapping) else {} reward_cfg = (config.get("environment", {}) or {}).get("reward", {}) if isinstance(config, Mapping) else {} baseline_dir = str( runtime_cfg.get("reward_baseline_dir") or reward_cfg.get("baseline_dir") or "" ).strip() if baseline_dir: return os.path.abspath(baseline_dir) timestamp = run_timestamp or str(runtime_cfg.get("run_timestamp", "") or "default") return os.path.abspath(os.path.join("runs", timestamp, "reward_baseline")) def episode_baseline_path(baseline_dir: str, episode: int) -> str: return os.path.join(os.path.abspath(baseline_dir), f"episode_{int(episode):04d}.csv") def write_episode_baseline( *, baseline_dir: str, episode: int, rows: Iterable[Mapping[str, object]], ) -> str: os.makedirs(baseline_dir, exist_ok=True) target_path = episode_baseline_path(baseline_dir, episode) fd, temp_path = tempfile.mkstemp( prefix=f"episode_{int(episode):04d}_", suffix=".tmp", dir=baseline_dir, text=True, ) try: with os.fdopen(fd, "w", newline="", encoding="utf-8-sig") as f: writer = csv.DictWriter(f, fieldnames=BASELINE_COLUMNS) writer.writeheader() for row in rows: writer.writerow({column: row.get(column, "") for column in BASELINE_COLUMNS}) os.replace(temp_path, target_path) finally: if os.path.exists(temp_path): os.remove(temp_path) return target_path class EpisodeBaselineWriter: def __init__(self, *, baseline_dir: str, episode: int): self.baseline_dir = os.path.abspath(baseline_dir) self.episode = int(episode) self.rows: list[Mapping[str, object]] = [] def append(self, row: Mapping[str, object]) -> str: self.rows.append(dict(row)) return write_episode_baseline( baseline_dir=self.baseline_dir, episode=self.episode, rows=self.rows, ) def read_episode_baseline(path: str) -> Dict[int, Dict[str, float]]: baseline_by_step: Dict[int, Dict[str, float]] = {} with open(path, "r", newline="", encoding="utf-8-sig") as f: reader = csv.DictReader(f) for row in reader: try: step = int(float(row.get("step", ""))) except (TypeError, ValueError): continue baseline_by_step[step] = { "mean_speed_kmh": _safe_float(row.get("mean_speed_kmh")), "mainline_completed_count": _safe_float(row.get("mainline_completed_count")), "mainline_interval_travel_time_mean_s": _safe_float( row.get("mainline_interval_travel_time_mean_s") ), "mainline_travel_time_cumulative_mean_s": _safe_float( row.get("mainline_travel_time_cumulative_mean_s") ), "ttc_risk_rate": _safe_float(row.get("ttc_risk_rate")), } return baseline_by_step def wait_for_episode_baseline( *, baseline_dir: str, episode: int, min_step: int, timeout_s: float, poll_interval_s: float, ) -> Dict[int, Dict[str, float]]: path = episode_baseline_path(baseline_dir, episode) deadline = time.monotonic() + max(float(timeout_s), 0.0) while True: if os.path.isfile(path): baseline = read_episode_baseline(path) if baseline and max(baseline) >= int(min_step): return baseline if time.monotonic() >= deadline: raise TimeoutError( f"Timed out waiting for reward baseline episode={episode} step>={min_step}: {path}" ) time.sleep(max(float(poll_interval_s), 0.05)) def _safe_float(value: object) -> float: try: return float(value) except (TypeError, ValueError): return float("nan")