固定随机数
This commit is contained in:
parent
a6a2e5a626
commit
b7ca7cc5fb
|
|
@ -73,6 +73,7 @@ class D3PGAgent:
|
|||
self,
|
||||
state_dim: int,
|
||||
action_dims: list,
|
||||
seed: int | None = None,
|
||||
learning_rate: float = 3e-4,
|
||||
buffer_size: int = 100000,
|
||||
learning_starts: int = 100,
|
||||
|
|
@ -93,6 +94,7 @@ class D3PGAgent:
|
|||
self.learning_starts = learning_starts
|
||||
self.total_steps = 0
|
||||
self.exploration_sigma = exploration_sigma
|
||||
self.seed = seed
|
||||
|
||||
dummy_env = MultiDiscreteWrapper(state_dim, action_dims)
|
||||
action_noise = NormalActionNoise(
|
||||
|
|
@ -111,6 +113,7 @@ class D3PGAgent:
|
|||
self.model = TD3(
|
||||
"MlpPolicy",
|
||||
env=dummy_env,
|
||||
seed=seed,
|
||||
learning_rate=learning_rate,
|
||||
buffer_size=buffer_size,
|
||||
learning_starts=learning_starts,
|
||||
|
|
|
|||
|
|
@ -54,6 +54,7 @@ class DDPGAgent:
|
|||
self,
|
||||
state_dim: int,
|
||||
action_dims: list,
|
||||
seed: int | None = None,
|
||||
learning_rate: float = 3e-4,
|
||||
buffer_size: int = 100000,
|
||||
learning_starts: int = 100,
|
||||
|
|
@ -72,6 +73,7 @@ class DDPGAgent:
|
|||
self.learning_starts = learning_starts
|
||||
self.total_steps = 0
|
||||
self.exploration_sigma = exploration_sigma
|
||||
self.seed = seed
|
||||
|
||||
dummy_env = MultiDiscreteWrapper(state_dim, action_dims)
|
||||
action_noise = NormalActionNoise(
|
||||
|
|
@ -90,6 +92,7 @@ class DDPGAgent:
|
|||
self.model = DDPG(
|
||||
"MlpPolicy",
|
||||
env=dummy_env,
|
||||
seed=seed,
|
||||
learning_rate=learning_rate,
|
||||
buffer_size=buffer_size,
|
||||
learning_starts=learning_starts,
|
||||
|
|
|
|||
|
|
@ -60,6 +60,7 @@ class SACAgent:
|
|||
self,
|
||||
state_dim: int,
|
||||
action_dims: list,
|
||||
seed: int | None = None,
|
||||
learning_rate: float = 3e-4,
|
||||
buffer_size: int = 100000,
|
||||
learning_starts: int = 100,
|
||||
|
|
@ -81,6 +82,7 @@ class SACAgent:
|
|||
self.device = device
|
||||
self.learning_starts = learning_starts
|
||||
self.total_steps = 0
|
||||
self.seed = seed
|
||||
|
||||
dummy_env = MultiDiscreteWrapper(state_dim, action_dims)
|
||||
policy_kwargs = {
|
||||
|
|
@ -96,6 +98,7 @@ class SACAgent:
|
|||
self.model = SAC(
|
||||
"MlpPolicy",
|
||||
env=dummy_env,
|
||||
seed=seed,
|
||||
learning_rate=learning_rate,
|
||||
buffer_size=buffer_size,
|
||||
learning_starts=learning_starts,
|
||||
|
|
|
|||
|
|
@ -232,6 +232,7 @@ class SCTD3Agent:
|
|||
self,
|
||||
state_dim: int,
|
||||
action_dims: list,
|
||||
seed: int | None = None,
|
||||
learning_rate: float = 3e-4,
|
||||
buffer_size: int = 100000,
|
||||
learning_starts: int = 100,
|
||||
|
|
@ -260,6 +261,7 @@ class SCTD3Agent:
|
|||
self.learning_starts = learning_starts
|
||||
self.total_steps = 0
|
||||
self.exploration_sigma = exploration_sigma
|
||||
self.seed = seed
|
||||
|
||||
dummy_env = MultiDiscreteWrapper(state_dim, action_dims)
|
||||
action_noise = NormalActionNoise(
|
||||
|
|
@ -292,6 +294,7 @@ class SCTD3Agent:
|
|||
self.model = TD3(
|
||||
"MlpPolicy",
|
||||
env=dummy_env,
|
||||
seed=seed,
|
||||
learning_rate=learning_rate,
|
||||
buffer_size=buffer_size,
|
||||
learning_starts=learning_starts,
|
||||
|
|
|
|||
|
|
@ -64,6 +64,7 @@ class TD3Agent:
|
|||
self,
|
||||
state_dim: int,
|
||||
action_dims: list,
|
||||
seed: int | None = None,
|
||||
learning_rate: float = 3e-4,
|
||||
buffer_size: int = 100000,
|
||||
learning_starts: int = 100,
|
||||
|
|
@ -84,6 +85,7 @@ class TD3Agent:
|
|||
self.learning_starts = learning_starts
|
||||
self.total_steps = 0
|
||||
self.exploration_sigma = exploration_sigma
|
||||
self.seed = seed
|
||||
|
||||
dummy_env = MultiDiscreteWrapper(state_dim, action_dims)
|
||||
action_noise = NormalActionNoise(
|
||||
|
|
@ -103,6 +105,7 @@ class TD3Agent:
|
|||
self.model = TD3(
|
||||
"MlpPolicy",
|
||||
env=dummy_env,
|
||||
seed=seed,
|
||||
learning_rate=learning_rate,
|
||||
buffer_size=buffer_size,
|
||||
learning_starts=learning_starts,
|
||||
|
|
|
|||
|
|
@ -22,6 +22,7 @@ from utils.episode_artifacts import save_training_episode_artifacts
|
|||
from utils.logger import TrainingLogger
|
||||
from utils.plot import plot_training_curves
|
||||
from utils.run_dirs import resolve_run_dirs, write_shared_run_config
|
||||
from utils.seeding import derive_seed, resolve_base_seed, set_global_seed
|
||||
|
||||
|
||||
def train_sumo_appo(log_dir=None, checkpoint_dir=None, run_timestamp=None):
|
||||
|
|
@ -31,6 +32,8 @@ def train_sumo_appo(log_dir=None, checkpoint_dir=None, run_timestamp=None):
|
|||
|
||||
agent_config = get_agent_config(config, "appo")
|
||||
train_config = get_training_config(config)
|
||||
base_seed = resolve_base_seed(train_config)
|
||||
set_global_seed(base_seed)
|
||||
|
||||
start_episode = 1
|
||||
_, checkpoint_dir, log_dir = resolve_run_dirs(
|
||||
|
|
@ -96,7 +99,6 @@ def train_sumo_appo(log_dir=None, checkpoint_dir=None, run_timestamp=None):
|
|||
num_episodes = train_config["num_episodes"]
|
||||
save_freq = train_config.get("save_freq", 50)
|
||||
log_freq = train_config.get("log_freq", 10)
|
||||
base_seed = train_config.get("random_seed", 42)
|
||||
|
||||
# 统计变量
|
||||
episode_rewards = []
|
||||
|
|
@ -114,7 +116,7 @@ def train_sumo_appo(log_dir=None, checkpoint_dir=None, run_timestamp=None):
|
|||
try:
|
||||
for episode in range(start_episode, num_episodes + 1):
|
||||
# 每个 episode 使用不同 seed 引入随机性
|
||||
seed = base_seed + episode
|
||||
seed = derive_seed(base_seed, episode)
|
||||
state = env.reset(seed=seed)
|
||||
episode_reward = 0
|
||||
episode_throughput = 0
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@ from utils.episode_artifacts import save_training_episode_artifacts
|
|||
from utils.logger import TrainingLogger
|
||||
from utils.plot import plot_training_curves
|
||||
from utils.run_dirs import resolve_run_dirs, write_shared_run_config
|
||||
from utils.seeding import derive_seed, resolve_base_seed, set_global_seed
|
||||
|
||||
|
||||
def train_sumo_dcmappo(log_dir=None, checkpoint_dir=None, run_timestamp=None):
|
||||
|
|
@ -26,6 +27,8 @@ def train_sumo_dcmappo(log_dir=None, checkpoint_dir=None, run_timestamp=None):
|
|||
|
||||
agent_config = get_agent_config(config, "dcmappo")
|
||||
train_config = get_training_config(config)
|
||||
base_seed = resolve_base_seed(train_config)
|
||||
set_global_seed(base_seed)
|
||||
|
||||
_, checkpoint_dir, log_dir = resolve_run_dirs(
|
||||
"dcmappo",
|
||||
|
|
@ -60,6 +63,7 @@ def train_sumo_dcmappo(log_dir=None, checkpoint_dir=None, run_timestamp=None):
|
|||
print(f" Corridor blocks: {agent_config.get('num_corridor_blocks', 2)}")
|
||||
print(f" LR: {agent_config.get('learning_rate', 3e-4)}")
|
||||
print(f" Device: {agent_config.get('device', 'cuda')}")
|
||||
print(f" Global seed: {base_seed if base_seed is not None else 'None (random)'}")
|
||||
print()
|
||||
|
||||
agent = DCMAPPOAgent(
|
||||
|
|
@ -91,8 +95,6 @@ def train_sumo_dcmappo(log_dir=None, checkpoint_dir=None, run_timestamp=None):
|
|||
num_episodes = train_config["num_episodes"]
|
||||
save_freq = train_config.get("save_freq", 50)
|
||||
log_freq = train_config.get("log_freq", 10)
|
||||
base_seed = train_config.get("random_seed", 42)
|
||||
|
||||
episode_rewards = []
|
||||
episode_throughputs = []
|
||||
episode_mean_speeds = []
|
||||
|
|
@ -107,7 +109,7 @@ def train_sumo_dcmappo(log_dir=None, checkpoint_dir=None, run_timestamp=None):
|
|||
|
||||
try:
|
||||
for episode in range(1, num_episodes + 1):
|
||||
seed = base_seed + episode
|
||||
seed = derive_seed(base_seed, episode)
|
||||
state = env.reset(seed=seed)
|
||||
episode_reward = 0.0
|
||||
episode_throughput = 0.0
|
||||
|
|
|
|||
|
|
@ -19,6 +19,7 @@ from utils.episode_artifacts import save_training_episode_artifacts
|
|||
from utils.logger import TrainingLogger
|
||||
from utils.plot import plot_training_curves
|
||||
from utils.run_dirs import resolve_run_dirs, write_shared_run_config
|
||||
from utils.seeding import derive_seed, resolve_base_seed, set_global_seed
|
||||
|
||||
|
||||
def train_sumo_ddpg(log_dir=None, checkpoint_dir=None, run_timestamp=None):
|
||||
|
|
@ -28,6 +29,8 @@ def train_sumo_ddpg(log_dir=None, checkpoint_dir=None, run_timestamp=None):
|
|||
|
||||
agent_config = get_agent_config(config, "ddpg")
|
||||
train_config = get_training_config(config)
|
||||
base_seed = resolve_base_seed(train_config)
|
||||
set_global_seed(base_seed)
|
||||
|
||||
_, checkpoint_dir, log_dir = resolve_run_dirs(
|
||||
"ddpg",
|
||||
|
|
@ -78,13 +81,12 @@ def train_sumo_ddpg(log_dir=None, checkpoint_dir=None, run_timestamp=None):
|
|||
actor_hidden_dims=agent_config.get("actor_hidden_dims"),
|
||||
critic_hidden_dims=agent_config.get("critic_hidden_dims"),
|
||||
activation_fn=agent_config.get("activation_fn", "relu"),
|
||||
seed=base_seed,
|
||||
)
|
||||
|
||||
num_episodes = train_config["num_episodes"]
|
||||
save_freq = train_config.get("save_freq", 50)
|
||||
log_freq = train_config.get("log_freq", 10)
|
||||
base_seed = train_config.get("random_seed", 42)
|
||||
|
||||
episode_rewards = []
|
||||
episode_throughputs = []
|
||||
episode_mean_speeds = []
|
||||
|
|
@ -96,7 +98,7 @@ def train_sumo_ddpg(log_dir=None, checkpoint_dir=None, run_timestamp=None):
|
|||
|
||||
try:
|
||||
for episode in range(1, num_episodes + 1):
|
||||
seed = base_seed + episode
|
||||
seed = derive_seed(base_seed, episode)
|
||||
state = env.reset(seed=seed)
|
||||
episode_reward = 0
|
||||
episode_throughput = 0
|
||||
|
|
|
|||
|
|
@ -16,6 +16,7 @@ from utils.episode_artifacts import save_training_episode_artifacts
|
|||
from utils.logger import TrainingLogger
|
||||
from utils.plot import plot_training_curves
|
||||
from utils.run_dirs import resolve_run_dirs, write_shared_run_config
|
||||
from utils.seeding import derive_seed, resolve_base_seed, set_global_seed
|
||||
|
||||
|
||||
def train_sumo_gpro(log_dir=None, checkpoint_dir=None, run_timestamp=None):
|
||||
|
|
@ -25,6 +26,8 @@ def train_sumo_gpro(log_dir=None, checkpoint_dir=None, run_timestamp=None):
|
|||
|
||||
agent_config = get_agent_config(config, "gpro")
|
||||
train_config = get_training_config(config)
|
||||
base_seed = resolve_base_seed(train_config)
|
||||
set_global_seed(base_seed)
|
||||
|
||||
_, checkpoint_dir, log_dir = resolve_run_dirs(
|
||||
"gpro",
|
||||
|
|
@ -90,8 +93,6 @@ def train_sumo_gpro(log_dir=None, checkpoint_dir=None, run_timestamp=None):
|
|||
num_episodes = train_config["num_episodes"]
|
||||
save_freq = train_config.get("save_freq", 50)
|
||||
log_freq = train_config.get("log_freq", 10)
|
||||
base_seed = train_config.get("random_seed", 42)
|
||||
|
||||
episode_rewards = []
|
||||
episode_throughputs = []
|
||||
episode_mean_speeds = []
|
||||
|
|
@ -107,7 +108,7 @@ def train_sumo_gpro(log_dir=None, checkpoint_dir=None, run_timestamp=None):
|
|||
try:
|
||||
pending_log_rows = []
|
||||
for group_start in range(1, num_episodes + 1, group_size):
|
||||
group_seed = base_seed + ((group_start - 1) // group_size) + 1
|
||||
group_seed = derive_seed(base_seed, ((group_start - 1) // group_size) + 1)
|
||||
group_end = min(group_start + group_size - 1, num_episodes)
|
||||
pending_log_rows.clear()
|
||||
|
||||
|
|
|
|||
|
|
@ -19,6 +19,7 @@ from utils.episode_artifacts import save_training_episode_artifacts
|
|||
from utils.logger import TrainingLogger
|
||||
from utils.plot import plot_training_curves
|
||||
from utils.run_dirs import resolve_run_dirs, write_shared_run_config
|
||||
from utils.seeding import derive_seed, resolve_base_seed, set_global_seed
|
||||
|
||||
|
||||
def train_sumo_mappo(log_dir=None, checkpoint_dir=None, run_timestamp=None):
|
||||
|
|
@ -27,6 +28,8 @@ def train_sumo_mappo(log_dir=None, checkpoint_dir=None, run_timestamp=None):
|
|||
|
||||
agent_config = get_agent_config(config, "mappo")
|
||||
train_config = get_training_config(config)
|
||||
base_seed = resolve_base_seed(train_config)
|
||||
set_global_seed(base_seed)
|
||||
|
||||
_, checkpoint_dir, log_dir = resolve_run_dirs(
|
||||
"mappo",
|
||||
|
|
@ -60,6 +63,7 @@ def train_sumo_mappo(log_dir=None, checkpoint_dir=None, run_timestamp=None):
|
|||
print(f" Hidden dim: {agent_config.get('hidden_dim', 256)}")
|
||||
print(f" LR: {agent_config.get('learning_rate', 3e-4)}")
|
||||
print(f" Device: {agent_config.get('device', 'cuda')}")
|
||||
print(f" Global seed: {base_seed if base_seed is not None else 'None (random)'}")
|
||||
print()
|
||||
|
||||
agent = MAPPOAgent(
|
||||
|
|
@ -88,8 +92,6 @@ def train_sumo_mappo(log_dir=None, checkpoint_dir=None, run_timestamp=None):
|
|||
num_episodes = train_config["num_episodes"]
|
||||
save_freq = train_config.get("save_freq", 50)
|
||||
log_freq = train_config.get("log_freq", 10)
|
||||
base_seed = train_config.get("random_seed", 42)
|
||||
|
||||
episode_rewards = []
|
||||
episode_throughputs = []
|
||||
episode_mean_speeds = []
|
||||
|
|
@ -104,7 +106,7 @@ def train_sumo_mappo(log_dir=None, checkpoint_dir=None, run_timestamp=None):
|
|||
|
||||
try:
|
||||
for episode in range(1, num_episodes + 1):
|
||||
seed = base_seed + episode
|
||||
seed = derive_seed(base_seed, episode)
|
||||
state = env.reset(seed=seed)
|
||||
episode_reward = 0.0
|
||||
episode_throughput = 0.0
|
||||
|
|
|
|||
|
|
@ -22,6 +22,7 @@ from utils.episode_artifacts import save_training_episode_artifacts
|
|||
from utils.logger import TrainingLogger
|
||||
from utils.plot import plot_training_curves
|
||||
from utils.run_dirs import resolve_run_dirs, write_shared_run_config
|
||||
from utils.seeding import derive_seed, resolve_base_seed, set_global_seed
|
||||
|
||||
|
||||
def train_sumo_ppo(log_dir=None, checkpoint_dir=None, run_timestamp=None):
|
||||
|
|
@ -32,6 +33,8 @@ def train_sumo_ppo(log_dir=None, checkpoint_dir=None, run_timestamp=None):
|
|||
|
||||
agent_config = get_agent_config(config, "ppo")
|
||||
train_config = get_training_config(config)
|
||||
base_seed = resolve_base_seed(train_config)
|
||||
set_global_seed(base_seed)
|
||||
|
||||
start_episode = 1
|
||||
_, checkpoint_dir, log_dir = resolve_run_dirs(
|
||||
|
|
@ -96,7 +99,6 @@ def train_sumo_ppo(log_dir=None, checkpoint_dir=None, run_timestamp=None):
|
|||
num_episodes = train_config["num_episodes"]
|
||||
save_freq = train_config.get("save_freq", 50)
|
||||
log_freq = train_config.get("log_freq", 10)
|
||||
base_seed = train_config.get("random_seed", 42)
|
||||
|
||||
# 统计变量
|
||||
episode_rewards = []
|
||||
|
|
@ -114,7 +116,7 @@ def train_sumo_ppo(log_dir=None, checkpoint_dir=None, run_timestamp=None):
|
|||
try:
|
||||
for episode in range(start_episode, num_episodes + 1):
|
||||
# 每个 episode 使用不同 seed 引入随机性
|
||||
seed = base_seed + episode
|
||||
seed = derive_seed(base_seed, episode)
|
||||
state = env.reset(seed=seed)
|
||||
episode_reward = 0
|
||||
episode_throughput = 0
|
||||
|
|
|
|||
|
|
@ -16,6 +16,7 @@ from utils.episode_artifacts import save_training_episode_artifacts
|
|||
from utils.logger import TrainingLogger
|
||||
from utils.plot import plot_training_curves
|
||||
from utils.run_dirs import resolve_run_dirs, write_shared_run_config
|
||||
from utils.seeding import derive_seed, resolve_base_seed, set_global_seed
|
||||
|
||||
|
||||
def train_sumo_sac(log_dir=None, checkpoint_dir=None, run_timestamp=None):
|
||||
|
|
@ -25,6 +26,8 @@ def train_sumo_sac(log_dir=None, checkpoint_dir=None, run_timestamp=None):
|
|||
|
||||
agent_config = get_agent_config(config, "sac")
|
||||
train_config = get_training_config(config)
|
||||
base_seed = resolve_base_seed(train_config)
|
||||
set_global_seed(base_seed)
|
||||
|
||||
_, checkpoint_dir, log_dir = resolve_run_dirs(
|
||||
"sac",
|
||||
|
|
@ -78,13 +81,12 @@ def train_sumo_sac(log_dir=None, checkpoint_dir=None, run_timestamp=None):
|
|||
actor_hidden_dims=agent_config.get("actor_hidden_dims"),
|
||||
critic_hidden_dims=agent_config.get("critic_hidden_dims"),
|
||||
activation_fn=agent_config.get("activation_fn", "relu"),
|
||||
seed=base_seed,
|
||||
)
|
||||
|
||||
num_episodes = train_config["num_episodes"]
|
||||
save_freq = train_config.get("save_freq", 50)
|
||||
log_freq = train_config.get("log_freq", 10)
|
||||
base_seed = train_config.get("random_seed", 42)
|
||||
|
||||
episode_rewards = []
|
||||
episode_throughputs = []
|
||||
episode_mean_speeds = []
|
||||
|
|
@ -96,7 +98,7 @@ def train_sumo_sac(log_dir=None, checkpoint_dir=None, run_timestamp=None):
|
|||
|
||||
try:
|
||||
for episode in range(1, num_episodes + 1):
|
||||
seed = base_seed + episode
|
||||
seed = derive_seed(base_seed, episode)
|
||||
state = env.reset(seed=seed)
|
||||
episode_reward = 0.0
|
||||
episode_throughput = 0.0
|
||||
|
|
|
|||
|
|
@ -19,6 +19,7 @@ from utils.episode_artifacts import save_training_episode_artifacts
|
|||
from utils.logger import TrainingLogger
|
||||
from utils.plot import plot_training_curves
|
||||
from utils.run_dirs import resolve_run_dirs, write_shared_run_config
|
||||
from utils.seeding import derive_seed, resolve_base_seed, set_global_seed
|
||||
|
||||
|
||||
def train_sumo_tcamappo(log_dir=None, checkpoint_dir=None, run_timestamp=None):
|
||||
|
|
@ -27,6 +28,8 @@ def train_sumo_tcamappo(log_dir=None, checkpoint_dir=None, run_timestamp=None):
|
|||
|
||||
agent_config = get_agent_config(config, "tcamappo")
|
||||
train_config = get_training_config(config)
|
||||
base_seed = resolve_base_seed(train_config)
|
||||
set_global_seed(base_seed)
|
||||
|
||||
_, checkpoint_dir, log_dir = resolve_run_dirs(
|
||||
"tcamappo",
|
||||
|
|
@ -62,6 +65,7 @@ def train_sumo_tcamappo(log_dir=None, checkpoint_dir=None, run_timestamp=None):
|
|||
print(f" Critic heads/layers: {agent_config.get('critic_num_heads', 4)}/{agent_config.get('critic_num_layers', 2)}")
|
||||
print(f" LR: {agent_config.get('learning_rate', 3e-4)}")
|
||||
print(f" Device: {agent_config.get('device', 'cuda')}")
|
||||
print(f" Global seed: {base_seed if base_seed is not None else 'None (random)'}")
|
||||
print()
|
||||
|
||||
agent = TCAMAPPOAgent(
|
||||
|
|
@ -94,8 +98,6 @@ def train_sumo_tcamappo(log_dir=None, checkpoint_dir=None, run_timestamp=None):
|
|||
num_episodes = train_config["num_episodes"]
|
||||
save_freq = train_config.get("save_freq", 50)
|
||||
log_freq = train_config.get("log_freq", 10)
|
||||
base_seed = train_config.get("random_seed", 42)
|
||||
|
||||
episode_rewards = []
|
||||
episode_throughputs = []
|
||||
episode_mean_speeds = []
|
||||
|
|
@ -110,7 +112,7 @@ def train_sumo_tcamappo(log_dir=None, checkpoint_dir=None, run_timestamp=None):
|
|||
|
||||
try:
|
||||
for episode in range(1, num_episodes + 1):
|
||||
seed = base_seed + episode
|
||||
seed = derive_seed(base_seed, episode)
|
||||
state = env.reset(seed=seed)
|
||||
agent.reset_episode()
|
||||
|
||||
|
|
|
|||
|
|
@ -20,6 +20,7 @@ from utils.episode_artifacts import save_training_episode_artifacts
|
|||
from utils.logger import TrainingLogger
|
||||
from utils.plot import plot_training_curves
|
||||
from utils.run_dirs import resolve_run_dirs, write_shared_run_config
|
||||
from utils.seeding import derive_seed, resolve_base_seed, set_global_seed
|
||||
|
||||
|
||||
def train_sumo_td3(
|
||||
|
|
@ -37,6 +38,8 @@ def train_sumo_td3(
|
|||
|
||||
agent_config = get_agent_config(config, config_key)
|
||||
train_config = get_training_config(config)
|
||||
base_seed = resolve_base_seed(train_config)
|
||||
set_global_seed(base_seed)
|
||||
|
||||
_, checkpoint_dir, log_dir = resolve_run_dirs(
|
||||
model_name,
|
||||
|
|
@ -102,12 +105,12 @@ def train_sumo_td3(
|
|||
extractor_kernel_size=agent_config.get("extractor_kernel_size", 3),
|
||||
)
|
||||
|
||||
common_kwargs["seed"] = base_seed
|
||||
agent = agent_class(**common_kwargs)
|
||||
|
||||
num_episodes = train_config["num_episodes"]
|
||||
save_freq = train_config.get("save_freq", 50)
|
||||
log_freq = train_config.get("log_freq", 10)
|
||||
base_seed = train_config.get("random_seed", 42)
|
||||
|
||||
episode_rewards = []
|
||||
episode_throughputs = []
|
||||
|
|
@ -120,7 +123,7 @@ def train_sumo_td3(
|
|||
|
||||
try:
|
||||
for episode in range(1, num_episodes + 1):
|
||||
seed = base_seed + episode
|
||||
seed = derive_seed(base_seed, episode)
|
||||
state = env.reset(seed=seed)
|
||||
episode_reward = 0
|
||||
episode_throughput = 0
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@ from utils.episode_artifacts import save_training_episode_artifacts
|
|||
from utils.logger import TrainingLogger
|
||||
from utils.plot import plot_training_curves
|
||||
from utils.run_dirs import resolve_run_dirs, write_shared_run_config
|
||||
from utils.seeding import derive_seed, resolve_base_seed, set_global_seed
|
||||
|
||||
matplotlib.use("Agg")
|
||||
|
||||
|
|
@ -72,6 +73,8 @@ def train_sumo_value_based(
|
|||
|
||||
agent_config = get_agent_config(config, model_key)
|
||||
train_config = get_training_config(config)
|
||||
base_seed = resolve_base_seed(train_config)
|
||||
set_global_seed(base_seed)
|
||||
|
||||
_, checkpoint_dir, log_dir = resolve_run_dirs(
|
||||
model_key,
|
||||
|
|
@ -112,12 +115,14 @@ def train_sumo_value_based(
|
|||
print(f" Device: {agent_config.get('device', 'cuda')}")
|
||||
print()
|
||||
|
||||
print(f" Global seed: {base_seed if base_seed is not None else 'None (random)'}")
|
||||
print()
|
||||
|
||||
agent = _build_value_based_agent(agent_builder, env, agent_config)
|
||||
|
||||
num_episodes = train_config["num_episodes"]
|
||||
save_freq = train_config.get("save_freq", 50)
|
||||
log_freq = train_config.get("log_freq", 10)
|
||||
base_seed = train_config.get("random_seed", 42)
|
||||
|
||||
episode_rewards = []
|
||||
episode_throughputs = []
|
||||
|
|
@ -131,7 +136,7 @@ def train_sumo_value_based(
|
|||
|
||||
try:
|
||||
for episode in range(1, num_episodes + 1):
|
||||
seed = base_seed + episode
|
||||
seed = derive_seed(base_seed, episode)
|
||||
state = env.reset(seed=seed)
|
||||
episode_reward = 0.0
|
||||
episode_throughput = 0.0
|
||||
|
|
|
|||
|
|
@ -0,0 +1,52 @@
|
|||
"""Utilities for reproducible experiment seeding."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import random
|
||||
from typing import Mapping
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
def resolve_base_seed(training_cfg: Mapping[str, object], default: int = 42) -> int | None:
|
||||
"""Return the configured base seed, preserving explicit null as true random."""
|
||||
if "random_seed" not in training_cfg:
|
||||
return int(default)
|
||||
seed = training_cfg.get("random_seed")
|
||||
if seed is None:
|
||||
return None
|
||||
return int(seed)
|
||||
|
||||
|
||||
def derive_seed(base_seed: int | None, offset: int = 0) -> int | None:
|
||||
"""Derive a deterministic child seed from a base seed."""
|
||||
if base_seed is None:
|
||||
return None
|
||||
return int(base_seed) + int(offset)
|
||||
|
||||
|
||||
def set_global_seed(seed: int | None, *, deterministic_torch: bool = True) -> None:
|
||||
"""Seed Python, NumPy and PyTorch RNGs for reproducible training."""
|
||||
if seed is None:
|
||||
return
|
||||
|
||||
seed = int(seed)
|
||||
os.environ["PYTHONHASHSEED"] = str(seed)
|
||||
os.environ.setdefault("CUBLAS_WORKSPACE_CONFIG", ":4096:8")
|
||||
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
|
||||
if deterministic_torch:
|
||||
torch.backends.cudnn.deterministic = True
|
||||
torch.backends.cudnn.benchmark = False
|
||||
try:
|
||||
torch.use_deterministic_algorithms(True, warn_only=True)
|
||||
except Exception:
|
||||
pass
|
||||
Loading…
Reference in New Issue