修改同时训练时保存的目录

This commit is contained in:
Zihan Ye 2026-04-09 02:46:25 +08:00
parent 09c0dd86e5
commit 27502241ad
8 changed files with 166 additions and 46 deletions

View File

@ -1,24 +1,50 @@
"""一键异步启动全部训练""" """一键异步启动全部训练。"""
import os
import subprocess import subprocess
import sys import sys
from datetime import datetime from datetime import datetime
AGENTS = ["ppo", "appo", "mappo", "dqn", "ddpg", "td3"] AGENTS = ["ppo", "appo", "mappo", "dqn", "ddpg", "td3"]
processes = {}
for agent in AGENTS:
print(f"[{datetime.now().strftime('%H:%M:%S')}] 启动 {agent.upper()} 训练...")
p = subprocess.Popen(
[sys.executable, "-m", f"training.train_{agent}"],
stdout=open(f"logs_{agent}_stdout.txt", "w"),
stderr=subprocess.STDOUT,
)
processes[agent] = p
print(f" PID: {p.pid}")
print(f"\n全部 {len(AGENTS)} 个训练已启动,等待完成...\n") def main():
run_timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
processes = {}
for agent, p in processes.items(): for agent in AGENTS:
p.wait() print(f"[{datetime.now().strftime('%H:%M:%S')}] 启动 {agent.upper()} 训练...")
status = "✓ 完成" if p.returncode == 0 else f"✗ 失败(code={p.returncode})" log_dir = os.path.join("logs", "multi-model", run_timestamp, agent)
print(f"[{agent.upper()}] {status}") checkpoint_dir = os.path.join("checkpoints", "multi-model", run_timestamp, agent)
os.makedirs(log_dir, exist_ok=True)
os.makedirs(checkpoint_dir, exist_ok=True)
stdout_path = os.path.join(log_dir, "stdout.txt")
process = subprocess.Popen(
[
sys.executable,
"-m",
f"training.train_{agent}",
"--log-dir",
log_dir,
"--checkpoint-dir",
checkpoint_dir,
"--run-timestamp",
run_timestamp,
],
stdout=open(stdout_path, "w", encoding="utf-8"),
stderr=subprocess.STDOUT,
)
processes[agent] = process
print(f" PID: {process.pid}")
print(f"\n全部 {len(AGENTS)} 个训练已启动,等待完成...")
print(f"本次多模型时间戳: {run_timestamp}\n")
for agent, process in processes.items():
process.wait()
status = "完成" if process.returncode == 0 else f"失败(code={process.returncode})"
print(f"[{agent.upper()}] {status}")
if __name__ == "__main__":
main()

View File

@ -2,6 +2,7 @@
基于 SUMO+TraCI APPO 训练脚本 基于 SUMO+TraCI APPO 训练脚本
使用微观仿真环境训练 VSL 控制策略 使用微观仿真环境训练 VSL 控制策略
""" """
import argparse
import os import os
import sys import sys
import copy import copy
@ -19,9 +20,10 @@ from agents.appo_agent import APPOAgent
from utils.config import get_agent_config, get_training_config from utils.config import get_agent_config, get_training_config
from utils.logger import TrainingLogger from utils.logger import TrainingLogger
from utils.plot import plot_training_curves from utils.plot import plot_training_curves
from utils.run_dirs import add_run_dir_args, resolve_run_dirs
def train_sumo_appo(): def train_sumo_appo(log_dir=None, checkpoint_dir=None, run_timestamp=None):
"""SUMO 环境下的 APPO 训练主函数""" """SUMO 环境下的 APPO 训练主函数"""
with open("config_sumo_vsl.yaml", "r", encoding="utf-8") as f: with open("config_sumo_vsl.yaml", "r", encoding="utf-8") as f:
config = yaml.safe_load(f) config = yaml.safe_load(f)
@ -30,9 +32,12 @@ def train_sumo_appo():
train_config = get_training_config(config) train_config = get_training_config(config)
start_episode = 1 start_episode = 1
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") _, checkpoint_dir, log_dir = resolve_run_dirs(
checkpoint_dir = os.path.join("checkpoints", "appo", timestamp) "appo",
log_dir = os.path.join("logs", "appo", timestamp) log_dir=log_dir,
checkpoint_dir=checkpoint_dir,
run_timestamp=run_timestamp,
)
os.makedirs(checkpoint_dir, exist_ok=True) os.makedirs(checkpoint_dir, exist_ok=True)
os.makedirs(log_dir, exist_ok=True) os.makedirs(log_dir, exist_ok=True)
runtime_config = copy.deepcopy(config) runtime_config = copy.deepcopy(config)
@ -246,4 +251,10 @@ def train_sumo_appo():
if __name__ == "__main__": if __name__ == "__main__":
train_sumo_appo() parser = add_run_dir_args(argparse.ArgumentParser())
args = parser.parse_args()
train_sumo_appo(
log_dir=args.log_dir,
checkpoint_dir=args.checkpoint_dir,
run_timestamp=args.run_timestamp,
)

View File

@ -2,6 +2,7 @@
基于 SUMO+TraCI DDPG 训练脚本 基于 SUMO+TraCI DDPG 训练脚本
使用 Stable-Baselines3 DDPG 算法 使用 Stable-Baselines3 DDPG 算法
""" """
import argparse
import os import os
import copy import copy
import yaml import yaml
@ -16,9 +17,10 @@ from agents.ddpg_agent import DDPGAgent
from utils.config import get_agent_config, get_training_config from utils.config import get_agent_config, get_training_config
from utils.logger import TrainingLogger from utils.logger import TrainingLogger
from utils.plot import plot_training_curves from utils.plot import plot_training_curves
from utils.run_dirs import add_run_dir_args, resolve_run_dirs
def train_sumo_ddpg(): def train_sumo_ddpg(log_dir=None, checkpoint_dir=None, run_timestamp=None):
"""SUMO 环境下的 DDPG 训练主函数""" """SUMO 环境下的 DDPG 训练主函数"""
with open("config_sumo_vsl.yaml", "r", encoding="utf-8") as f: with open("config_sumo_vsl.yaml", "r", encoding="utf-8") as f:
config = yaml.safe_load(f) config = yaml.safe_load(f)
@ -26,9 +28,12 @@ def train_sumo_ddpg():
agent_config = get_agent_config(config, "ddpg") agent_config = get_agent_config(config, "ddpg")
train_config = get_training_config(config) train_config = get_training_config(config)
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") _, checkpoint_dir, log_dir = resolve_run_dirs(
checkpoint_dir = os.path.join("checkpoints", "ddpg", timestamp) "ddpg",
log_dir = os.path.join("logs", "ddpg", timestamp) log_dir=log_dir,
checkpoint_dir=checkpoint_dir,
run_timestamp=run_timestamp,
)
os.makedirs(checkpoint_dir, exist_ok=True) os.makedirs(checkpoint_dir, exist_ok=True)
os.makedirs(log_dir, exist_ok=True) os.makedirs(log_dir, exist_ok=True)
runtime_config = copy.deepcopy(config) runtime_config = copy.deepcopy(config)
@ -183,4 +188,10 @@ def train_sumo_ddpg():
if __name__ == "__main__": if __name__ == "__main__":
train_sumo_ddpg() parser = add_run_dir_args(argparse.ArgumentParser())
args = parser.parse_args()
train_sumo_ddpg(
log_dir=args.log_dir,
checkpoint_dir=args.checkpoint_dir,
run_timestamp=args.run_timestamp,
)

View File

@ -1,6 +1,7 @@
""" """
DQN训练脚本 - SUMO VSL环境 DQN训练脚本 - SUMO VSL环境
""" """
import argparse
import os import os
import sys import sys
import copy import copy
@ -17,9 +18,10 @@ from agents.dqn_agent import DQNAgent
from utils.config import get_agent_config, get_training_config from utils.config import get_agent_config, get_training_config
from utils.logger import TrainingLogger from utils.logger import TrainingLogger
from utils.plot import plot_training_curves from utils.plot import plot_training_curves
from utils.run_dirs import add_run_dir_args, resolve_run_dirs
def train_sumo_dqn(): def train_sumo_dqn(log_dir=None, checkpoint_dir=None, run_timestamp=None):
with open("config_sumo_vsl.yaml", "r", encoding="utf-8") as f: with open("config_sumo_vsl.yaml", "r", encoding="utf-8") as f:
config = yaml.safe_load(f) config = yaml.safe_load(f)
@ -27,9 +29,12 @@ def train_sumo_dqn():
train_config = get_training_config(config) train_config = get_training_config(config)
start_episode = 1 start_episode = 1
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") _, checkpoint_dir, log_dir = resolve_run_dirs(
checkpoint_dir = os.path.join("checkpoints", "dqn", timestamp) "dqn",
log_dir = os.path.join("logs", "dqn", timestamp) log_dir=log_dir,
checkpoint_dir=checkpoint_dir,
run_timestamp=run_timestamp,
)
os.makedirs(checkpoint_dir, exist_ok=True) os.makedirs(checkpoint_dir, exist_ok=True)
os.makedirs(log_dir, exist_ok=True) os.makedirs(log_dir, exist_ok=True)
runtime_config = copy.deepcopy(config) runtime_config = copy.deepcopy(config)
@ -203,4 +208,10 @@ def train_sumo_dqn():
if __name__ == "__main__": if __name__ == "__main__":
train_sumo_dqn() parser = add_run_dir_args(argparse.ArgumentParser())
args = parser.parse_args()
train_sumo_dqn(
log_dir=args.log_dir,
checkpoint_dir=args.checkpoint_dir,
run_timestamp=args.run_timestamp,
)

View File

@ -1,6 +1,7 @@
""" """
MAPPO training script for SUMO + TraCI VSL. MAPPO training script for SUMO + TraCI VSL.
""" """
import argparse
import os import os
import copy import copy
import yaml import yaml
@ -16,18 +17,22 @@ from agents.mappo_agent import MAPPOAgent
from utils.config import get_agent_config, get_training_config from utils.config import get_agent_config, get_training_config
from utils.logger import TrainingLogger from utils.logger import TrainingLogger
from utils.plot import plot_training_curves from utils.plot import plot_training_curves
from utils.run_dirs import add_run_dir_args, resolve_run_dirs
def train_sumo_mappo(): def train_sumo_mappo(log_dir=None, checkpoint_dir=None, run_timestamp=None):
with open("config_sumo_vsl.yaml", "r", encoding="utf-8") as f: with open("config_sumo_vsl.yaml", "r", encoding="utf-8") as f:
config = yaml.safe_load(f) config = yaml.safe_load(f)
agent_config = get_agent_config(config, "mappo") agent_config = get_agent_config(config, "mappo")
train_config = get_training_config(config) train_config = get_training_config(config)
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") _, checkpoint_dir, log_dir = resolve_run_dirs(
checkpoint_dir = os.path.join("checkpoints", "mappo", timestamp) "mappo",
log_dir = os.path.join("logs", "mappo", timestamp) log_dir=log_dir,
checkpoint_dir=checkpoint_dir,
run_timestamp=run_timestamp,
)
os.makedirs(checkpoint_dir, exist_ok=True) os.makedirs(checkpoint_dir, exist_ok=True)
os.makedirs(log_dir, exist_ok=True) os.makedirs(log_dir, exist_ok=True)
runtime_config = copy.deepcopy(config) runtime_config = copy.deepcopy(config)
@ -235,4 +240,10 @@ def train_sumo_mappo():
if __name__ == "__main__": if __name__ == "__main__":
train_sumo_mappo() parser = add_run_dir_args(argparse.ArgumentParser())
args = parser.parse_args()
train_sumo_mappo(
log_dir=args.log_dir,
checkpoint_dir=args.checkpoint_dir,
run_timestamp=args.run_timestamp,
)

View File

@ -2,6 +2,7 @@
基于 SUMO+TraCI PPO 训练脚本 基于 SUMO+TraCI PPO 训练脚本
使用微观仿真环境训练 VSL 控制策略 使用微观仿真环境训练 VSL 控制策略
""" """
import argparse
import os import os
import sys import sys
import copy import copy
@ -19,9 +20,10 @@ from agents.ppo_agent import PPOAgent
from utils.config import get_agent_config, get_training_config from utils.config import get_agent_config, get_training_config
from utils.logger import TrainingLogger from utils.logger import TrainingLogger
from utils.plot import plot_training_curves from utils.plot import plot_training_curves
from utils.run_dirs import add_run_dir_args, resolve_run_dirs
def train_sumo_ppo(): def train_sumo_ppo(log_dir=None, checkpoint_dir=None, run_timestamp=None):
"""SUMO 环境下的 PPO 训练主函数""" """SUMO 环境下的 PPO 训练主函数"""
# 加载配置 # 加载配置
with open("config_sumo_vsl.yaml", "r", encoding="utf-8") as f: with open("config_sumo_vsl.yaml", "r", encoding="utf-8") as f:
@ -31,9 +33,12 @@ def train_sumo_ppo():
train_config = get_training_config(config) train_config = get_training_config(config)
start_episode = 1 start_episode = 1
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") _, checkpoint_dir, log_dir = resolve_run_dirs(
checkpoint_dir = os.path.join("checkpoints", "ppo", timestamp) "ppo",
log_dir = os.path.join("logs", "ppo", timestamp) log_dir=log_dir,
checkpoint_dir=checkpoint_dir,
run_timestamp=run_timestamp,
)
os.makedirs(checkpoint_dir, exist_ok=True) os.makedirs(checkpoint_dir, exist_ok=True)
os.makedirs(log_dir, exist_ok=True) os.makedirs(log_dir, exist_ok=True)
runtime_config = copy.deepcopy(config) runtime_config = copy.deepcopy(config)
@ -249,4 +254,10 @@ def train_sumo_ppo():
if __name__ == "__main__": if __name__ == "__main__":
train_sumo_ppo() parser = add_run_dir_args(argparse.ArgumentParser())
args = parser.parse_args()
train_sumo_ppo(
log_dir=args.log_dir,
checkpoint_dir=args.checkpoint_dir,
run_timestamp=args.run_timestamp,
)

View File

@ -2,6 +2,7 @@
基于 SUMO+TraCI TD3 训练脚本 基于 SUMO+TraCI TD3 训练脚本
使用 Stable-Baselines3 TD3 算法 使用 Stable-Baselines3 TD3 算法
""" """
import argparse
import os import os
import copy import copy
import yaml import yaml
@ -17,9 +18,10 @@ from agents.td3_agent import TD3Agent
from utils.config import get_agent_config, get_training_config from utils.config import get_agent_config, get_training_config
from utils.logger import TrainingLogger from utils.logger import TrainingLogger
from utils.plot import plot_training_curves from utils.plot import plot_training_curves
from utils.run_dirs import add_run_dir_args, resolve_run_dirs
def train_sumo_td3(): def train_sumo_td3(log_dir=None, checkpoint_dir=None, run_timestamp=None):
"""SUMO 环境下的 TD3 训练主函数""" """SUMO 环境下的 TD3 训练主函数"""
with open("config_sumo_vsl.yaml", "r", encoding="utf-8") as f: with open("config_sumo_vsl.yaml", "r", encoding="utf-8") as f:
config = yaml.safe_load(f) config = yaml.safe_load(f)
@ -27,9 +29,12 @@ def train_sumo_td3():
agent_config = get_agent_config(config, "td3") agent_config = get_agent_config(config, "td3")
train_config = get_training_config(config) train_config = get_training_config(config)
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") _, checkpoint_dir, log_dir = resolve_run_dirs(
checkpoint_dir = os.path.join("checkpoints", "td3", timestamp) "td3",
log_dir = os.path.join("logs", "td3", timestamp) log_dir=log_dir,
checkpoint_dir=checkpoint_dir,
run_timestamp=run_timestamp,
)
os.makedirs(checkpoint_dir, exist_ok=True) os.makedirs(checkpoint_dir, exist_ok=True)
os.makedirs(log_dir, exist_ok=True) os.makedirs(log_dir, exist_ok=True)
runtime_config = copy.deepcopy(config) runtime_config = copy.deepcopy(config)
@ -193,4 +198,10 @@ def train_sumo_td3():
if __name__ == "__main__": if __name__ == "__main__":
train_sumo_td3() parser = add_run_dir_args(argparse.ArgumentParser())
args = parser.parse_args()
train_sumo_td3(
log_dir=args.log_dir,
checkpoint_dir=args.checkpoint_dir,
run_timestamp=args.run_timestamp,
)

28
utils/run_dirs.py Normal file
View File

@ -0,0 +1,28 @@
import argparse
import os
from datetime import datetime
from typing import Optional, Tuple
def add_run_dir_args(parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
parser.add_argument("--log-dir", type=str, default=None, help="Training log output directory.")
parser.add_argument("--checkpoint-dir", type=str, default=None, help="Model checkpoint output directory.")
parser.add_argument("--run-timestamp", type=str, default=None, help="Run timestamp tag for default directories.")
return parser
def resolve_run_dirs(
model_name: str,
log_dir: Optional[str] = None,
checkpoint_dir: Optional[str] = None,
run_timestamp: Optional[str] = None,
) -> Tuple[str, str, str]:
"""Resolve output directories from runtime args, falling back to per-model defaults."""
timestamp = run_timestamp or datetime.now().strftime("%Y%m%d_%H%M%S")
if checkpoint_dir is None:
checkpoint_dir = os.path.join("checkpoints", model_name, timestamp)
if log_dir is None:
log_dir = os.path.join("logs", model_name, timestamp)
return timestamp, checkpoint_dir, log_dir