diff --git a/run_all_training.py b/run_all_training.py index ed10498..b14ff58 100644 --- a/run_all_training.py +++ b/run_all_training.py @@ -1,24 +1,50 @@ -"""一键异步启动全部训练""" +"""一键异步启动全部训练。""" +import os import subprocess import sys from datetime import datetime AGENTS = ["ppo", "appo", "mappo", "dqn", "ddpg", "td3"] -processes = {} -for agent in AGENTS: - print(f"[{datetime.now().strftime('%H:%M:%S')}] 启动 {agent.upper()} 训练...") - p = subprocess.Popen( - [sys.executable, "-m", f"training.train_{agent}"], - stdout=open(f"logs_{agent}_stdout.txt", "w"), - stderr=subprocess.STDOUT, - ) - processes[agent] = p - print(f" PID: {p.pid}") -print(f"\n全部 {len(AGENTS)} 个训练已启动,等待完成...\n") +def main(): + run_timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + processes = {} -for agent, p in processes.items(): - p.wait() - status = "✓ 完成" if p.returncode == 0 else f"✗ 失败(code={p.returncode})" - print(f"[{agent.upper()}] {status}") + for agent in AGENTS: + print(f"[{datetime.now().strftime('%H:%M:%S')}] 启动 {agent.upper()} 训练...") + log_dir = os.path.join("logs", "multi-model", run_timestamp, agent) + checkpoint_dir = os.path.join("checkpoints", "multi-model", run_timestamp, agent) + os.makedirs(log_dir, exist_ok=True) + os.makedirs(checkpoint_dir, exist_ok=True) + stdout_path = os.path.join(log_dir, "stdout.txt") + + process = subprocess.Popen( + [ + sys.executable, + "-m", + f"training.train_{agent}", + "--log-dir", + log_dir, + "--checkpoint-dir", + checkpoint_dir, + "--run-timestamp", + run_timestamp, + ], + stdout=open(stdout_path, "w", encoding="utf-8"), + stderr=subprocess.STDOUT, + ) + processes[agent] = process + print(f" PID: {process.pid}") + + print(f"\n全部 {len(AGENTS)} 个训练已启动,等待完成...") + print(f"本次多模型时间戳: {run_timestamp}\n") + + for agent, process in processes.items(): + process.wait() + status = "完成" if process.returncode == 0 else f"失败(code={process.returncode})" + print(f"[{agent.upper()}] {status}") + + +if __name__ == "__main__": + main() diff --git a/training/train_appo.py b/training/train_appo.py index 288e701..d4264b5 100644 --- a/training/train_appo.py +++ b/training/train_appo.py @@ -2,6 +2,7 @@ 基于 SUMO+TraCI 的 APPO 训练脚本 使用微观仿真环境训练 VSL 控制策略 """ +import argparse import os import sys import copy @@ -19,9 +20,10 @@ from agents.appo_agent import APPOAgent from utils.config import get_agent_config, get_training_config from utils.logger import TrainingLogger from utils.plot import plot_training_curves +from utils.run_dirs import add_run_dir_args, resolve_run_dirs -def train_sumo_appo(): +def train_sumo_appo(log_dir=None, checkpoint_dir=None, run_timestamp=None): """SUMO 环境下的 APPO 训练主函数""" with open("config_sumo_vsl.yaml", "r", encoding="utf-8") as f: config = yaml.safe_load(f) @@ -30,9 +32,12 @@ def train_sumo_appo(): train_config = get_training_config(config) start_episode = 1 - timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") - checkpoint_dir = os.path.join("checkpoints", "appo", timestamp) - log_dir = os.path.join("logs", "appo", timestamp) + _, checkpoint_dir, log_dir = resolve_run_dirs( + "appo", + log_dir=log_dir, + checkpoint_dir=checkpoint_dir, + run_timestamp=run_timestamp, + ) os.makedirs(checkpoint_dir, exist_ok=True) os.makedirs(log_dir, exist_ok=True) runtime_config = copy.deepcopy(config) @@ -246,4 +251,10 @@ def train_sumo_appo(): if __name__ == "__main__": - train_sumo_appo() + parser = add_run_dir_args(argparse.ArgumentParser()) + args = parser.parse_args() + train_sumo_appo( + log_dir=args.log_dir, + checkpoint_dir=args.checkpoint_dir, + run_timestamp=args.run_timestamp, + ) diff --git a/training/train_ddpg.py b/training/train_ddpg.py index 258c5e0..6fc822e 100644 --- a/training/train_ddpg.py +++ b/training/train_ddpg.py @@ -2,6 +2,7 @@ 基于 SUMO+TraCI 的 DDPG 训练脚本 使用 Stable-Baselines3 的 DDPG 算法 """ +import argparse import os import copy import yaml @@ -16,9 +17,10 @@ from agents.ddpg_agent import DDPGAgent from utils.config import get_agent_config, get_training_config from utils.logger import TrainingLogger from utils.plot import plot_training_curves +from utils.run_dirs import add_run_dir_args, resolve_run_dirs -def train_sumo_ddpg(): +def train_sumo_ddpg(log_dir=None, checkpoint_dir=None, run_timestamp=None): """SUMO 环境下的 DDPG 训练主函数""" with open("config_sumo_vsl.yaml", "r", encoding="utf-8") as f: config = yaml.safe_load(f) @@ -26,9 +28,12 @@ def train_sumo_ddpg(): agent_config = get_agent_config(config, "ddpg") train_config = get_training_config(config) - timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") - checkpoint_dir = os.path.join("checkpoints", "ddpg", timestamp) - log_dir = os.path.join("logs", "ddpg", timestamp) + _, checkpoint_dir, log_dir = resolve_run_dirs( + "ddpg", + log_dir=log_dir, + checkpoint_dir=checkpoint_dir, + run_timestamp=run_timestamp, + ) os.makedirs(checkpoint_dir, exist_ok=True) os.makedirs(log_dir, exist_ok=True) runtime_config = copy.deepcopy(config) @@ -183,4 +188,10 @@ def train_sumo_ddpg(): if __name__ == "__main__": - train_sumo_ddpg() + parser = add_run_dir_args(argparse.ArgumentParser()) + args = parser.parse_args() + train_sumo_ddpg( + log_dir=args.log_dir, + checkpoint_dir=args.checkpoint_dir, + run_timestamp=args.run_timestamp, + ) diff --git a/training/train_dqn.py b/training/train_dqn.py index ff07d31..7d5d014 100644 --- a/training/train_dqn.py +++ b/training/train_dqn.py @@ -1,6 +1,7 @@ """ DQN训练脚本 - SUMO VSL环境 """ +import argparse import os import sys import copy @@ -17,9 +18,10 @@ from agents.dqn_agent import DQNAgent from utils.config import get_agent_config, get_training_config from utils.logger import TrainingLogger from utils.plot import plot_training_curves +from utils.run_dirs import add_run_dir_args, resolve_run_dirs -def train_sumo_dqn(): +def train_sumo_dqn(log_dir=None, checkpoint_dir=None, run_timestamp=None): with open("config_sumo_vsl.yaml", "r", encoding="utf-8") as f: config = yaml.safe_load(f) @@ -27,9 +29,12 @@ def train_sumo_dqn(): train_config = get_training_config(config) start_episode = 1 - timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") - checkpoint_dir = os.path.join("checkpoints", "dqn", timestamp) - log_dir = os.path.join("logs", "dqn", timestamp) + _, checkpoint_dir, log_dir = resolve_run_dirs( + "dqn", + log_dir=log_dir, + checkpoint_dir=checkpoint_dir, + run_timestamp=run_timestamp, + ) os.makedirs(checkpoint_dir, exist_ok=True) os.makedirs(log_dir, exist_ok=True) runtime_config = copy.deepcopy(config) @@ -203,4 +208,10 @@ def train_sumo_dqn(): if __name__ == "__main__": - train_sumo_dqn() + parser = add_run_dir_args(argparse.ArgumentParser()) + args = parser.parse_args() + train_sumo_dqn( + log_dir=args.log_dir, + checkpoint_dir=args.checkpoint_dir, + run_timestamp=args.run_timestamp, + ) diff --git a/training/train_mappo.py b/training/train_mappo.py index 8e48d60..3f47d2d 100644 --- a/training/train_mappo.py +++ b/training/train_mappo.py @@ -1,6 +1,7 @@ """ MAPPO training script for SUMO + TraCI VSL. """ +import argparse import os import copy import yaml @@ -16,18 +17,22 @@ from agents.mappo_agent import MAPPOAgent from utils.config import get_agent_config, get_training_config from utils.logger import TrainingLogger from utils.plot import plot_training_curves +from utils.run_dirs import add_run_dir_args, resolve_run_dirs -def train_sumo_mappo(): +def train_sumo_mappo(log_dir=None, checkpoint_dir=None, run_timestamp=None): with open("config_sumo_vsl.yaml", "r", encoding="utf-8") as f: config = yaml.safe_load(f) agent_config = get_agent_config(config, "mappo") train_config = get_training_config(config) - timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") - checkpoint_dir = os.path.join("checkpoints", "mappo", timestamp) - log_dir = os.path.join("logs", "mappo", timestamp) + _, checkpoint_dir, log_dir = resolve_run_dirs( + "mappo", + log_dir=log_dir, + checkpoint_dir=checkpoint_dir, + run_timestamp=run_timestamp, + ) os.makedirs(checkpoint_dir, exist_ok=True) os.makedirs(log_dir, exist_ok=True) runtime_config = copy.deepcopy(config) @@ -235,4 +240,10 @@ def train_sumo_mappo(): if __name__ == "__main__": - train_sumo_mappo() + parser = add_run_dir_args(argparse.ArgumentParser()) + args = parser.parse_args() + train_sumo_mappo( + log_dir=args.log_dir, + checkpoint_dir=args.checkpoint_dir, + run_timestamp=args.run_timestamp, + ) diff --git a/training/train_ppo.py b/training/train_ppo.py index 4e86907..d04f156 100644 --- a/training/train_ppo.py +++ b/training/train_ppo.py @@ -2,6 +2,7 @@ 基于 SUMO+TraCI 的 PPO 训练脚本 使用微观仿真环境训练 VSL 控制策略 """ +import argparse import os import sys import copy @@ -19,9 +20,10 @@ from agents.ppo_agent import PPOAgent from utils.config import get_agent_config, get_training_config from utils.logger import TrainingLogger from utils.plot import plot_training_curves +from utils.run_dirs import add_run_dir_args, resolve_run_dirs -def train_sumo_ppo(): +def train_sumo_ppo(log_dir=None, checkpoint_dir=None, run_timestamp=None): """SUMO 环境下的 PPO 训练主函数""" # 加载配置 with open("config_sumo_vsl.yaml", "r", encoding="utf-8") as f: @@ -31,9 +33,12 @@ def train_sumo_ppo(): train_config = get_training_config(config) start_episode = 1 - timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") - checkpoint_dir = os.path.join("checkpoints", "ppo", timestamp) - log_dir = os.path.join("logs", "ppo", timestamp) + _, checkpoint_dir, log_dir = resolve_run_dirs( + "ppo", + log_dir=log_dir, + checkpoint_dir=checkpoint_dir, + run_timestamp=run_timestamp, + ) os.makedirs(checkpoint_dir, exist_ok=True) os.makedirs(log_dir, exist_ok=True) runtime_config = copy.deepcopy(config) @@ -249,4 +254,10 @@ def train_sumo_ppo(): if __name__ == "__main__": - train_sumo_ppo() + parser = add_run_dir_args(argparse.ArgumentParser()) + args = parser.parse_args() + train_sumo_ppo( + log_dir=args.log_dir, + checkpoint_dir=args.checkpoint_dir, + run_timestamp=args.run_timestamp, + ) diff --git a/training/train_td3.py b/training/train_td3.py index 3e429ee..7502cf0 100644 --- a/training/train_td3.py +++ b/training/train_td3.py @@ -2,6 +2,7 @@ 基于 SUMO+TraCI 的 TD3 训练脚本 使用 Stable-Baselines3 的 TD3 算法 """ +import argparse import os import copy import yaml @@ -17,9 +18,10 @@ from agents.td3_agent import TD3Agent from utils.config import get_agent_config, get_training_config from utils.logger import TrainingLogger from utils.plot import plot_training_curves +from utils.run_dirs import add_run_dir_args, resolve_run_dirs -def train_sumo_td3(): +def train_sumo_td3(log_dir=None, checkpoint_dir=None, run_timestamp=None): """SUMO 环境下的 TD3 训练主函数""" with open("config_sumo_vsl.yaml", "r", encoding="utf-8") as f: config = yaml.safe_load(f) @@ -27,9 +29,12 @@ def train_sumo_td3(): agent_config = get_agent_config(config, "td3") train_config = get_training_config(config) - timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") - checkpoint_dir = os.path.join("checkpoints", "td3", timestamp) - log_dir = os.path.join("logs", "td3", timestamp) + _, checkpoint_dir, log_dir = resolve_run_dirs( + "td3", + log_dir=log_dir, + checkpoint_dir=checkpoint_dir, + run_timestamp=run_timestamp, + ) os.makedirs(checkpoint_dir, exist_ok=True) os.makedirs(log_dir, exist_ok=True) runtime_config = copy.deepcopy(config) @@ -193,4 +198,10 @@ def train_sumo_td3(): if __name__ == "__main__": - train_sumo_td3() + parser = add_run_dir_args(argparse.ArgumentParser()) + args = parser.parse_args() + train_sumo_td3( + log_dir=args.log_dir, + checkpoint_dir=args.checkpoint_dir, + run_timestamp=args.run_timestamp, + ) diff --git a/utils/run_dirs.py b/utils/run_dirs.py new file mode 100644 index 0000000..2a31b1f --- /dev/null +++ b/utils/run_dirs.py @@ -0,0 +1,28 @@ +import argparse +import os +from datetime import datetime +from typing import Optional, Tuple + + +def add_run_dir_args(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: + parser.add_argument("--log-dir", type=str, default=None, help="Training log output directory.") + parser.add_argument("--checkpoint-dir", type=str, default=None, help="Model checkpoint output directory.") + parser.add_argument("--run-timestamp", type=str, default=None, help="Run timestamp tag for default directories.") + return parser + + +def resolve_run_dirs( + model_name: str, + log_dir: Optional[str] = None, + checkpoint_dir: Optional[str] = None, + run_timestamp: Optional[str] = None, +) -> Tuple[str, str, str]: + """Resolve output directories from runtime args, falling back to per-model defaults.""" + timestamp = run_timestamp or datetime.now().strftime("%Y%m%d_%H%M%S") + + if checkpoint_dir is None: + checkpoint_dir = os.path.join("checkpoints", model_name, timestamp) + if log_dir is None: + log_dir = os.path.join("logs", model_name, timestamp) + + return timestamp, checkpoint_dir, log_dir