53 lines
1.5 KiB
Python
53 lines
1.5 KiB
Python
"""Utilities for reproducible experiment seeding."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import os
|
|
import random
|
|
from typing import Mapping
|
|
|
|
import numpy as np
|
|
import torch
|
|
|
|
|
|
def resolve_base_seed(training_cfg: Mapping[str, object], default: int = 42) -> int | None:
|
|
"""Return the configured base seed, preserving explicit null as true random."""
|
|
if "random_seed" not in training_cfg:
|
|
return int(default)
|
|
seed = training_cfg.get("random_seed")
|
|
if seed is None:
|
|
return None
|
|
return int(seed)
|
|
|
|
|
|
def derive_seed(base_seed: int | None, offset: int = 0) -> int | None:
|
|
"""Derive a deterministic child seed from a base seed."""
|
|
if base_seed is None:
|
|
return None
|
|
return int(base_seed) + int(offset)
|
|
|
|
|
|
def set_global_seed(seed: int | None, *, deterministic_torch: bool = True) -> None:
|
|
"""Seed Python, NumPy and PyTorch RNGs for reproducible training."""
|
|
if seed is None:
|
|
return
|
|
|
|
seed = int(seed)
|
|
os.environ["PYTHONHASHSEED"] = str(seed)
|
|
os.environ.setdefault("CUBLAS_WORKSPACE_CONFIG", ":4096:8")
|
|
|
|
random.seed(seed)
|
|
np.random.seed(seed)
|
|
torch.manual_seed(seed)
|
|
if torch.cuda.is_available():
|
|
torch.cuda.manual_seed(seed)
|
|
torch.cuda.manual_seed_all(seed)
|
|
|
|
if deterministic_torch:
|
|
torch.backends.cudnn.deterministic = True
|
|
torch.backends.cudnn.benchmark = False
|
|
try:
|
|
torch.use_deterministic_algorithms(True, warn_only=True)
|
|
except Exception:
|
|
pass
|