ctm-dqn/utils/seeding.py

53 lines
1.5 KiB
Python

"""Utilities for reproducible experiment seeding."""
from __future__ import annotations
import os
import random
from typing import Mapping
import numpy as np
import torch
def resolve_base_seed(training_cfg: Mapping[str, object], default: int = 42) -> int | None:
"""Return the configured base seed, preserving explicit null as true random."""
if "random_seed" not in training_cfg:
return int(default)
seed = training_cfg.get("random_seed")
if seed is None:
return None
return int(seed)
def derive_seed(base_seed: int | None, offset: int = 0) -> int | None:
"""Derive a deterministic child seed from a base seed."""
if base_seed is None:
return None
return int(base_seed) + int(offset)
def set_global_seed(seed: int | None, *, deterministic_torch: bool = True) -> None:
"""Seed Python, NumPy and PyTorch RNGs for reproducible training."""
if seed is None:
return
seed = int(seed)
os.environ["PYTHONHASHSEED"] = str(seed)
os.environ.setdefault("CUBLAS_WORKSPACE_CONFIG", ":4096:8")
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
if deterministic_torch:
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
try:
torch.use_deterministic_algorithms(True, warn_only=True)
except Exception:
pass