From b7ca7cc5fb663ba1d7a54600e5fbc99092890ea2 Mon Sep 17 00:00:00 2001 From: Maple-YZ Date: Fri, 24 Apr 2026 04:35:52 +0800 Subject: [PATCH] =?UTF-8?q?=E5=9B=BA=E5=AE=9A=E9=9A=8F=E6=9C=BA=E6=95=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- agents/d3pg_agent.py | 3 ++ agents/ddpg_agent.py | 3 ++ agents/sac_agent.py | 3 ++ agents/sctd3_agent.py | 3 ++ agents/td3_agent.py | 3 ++ training/train_appo.py | 6 ++-- training/train_dcmappo.py | 8 ++++-- training/train_ddpg.py | 8 ++++-- training/train_gpro.py | 7 +++-- training/train_mappo.py | 8 ++++-- training/train_ppo.py | 6 ++-- training/train_sac.py | 8 ++++-- training/train_tcamappo.py | 8 ++++-- training/train_td3.py | 7 +++-- training/train_value_based.py | 9 ++++-- utils/seeding.py | 52 +++++++++++++++++++++++++++++++++++ 16 files changed, 116 insertions(+), 26 deletions(-) create mode 100644 utils/seeding.py diff --git a/agents/d3pg_agent.py b/agents/d3pg_agent.py index afb63e6..61dca09 100644 --- a/agents/d3pg_agent.py +++ b/agents/d3pg_agent.py @@ -73,6 +73,7 @@ class D3PGAgent: self, state_dim: int, action_dims: list, + seed: int | None = None, learning_rate: float = 3e-4, buffer_size: int = 100000, learning_starts: int = 100, @@ -93,6 +94,7 @@ class D3PGAgent: self.learning_starts = learning_starts self.total_steps = 0 self.exploration_sigma = exploration_sigma + self.seed = seed dummy_env = MultiDiscreteWrapper(state_dim, action_dims) action_noise = NormalActionNoise( @@ -111,6 +113,7 @@ class D3PGAgent: self.model = TD3( "MlpPolicy", env=dummy_env, + seed=seed, learning_rate=learning_rate, buffer_size=buffer_size, learning_starts=learning_starts, diff --git a/agents/ddpg_agent.py b/agents/ddpg_agent.py index a042bf1..c3b2808 100644 --- a/agents/ddpg_agent.py +++ b/agents/ddpg_agent.py @@ -54,6 +54,7 @@ class DDPGAgent: self, state_dim: int, action_dims: list, + seed: int | None = None, learning_rate: float = 3e-4, buffer_size: int = 100000, learning_starts: int = 100, @@ -72,6 +73,7 @@ class DDPGAgent: self.learning_starts = learning_starts self.total_steps = 0 self.exploration_sigma = exploration_sigma + self.seed = seed dummy_env = MultiDiscreteWrapper(state_dim, action_dims) action_noise = NormalActionNoise( @@ -90,6 +92,7 @@ class DDPGAgent: self.model = DDPG( "MlpPolicy", env=dummy_env, + seed=seed, learning_rate=learning_rate, buffer_size=buffer_size, learning_starts=learning_starts, diff --git a/agents/sac_agent.py b/agents/sac_agent.py index c560afb..c3ddb97 100644 --- a/agents/sac_agent.py +++ b/agents/sac_agent.py @@ -60,6 +60,7 @@ class SACAgent: self, state_dim: int, action_dims: list, + seed: int | None = None, learning_rate: float = 3e-4, buffer_size: int = 100000, learning_starts: int = 100, @@ -81,6 +82,7 @@ class SACAgent: self.device = device self.learning_starts = learning_starts self.total_steps = 0 + self.seed = seed dummy_env = MultiDiscreteWrapper(state_dim, action_dims) policy_kwargs = { @@ -96,6 +98,7 @@ class SACAgent: self.model = SAC( "MlpPolicy", env=dummy_env, + seed=seed, learning_rate=learning_rate, buffer_size=buffer_size, learning_starts=learning_starts, diff --git a/agents/sctd3_agent.py b/agents/sctd3_agent.py index 19ca267..57b8966 100644 --- a/agents/sctd3_agent.py +++ b/agents/sctd3_agent.py @@ -232,6 +232,7 @@ class SCTD3Agent: self, state_dim: int, action_dims: list, + seed: int | None = None, learning_rate: float = 3e-4, buffer_size: int = 100000, learning_starts: int = 100, @@ -260,6 +261,7 @@ class SCTD3Agent: self.learning_starts = learning_starts self.total_steps = 0 self.exploration_sigma = exploration_sigma + self.seed = seed dummy_env = MultiDiscreteWrapper(state_dim, action_dims) action_noise = NormalActionNoise( @@ -292,6 +294,7 @@ class SCTD3Agent: self.model = TD3( "MlpPolicy", env=dummy_env, + seed=seed, learning_rate=learning_rate, buffer_size=buffer_size, learning_starts=learning_starts, diff --git a/agents/td3_agent.py b/agents/td3_agent.py index 90fa4b5..501f55e 100644 --- a/agents/td3_agent.py +++ b/agents/td3_agent.py @@ -64,6 +64,7 @@ class TD3Agent: self, state_dim: int, action_dims: list, + seed: int | None = None, learning_rate: float = 3e-4, buffer_size: int = 100000, learning_starts: int = 100, @@ -84,6 +85,7 @@ class TD3Agent: self.learning_starts = learning_starts self.total_steps = 0 self.exploration_sigma = exploration_sigma + self.seed = seed dummy_env = MultiDiscreteWrapper(state_dim, action_dims) action_noise = NormalActionNoise( @@ -103,6 +105,7 @@ class TD3Agent: self.model = TD3( "MlpPolicy", env=dummy_env, + seed=seed, learning_rate=learning_rate, buffer_size=buffer_size, learning_starts=learning_starts, diff --git a/training/train_appo.py b/training/train_appo.py index 53cd372..4208439 100644 --- a/training/train_appo.py +++ b/training/train_appo.py @@ -22,6 +22,7 @@ from utils.episode_artifacts import save_training_episode_artifacts from utils.logger import TrainingLogger from utils.plot import plot_training_curves from utils.run_dirs import resolve_run_dirs, write_shared_run_config +from utils.seeding import derive_seed, resolve_base_seed, set_global_seed def train_sumo_appo(log_dir=None, checkpoint_dir=None, run_timestamp=None): @@ -31,6 +32,8 @@ def train_sumo_appo(log_dir=None, checkpoint_dir=None, run_timestamp=None): agent_config = get_agent_config(config, "appo") train_config = get_training_config(config) + base_seed = resolve_base_seed(train_config) + set_global_seed(base_seed) start_episode = 1 _, checkpoint_dir, log_dir = resolve_run_dirs( @@ -96,7 +99,6 @@ def train_sumo_appo(log_dir=None, checkpoint_dir=None, run_timestamp=None): num_episodes = train_config["num_episodes"] save_freq = train_config.get("save_freq", 50) log_freq = train_config.get("log_freq", 10) - base_seed = train_config.get("random_seed", 42) # 统计变量 episode_rewards = [] @@ -114,7 +116,7 @@ def train_sumo_appo(log_dir=None, checkpoint_dir=None, run_timestamp=None): try: for episode in range(start_episode, num_episodes + 1): # 每个 episode 使用不同 seed 引入随机性 - seed = base_seed + episode + seed = derive_seed(base_seed, episode) state = env.reset(seed=seed) episode_reward = 0 episode_throughput = 0 diff --git a/training/train_dcmappo.py b/training/train_dcmappo.py index 0be29cb..0e71392 100644 --- a/training/train_dcmappo.py +++ b/training/train_dcmappo.py @@ -18,6 +18,7 @@ from utils.episode_artifacts import save_training_episode_artifacts from utils.logger import TrainingLogger from utils.plot import plot_training_curves from utils.run_dirs import resolve_run_dirs, write_shared_run_config +from utils.seeding import derive_seed, resolve_base_seed, set_global_seed def train_sumo_dcmappo(log_dir=None, checkpoint_dir=None, run_timestamp=None): @@ -26,6 +27,8 @@ def train_sumo_dcmappo(log_dir=None, checkpoint_dir=None, run_timestamp=None): agent_config = get_agent_config(config, "dcmappo") train_config = get_training_config(config) + base_seed = resolve_base_seed(train_config) + set_global_seed(base_seed) _, checkpoint_dir, log_dir = resolve_run_dirs( "dcmappo", @@ -60,6 +63,7 @@ def train_sumo_dcmappo(log_dir=None, checkpoint_dir=None, run_timestamp=None): print(f" Corridor blocks: {agent_config.get('num_corridor_blocks', 2)}") print(f" LR: {agent_config.get('learning_rate', 3e-4)}") print(f" Device: {agent_config.get('device', 'cuda')}") + print(f" Global seed: {base_seed if base_seed is not None else 'None (random)'}") print() agent = DCMAPPOAgent( @@ -91,8 +95,6 @@ def train_sumo_dcmappo(log_dir=None, checkpoint_dir=None, run_timestamp=None): num_episodes = train_config["num_episodes"] save_freq = train_config.get("save_freq", 50) log_freq = train_config.get("log_freq", 10) - base_seed = train_config.get("random_seed", 42) - episode_rewards = [] episode_throughputs = [] episode_mean_speeds = [] @@ -107,7 +109,7 @@ def train_sumo_dcmappo(log_dir=None, checkpoint_dir=None, run_timestamp=None): try: for episode in range(1, num_episodes + 1): - seed = base_seed + episode + seed = derive_seed(base_seed, episode) state = env.reset(seed=seed) episode_reward = 0.0 episode_throughput = 0.0 diff --git a/training/train_ddpg.py b/training/train_ddpg.py index 9a384de..c7c3a34 100644 --- a/training/train_ddpg.py +++ b/training/train_ddpg.py @@ -19,6 +19,7 @@ from utils.episode_artifacts import save_training_episode_artifacts from utils.logger import TrainingLogger from utils.plot import plot_training_curves from utils.run_dirs import resolve_run_dirs, write_shared_run_config +from utils.seeding import derive_seed, resolve_base_seed, set_global_seed def train_sumo_ddpg(log_dir=None, checkpoint_dir=None, run_timestamp=None): @@ -28,6 +29,8 @@ def train_sumo_ddpg(log_dir=None, checkpoint_dir=None, run_timestamp=None): agent_config = get_agent_config(config, "ddpg") train_config = get_training_config(config) + base_seed = resolve_base_seed(train_config) + set_global_seed(base_seed) _, checkpoint_dir, log_dir = resolve_run_dirs( "ddpg", @@ -78,13 +81,12 @@ def train_sumo_ddpg(log_dir=None, checkpoint_dir=None, run_timestamp=None): actor_hidden_dims=agent_config.get("actor_hidden_dims"), critic_hidden_dims=agent_config.get("critic_hidden_dims"), activation_fn=agent_config.get("activation_fn", "relu"), + seed=base_seed, ) num_episodes = train_config["num_episodes"] save_freq = train_config.get("save_freq", 50) log_freq = train_config.get("log_freq", 10) - base_seed = train_config.get("random_seed", 42) - episode_rewards = [] episode_throughputs = [] episode_mean_speeds = [] @@ -96,7 +98,7 @@ def train_sumo_ddpg(log_dir=None, checkpoint_dir=None, run_timestamp=None): try: for episode in range(1, num_episodes + 1): - seed = base_seed + episode + seed = derive_seed(base_seed, episode) state = env.reset(seed=seed) episode_reward = 0 episode_throughput = 0 diff --git a/training/train_gpro.py b/training/train_gpro.py index 541af01..ebd24e0 100644 --- a/training/train_gpro.py +++ b/training/train_gpro.py @@ -16,6 +16,7 @@ from utils.episode_artifacts import save_training_episode_artifacts from utils.logger import TrainingLogger from utils.plot import plot_training_curves from utils.run_dirs import resolve_run_dirs, write_shared_run_config +from utils.seeding import derive_seed, resolve_base_seed, set_global_seed def train_sumo_gpro(log_dir=None, checkpoint_dir=None, run_timestamp=None): @@ -25,6 +26,8 @@ def train_sumo_gpro(log_dir=None, checkpoint_dir=None, run_timestamp=None): agent_config = get_agent_config(config, "gpro") train_config = get_training_config(config) + base_seed = resolve_base_seed(train_config) + set_global_seed(base_seed) _, checkpoint_dir, log_dir = resolve_run_dirs( "gpro", @@ -90,8 +93,6 @@ def train_sumo_gpro(log_dir=None, checkpoint_dir=None, run_timestamp=None): num_episodes = train_config["num_episodes"] save_freq = train_config.get("save_freq", 50) log_freq = train_config.get("log_freq", 10) - base_seed = train_config.get("random_seed", 42) - episode_rewards = [] episode_throughputs = [] episode_mean_speeds = [] @@ -107,7 +108,7 @@ def train_sumo_gpro(log_dir=None, checkpoint_dir=None, run_timestamp=None): try: pending_log_rows = [] for group_start in range(1, num_episodes + 1, group_size): - group_seed = base_seed + ((group_start - 1) // group_size) + 1 + group_seed = derive_seed(base_seed, ((group_start - 1) // group_size) + 1) group_end = min(group_start + group_size - 1, num_episodes) pending_log_rows.clear() diff --git a/training/train_mappo.py b/training/train_mappo.py index 1b98511..7e11cba 100644 --- a/training/train_mappo.py +++ b/training/train_mappo.py @@ -19,6 +19,7 @@ from utils.episode_artifacts import save_training_episode_artifacts from utils.logger import TrainingLogger from utils.plot import plot_training_curves from utils.run_dirs import resolve_run_dirs, write_shared_run_config +from utils.seeding import derive_seed, resolve_base_seed, set_global_seed def train_sumo_mappo(log_dir=None, checkpoint_dir=None, run_timestamp=None): @@ -27,6 +28,8 @@ def train_sumo_mappo(log_dir=None, checkpoint_dir=None, run_timestamp=None): agent_config = get_agent_config(config, "mappo") train_config = get_training_config(config) + base_seed = resolve_base_seed(train_config) + set_global_seed(base_seed) _, checkpoint_dir, log_dir = resolve_run_dirs( "mappo", @@ -60,6 +63,7 @@ def train_sumo_mappo(log_dir=None, checkpoint_dir=None, run_timestamp=None): print(f" Hidden dim: {agent_config.get('hidden_dim', 256)}") print(f" LR: {agent_config.get('learning_rate', 3e-4)}") print(f" Device: {agent_config.get('device', 'cuda')}") + print(f" Global seed: {base_seed if base_seed is not None else 'None (random)'}") print() agent = MAPPOAgent( @@ -88,8 +92,6 @@ def train_sumo_mappo(log_dir=None, checkpoint_dir=None, run_timestamp=None): num_episodes = train_config["num_episodes"] save_freq = train_config.get("save_freq", 50) log_freq = train_config.get("log_freq", 10) - base_seed = train_config.get("random_seed", 42) - episode_rewards = [] episode_throughputs = [] episode_mean_speeds = [] @@ -104,7 +106,7 @@ def train_sumo_mappo(log_dir=None, checkpoint_dir=None, run_timestamp=None): try: for episode in range(1, num_episodes + 1): - seed = base_seed + episode + seed = derive_seed(base_seed, episode) state = env.reset(seed=seed) episode_reward = 0.0 episode_throughput = 0.0 diff --git a/training/train_ppo.py b/training/train_ppo.py index 1852abd..51cfa8d 100644 --- a/training/train_ppo.py +++ b/training/train_ppo.py @@ -22,6 +22,7 @@ from utils.episode_artifacts import save_training_episode_artifacts from utils.logger import TrainingLogger from utils.plot import plot_training_curves from utils.run_dirs import resolve_run_dirs, write_shared_run_config +from utils.seeding import derive_seed, resolve_base_seed, set_global_seed def train_sumo_ppo(log_dir=None, checkpoint_dir=None, run_timestamp=None): @@ -32,6 +33,8 @@ def train_sumo_ppo(log_dir=None, checkpoint_dir=None, run_timestamp=None): agent_config = get_agent_config(config, "ppo") train_config = get_training_config(config) + base_seed = resolve_base_seed(train_config) + set_global_seed(base_seed) start_episode = 1 _, checkpoint_dir, log_dir = resolve_run_dirs( @@ -96,7 +99,6 @@ def train_sumo_ppo(log_dir=None, checkpoint_dir=None, run_timestamp=None): num_episodes = train_config["num_episodes"] save_freq = train_config.get("save_freq", 50) log_freq = train_config.get("log_freq", 10) - base_seed = train_config.get("random_seed", 42) # 统计变量 episode_rewards = [] @@ -114,7 +116,7 @@ def train_sumo_ppo(log_dir=None, checkpoint_dir=None, run_timestamp=None): try: for episode in range(start_episode, num_episodes + 1): # 每个 episode 使用不同 seed 引入随机性 - seed = base_seed + episode + seed = derive_seed(base_seed, episode) state = env.reset(seed=seed) episode_reward = 0 episode_throughput = 0 diff --git a/training/train_sac.py b/training/train_sac.py index 23d8e0c..5dc43f5 100644 --- a/training/train_sac.py +++ b/training/train_sac.py @@ -16,6 +16,7 @@ from utils.episode_artifacts import save_training_episode_artifacts from utils.logger import TrainingLogger from utils.plot import plot_training_curves from utils.run_dirs import resolve_run_dirs, write_shared_run_config +from utils.seeding import derive_seed, resolve_base_seed, set_global_seed def train_sumo_sac(log_dir=None, checkpoint_dir=None, run_timestamp=None): @@ -25,6 +26,8 @@ def train_sumo_sac(log_dir=None, checkpoint_dir=None, run_timestamp=None): agent_config = get_agent_config(config, "sac") train_config = get_training_config(config) + base_seed = resolve_base_seed(train_config) + set_global_seed(base_seed) _, checkpoint_dir, log_dir = resolve_run_dirs( "sac", @@ -78,13 +81,12 @@ def train_sumo_sac(log_dir=None, checkpoint_dir=None, run_timestamp=None): actor_hidden_dims=agent_config.get("actor_hidden_dims"), critic_hidden_dims=agent_config.get("critic_hidden_dims"), activation_fn=agent_config.get("activation_fn", "relu"), + seed=base_seed, ) num_episodes = train_config["num_episodes"] save_freq = train_config.get("save_freq", 50) log_freq = train_config.get("log_freq", 10) - base_seed = train_config.get("random_seed", 42) - episode_rewards = [] episode_throughputs = [] episode_mean_speeds = [] @@ -96,7 +98,7 @@ def train_sumo_sac(log_dir=None, checkpoint_dir=None, run_timestamp=None): try: for episode in range(1, num_episodes + 1): - seed = base_seed + episode + seed = derive_seed(base_seed, episode) state = env.reset(seed=seed) episode_reward = 0.0 episode_throughput = 0.0 diff --git a/training/train_tcamappo.py b/training/train_tcamappo.py index d8b5274..ff3ae15 100644 --- a/training/train_tcamappo.py +++ b/training/train_tcamappo.py @@ -19,6 +19,7 @@ from utils.episode_artifacts import save_training_episode_artifacts from utils.logger import TrainingLogger from utils.plot import plot_training_curves from utils.run_dirs import resolve_run_dirs, write_shared_run_config +from utils.seeding import derive_seed, resolve_base_seed, set_global_seed def train_sumo_tcamappo(log_dir=None, checkpoint_dir=None, run_timestamp=None): @@ -27,6 +28,8 @@ def train_sumo_tcamappo(log_dir=None, checkpoint_dir=None, run_timestamp=None): agent_config = get_agent_config(config, "tcamappo") train_config = get_training_config(config) + base_seed = resolve_base_seed(train_config) + set_global_seed(base_seed) _, checkpoint_dir, log_dir = resolve_run_dirs( "tcamappo", @@ -62,6 +65,7 @@ def train_sumo_tcamappo(log_dir=None, checkpoint_dir=None, run_timestamp=None): print(f" Critic heads/layers: {agent_config.get('critic_num_heads', 4)}/{agent_config.get('critic_num_layers', 2)}") print(f" LR: {agent_config.get('learning_rate', 3e-4)}") print(f" Device: {agent_config.get('device', 'cuda')}") + print(f" Global seed: {base_seed if base_seed is not None else 'None (random)'}") print() agent = TCAMAPPOAgent( @@ -94,8 +98,6 @@ def train_sumo_tcamappo(log_dir=None, checkpoint_dir=None, run_timestamp=None): num_episodes = train_config["num_episodes"] save_freq = train_config.get("save_freq", 50) log_freq = train_config.get("log_freq", 10) - base_seed = train_config.get("random_seed", 42) - episode_rewards = [] episode_throughputs = [] episode_mean_speeds = [] @@ -110,7 +112,7 @@ def train_sumo_tcamappo(log_dir=None, checkpoint_dir=None, run_timestamp=None): try: for episode in range(1, num_episodes + 1): - seed = base_seed + episode + seed = derive_seed(base_seed, episode) state = env.reset(seed=seed) agent.reset_episode() diff --git a/training/train_td3.py b/training/train_td3.py index d91ec41..ed62b3d 100644 --- a/training/train_td3.py +++ b/training/train_td3.py @@ -20,6 +20,7 @@ from utils.episode_artifacts import save_training_episode_artifacts from utils.logger import TrainingLogger from utils.plot import plot_training_curves from utils.run_dirs import resolve_run_dirs, write_shared_run_config +from utils.seeding import derive_seed, resolve_base_seed, set_global_seed def train_sumo_td3( @@ -37,6 +38,8 @@ def train_sumo_td3( agent_config = get_agent_config(config, config_key) train_config = get_training_config(config) + base_seed = resolve_base_seed(train_config) + set_global_seed(base_seed) _, checkpoint_dir, log_dir = resolve_run_dirs( model_name, @@ -102,12 +105,12 @@ def train_sumo_td3( extractor_kernel_size=agent_config.get("extractor_kernel_size", 3), ) + common_kwargs["seed"] = base_seed agent = agent_class(**common_kwargs) num_episodes = train_config["num_episodes"] save_freq = train_config.get("save_freq", 50) log_freq = train_config.get("log_freq", 10) - base_seed = train_config.get("random_seed", 42) episode_rewards = [] episode_throughputs = [] @@ -120,7 +123,7 @@ def train_sumo_td3( try: for episode in range(1, num_episodes + 1): - seed = base_seed + episode + seed = derive_seed(base_seed, episode) state = env.reset(seed=seed) episode_reward = 0 episode_throughput = 0 diff --git a/training/train_value_based.py b/training/train_value_based.py index 5855cb3..87cc938 100644 --- a/training/train_value_based.py +++ b/training/train_value_based.py @@ -18,6 +18,7 @@ from utils.episode_artifacts import save_training_episode_artifacts from utils.logger import TrainingLogger from utils.plot import plot_training_curves from utils.run_dirs import resolve_run_dirs, write_shared_run_config +from utils.seeding import derive_seed, resolve_base_seed, set_global_seed matplotlib.use("Agg") @@ -72,6 +73,8 @@ def train_sumo_value_based( agent_config = get_agent_config(config, model_key) train_config = get_training_config(config) + base_seed = resolve_base_seed(train_config) + set_global_seed(base_seed) _, checkpoint_dir, log_dir = resolve_run_dirs( model_key, @@ -112,12 +115,14 @@ def train_sumo_value_based( print(f" Device: {agent_config.get('device', 'cuda')}") print() + print(f" Global seed: {base_seed if base_seed is not None else 'None (random)'}") + print() + agent = _build_value_based_agent(agent_builder, env, agent_config) num_episodes = train_config["num_episodes"] save_freq = train_config.get("save_freq", 50) log_freq = train_config.get("log_freq", 10) - base_seed = train_config.get("random_seed", 42) episode_rewards = [] episode_throughputs = [] @@ -131,7 +136,7 @@ def train_sumo_value_based( try: for episode in range(1, num_episodes + 1): - seed = base_seed + episode + seed = derive_seed(base_seed, episode) state = env.reset(seed=seed) episode_reward = 0.0 episode_throughput = 0.0 diff --git a/utils/seeding.py b/utils/seeding.py new file mode 100644 index 0000000..87d0baf --- /dev/null +++ b/utils/seeding.py @@ -0,0 +1,52 @@ +"""Utilities for reproducible experiment seeding.""" + +from __future__ import annotations + +import os +import random +from typing import Mapping + +import numpy as np +import torch + + +def resolve_base_seed(training_cfg: Mapping[str, object], default: int = 42) -> int | None: + """Return the configured base seed, preserving explicit null as true random.""" + if "random_seed" not in training_cfg: + return int(default) + seed = training_cfg.get("random_seed") + if seed is None: + return None + return int(seed) + + +def derive_seed(base_seed: int | None, offset: int = 0) -> int | None: + """Derive a deterministic child seed from a base seed.""" + if base_seed is None: + return None + return int(base_seed) + int(offset) + + +def set_global_seed(seed: int | None, *, deterministic_torch: bool = True) -> None: + """Seed Python, NumPy and PyTorch RNGs for reproducible training.""" + if seed is None: + return + + seed = int(seed) + os.environ["PYTHONHASHSEED"] = str(seed) + os.environ.setdefault("CUBLAS_WORKSPACE_CONFIG", ":4096:8") + + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + if deterministic_torch: + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + try: + torch.use_deterministic_algorithms(True, warn_only=True) + except Exception: + pass