139 lines
4.5 KiB
Python
139 lines
4.5 KiB
Python
"""Step-level no-control reward baseline exchange utilities."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import csv
|
|
import os
|
|
import tempfile
|
|
import time
|
|
from pathlib import Path
|
|
from typing import Dict, Iterable, Mapping
|
|
|
|
|
|
BASELINE_COLUMNS = [
|
|
"episode",
|
|
"step",
|
|
"seed",
|
|
"sim_time",
|
|
"reward",
|
|
"mean_speed_kmh",
|
|
"num_vehicles",
|
|
"mainline_completed_count",
|
|
"mainline_interval_travel_time_mean_s",
|
|
"mainline_travel_time_cumulative_mean_s",
|
|
"ttc_risk_rate",
|
|
]
|
|
|
|
|
|
def resolve_baseline_dir(config: Mapping[str, object], run_timestamp: str | None = None) -> str:
|
|
runtime_cfg = config.get("runtime", {}) if isinstance(config, Mapping) else {}
|
|
reward_cfg = (config.get("environment", {}) or {}).get("reward", {}) if isinstance(config, Mapping) else {}
|
|
baseline_dir = str(
|
|
runtime_cfg.get("reward_baseline_dir")
|
|
or reward_cfg.get("baseline_dir")
|
|
or ""
|
|
).strip()
|
|
if baseline_dir:
|
|
return os.path.abspath(baseline_dir)
|
|
|
|
timestamp = run_timestamp or str(runtime_cfg.get("run_timestamp", "") or "default")
|
|
return os.path.abspath(os.path.join("runs", timestamp, "reward_baseline"))
|
|
|
|
|
|
def episode_baseline_path(baseline_dir: str, episode: int) -> str:
|
|
return os.path.join(os.path.abspath(baseline_dir), f"episode_{int(episode):04d}.csv")
|
|
|
|
|
|
def write_episode_baseline(
|
|
*,
|
|
baseline_dir: str,
|
|
episode: int,
|
|
rows: Iterable[Mapping[str, object]],
|
|
) -> str:
|
|
os.makedirs(baseline_dir, exist_ok=True)
|
|
target_path = episode_baseline_path(baseline_dir, episode)
|
|
fd, temp_path = tempfile.mkstemp(
|
|
prefix=f"episode_{int(episode):04d}_",
|
|
suffix=".tmp",
|
|
dir=baseline_dir,
|
|
text=True,
|
|
)
|
|
try:
|
|
with os.fdopen(fd, "w", newline="", encoding="utf-8-sig") as f:
|
|
writer = csv.DictWriter(f, fieldnames=BASELINE_COLUMNS)
|
|
writer.writeheader()
|
|
for row in rows:
|
|
writer.writerow({column: row.get(column, "") for column in BASELINE_COLUMNS})
|
|
os.replace(temp_path, target_path)
|
|
finally:
|
|
if os.path.exists(temp_path):
|
|
os.remove(temp_path)
|
|
return target_path
|
|
|
|
|
|
class EpisodeBaselineWriter:
|
|
def __init__(self, *, baseline_dir: str, episode: int):
|
|
self.baseline_dir = os.path.abspath(baseline_dir)
|
|
self.episode = int(episode)
|
|
self.rows: list[Mapping[str, object]] = []
|
|
|
|
def append(self, row: Mapping[str, object]) -> str:
|
|
self.rows.append(dict(row))
|
|
return write_episode_baseline(
|
|
baseline_dir=self.baseline_dir,
|
|
episode=self.episode,
|
|
rows=self.rows,
|
|
)
|
|
|
|
|
|
def read_episode_baseline(path: str) -> Dict[int, Dict[str, float]]:
|
|
baseline_by_step: Dict[int, Dict[str, float]] = {}
|
|
with open(path, "r", newline="", encoding="utf-8-sig") as f:
|
|
reader = csv.DictReader(f)
|
|
for row in reader:
|
|
try:
|
|
step = int(float(row.get("step", "")))
|
|
except (TypeError, ValueError):
|
|
continue
|
|
baseline_by_step[step] = {
|
|
"mean_speed_kmh": _safe_float(row.get("mean_speed_kmh")),
|
|
"mainline_completed_count": _safe_float(row.get("mainline_completed_count")),
|
|
"mainline_interval_travel_time_mean_s": _safe_float(
|
|
row.get("mainline_interval_travel_time_mean_s")
|
|
),
|
|
"mainline_travel_time_cumulative_mean_s": _safe_float(
|
|
row.get("mainline_travel_time_cumulative_mean_s")
|
|
),
|
|
"ttc_risk_rate": _safe_float(row.get("ttc_risk_rate")),
|
|
}
|
|
return baseline_by_step
|
|
|
|
|
|
def wait_for_episode_baseline(
|
|
*,
|
|
baseline_dir: str,
|
|
episode: int,
|
|
min_step: int,
|
|
timeout_s: float,
|
|
poll_interval_s: float,
|
|
) -> Dict[int, Dict[str, float]]:
|
|
path = episode_baseline_path(baseline_dir, episode)
|
|
deadline = time.monotonic() + max(float(timeout_s), 0.0)
|
|
while True:
|
|
if os.path.isfile(path):
|
|
baseline = read_episode_baseline(path)
|
|
if baseline and max(baseline) >= int(min_step):
|
|
return baseline
|
|
if time.monotonic() >= deadline:
|
|
raise TimeoutError(
|
|
f"Timed out waiting for reward baseline episode={episode} step>={min_step}: {path}"
|
|
)
|
|
time.sleep(max(float(poll_interval_s), 0.05))
|
|
|
|
|
|
def _safe_float(value: object) -> float:
|
|
try:
|
|
return float(value)
|
|
except (TypeError, ValueError):
|
|
return float("nan")
|