修改同时训练时保存的目录

This commit is contained in:
Zihan Ye 2026-04-09 02:46:25 +08:00
parent 09c0dd86e5
commit 27502241ad
8 changed files with 166 additions and 46 deletions

View File

@ -1,24 +1,50 @@
"""一键异步启动全部训练"""
"""一键异步启动全部训练。"""
import os
import subprocess
import sys
from datetime import datetime
AGENTS = ["ppo", "appo", "mappo", "dqn", "ddpg", "td3"]
def main():
run_timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
processes = {}
for agent in AGENTS:
print(f"[{datetime.now().strftime('%H:%M:%S')}] 启动 {agent.upper()} 训练...")
p = subprocess.Popen(
[sys.executable, "-m", f"training.train_{agent}"],
stdout=open(f"logs_{agent}_stdout.txt", "w"),
log_dir = os.path.join("logs", "multi-model", run_timestamp, agent)
checkpoint_dir = os.path.join("checkpoints", "multi-model", run_timestamp, agent)
os.makedirs(log_dir, exist_ok=True)
os.makedirs(checkpoint_dir, exist_ok=True)
stdout_path = os.path.join(log_dir, "stdout.txt")
process = subprocess.Popen(
[
sys.executable,
"-m",
f"training.train_{agent}",
"--log-dir",
log_dir,
"--checkpoint-dir",
checkpoint_dir,
"--run-timestamp",
run_timestamp,
],
stdout=open(stdout_path, "w", encoding="utf-8"),
stderr=subprocess.STDOUT,
)
processes[agent] = p
print(f" PID: {p.pid}")
processes[agent] = process
print(f" PID: {process.pid}")
print(f"\n全部 {len(AGENTS)} 个训练已启动,等待完成...\n")
print(f"\n全部 {len(AGENTS)} 个训练已启动,等待完成...")
print(f"本次多模型时间戳: {run_timestamp}\n")
for agent, p in processes.items():
p.wait()
status = "✓ 完成" if p.returncode == 0 else f"✗ 失败(code={p.returncode})"
for agent, process in processes.items():
process.wait()
status = "完成" if process.returncode == 0 else f"失败(code={process.returncode})"
print(f"[{agent.upper()}] {status}")
if __name__ == "__main__":
main()

View File

@ -2,6 +2,7 @@
基于 SUMO+TraCI APPO 训练脚本
使用微观仿真环境训练 VSL 控制策略
"""
import argparse
import os
import sys
import copy
@ -19,9 +20,10 @@ from agents.appo_agent import APPOAgent
from utils.config import get_agent_config, get_training_config
from utils.logger import TrainingLogger
from utils.plot import plot_training_curves
from utils.run_dirs import add_run_dir_args, resolve_run_dirs
def train_sumo_appo():
def train_sumo_appo(log_dir=None, checkpoint_dir=None, run_timestamp=None):
"""SUMO 环境下的 APPO 训练主函数"""
with open("config_sumo_vsl.yaml", "r", encoding="utf-8") as f:
config = yaml.safe_load(f)
@ -30,9 +32,12 @@ def train_sumo_appo():
train_config = get_training_config(config)
start_episode = 1
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
checkpoint_dir = os.path.join("checkpoints", "appo", timestamp)
log_dir = os.path.join("logs", "appo", timestamp)
_, checkpoint_dir, log_dir = resolve_run_dirs(
"appo",
log_dir=log_dir,
checkpoint_dir=checkpoint_dir,
run_timestamp=run_timestamp,
)
os.makedirs(checkpoint_dir, exist_ok=True)
os.makedirs(log_dir, exist_ok=True)
runtime_config = copy.deepcopy(config)
@ -246,4 +251,10 @@ def train_sumo_appo():
if __name__ == "__main__":
train_sumo_appo()
parser = add_run_dir_args(argparse.ArgumentParser())
args = parser.parse_args()
train_sumo_appo(
log_dir=args.log_dir,
checkpoint_dir=args.checkpoint_dir,
run_timestamp=args.run_timestamp,
)

View File

@ -2,6 +2,7 @@
基于 SUMO+TraCI DDPG 训练脚本
使用 Stable-Baselines3 DDPG 算法
"""
import argparse
import os
import copy
import yaml
@ -16,9 +17,10 @@ from agents.ddpg_agent import DDPGAgent
from utils.config import get_agent_config, get_training_config
from utils.logger import TrainingLogger
from utils.plot import plot_training_curves
from utils.run_dirs import add_run_dir_args, resolve_run_dirs
def train_sumo_ddpg():
def train_sumo_ddpg(log_dir=None, checkpoint_dir=None, run_timestamp=None):
"""SUMO 环境下的 DDPG 训练主函数"""
with open("config_sumo_vsl.yaml", "r", encoding="utf-8") as f:
config = yaml.safe_load(f)
@ -26,9 +28,12 @@ def train_sumo_ddpg():
agent_config = get_agent_config(config, "ddpg")
train_config = get_training_config(config)
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
checkpoint_dir = os.path.join("checkpoints", "ddpg", timestamp)
log_dir = os.path.join("logs", "ddpg", timestamp)
_, checkpoint_dir, log_dir = resolve_run_dirs(
"ddpg",
log_dir=log_dir,
checkpoint_dir=checkpoint_dir,
run_timestamp=run_timestamp,
)
os.makedirs(checkpoint_dir, exist_ok=True)
os.makedirs(log_dir, exist_ok=True)
runtime_config = copy.deepcopy(config)
@ -183,4 +188,10 @@ def train_sumo_ddpg():
if __name__ == "__main__":
train_sumo_ddpg()
parser = add_run_dir_args(argparse.ArgumentParser())
args = parser.parse_args()
train_sumo_ddpg(
log_dir=args.log_dir,
checkpoint_dir=args.checkpoint_dir,
run_timestamp=args.run_timestamp,
)

View File

@ -1,6 +1,7 @@
"""
DQN训练脚本 - SUMO VSL环境
"""
import argparse
import os
import sys
import copy
@ -17,9 +18,10 @@ from agents.dqn_agent import DQNAgent
from utils.config import get_agent_config, get_training_config
from utils.logger import TrainingLogger
from utils.plot import plot_training_curves
from utils.run_dirs import add_run_dir_args, resolve_run_dirs
def train_sumo_dqn():
def train_sumo_dqn(log_dir=None, checkpoint_dir=None, run_timestamp=None):
with open("config_sumo_vsl.yaml", "r", encoding="utf-8") as f:
config = yaml.safe_load(f)
@ -27,9 +29,12 @@ def train_sumo_dqn():
train_config = get_training_config(config)
start_episode = 1
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
checkpoint_dir = os.path.join("checkpoints", "dqn", timestamp)
log_dir = os.path.join("logs", "dqn", timestamp)
_, checkpoint_dir, log_dir = resolve_run_dirs(
"dqn",
log_dir=log_dir,
checkpoint_dir=checkpoint_dir,
run_timestamp=run_timestamp,
)
os.makedirs(checkpoint_dir, exist_ok=True)
os.makedirs(log_dir, exist_ok=True)
runtime_config = copy.deepcopy(config)
@ -203,4 +208,10 @@ def train_sumo_dqn():
if __name__ == "__main__":
train_sumo_dqn()
parser = add_run_dir_args(argparse.ArgumentParser())
args = parser.parse_args()
train_sumo_dqn(
log_dir=args.log_dir,
checkpoint_dir=args.checkpoint_dir,
run_timestamp=args.run_timestamp,
)

View File

@ -1,6 +1,7 @@
"""
MAPPO training script for SUMO + TraCI VSL.
"""
import argparse
import os
import copy
import yaml
@ -16,18 +17,22 @@ from agents.mappo_agent import MAPPOAgent
from utils.config import get_agent_config, get_training_config
from utils.logger import TrainingLogger
from utils.plot import plot_training_curves
from utils.run_dirs import add_run_dir_args, resolve_run_dirs
def train_sumo_mappo():
def train_sumo_mappo(log_dir=None, checkpoint_dir=None, run_timestamp=None):
with open("config_sumo_vsl.yaml", "r", encoding="utf-8") as f:
config = yaml.safe_load(f)
agent_config = get_agent_config(config, "mappo")
train_config = get_training_config(config)
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
checkpoint_dir = os.path.join("checkpoints", "mappo", timestamp)
log_dir = os.path.join("logs", "mappo", timestamp)
_, checkpoint_dir, log_dir = resolve_run_dirs(
"mappo",
log_dir=log_dir,
checkpoint_dir=checkpoint_dir,
run_timestamp=run_timestamp,
)
os.makedirs(checkpoint_dir, exist_ok=True)
os.makedirs(log_dir, exist_ok=True)
runtime_config = copy.deepcopy(config)
@ -235,4 +240,10 @@ def train_sumo_mappo():
if __name__ == "__main__":
train_sumo_mappo()
parser = add_run_dir_args(argparse.ArgumentParser())
args = parser.parse_args()
train_sumo_mappo(
log_dir=args.log_dir,
checkpoint_dir=args.checkpoint_dir,
run_timestamp=args.run_timestamp,
)

View File

@ -2,6 +2,7 @@
基于 SUMO+TraCI PPO 训练脚本
使用微观仿真环境训练 VSL 控制策略
"""
import argparse
import os
import sys
import copy
@ -19,9 +20,10 @@ from agents.ppo_agent import PPOAgent
from utils.config import get_agent_config, get_training_config
from utils.logger import TrainingLogger
from utils.plot import plot_training_curves
from utils.run_dirs import add_run_dir_args, resolve_run_dirs
def train_sumo_ppo():
def train_sumo_ppo(log_dir=None, checkpoint_dir=None, run_timestamp=None):
"""SUMO 环境下的 PPO 训练主函数"""
# 加载配置
with open("config_sumo_vsl.yaml", "r", encoding="utf-8") as f:
@ -31,9 +33,12 @@ def train_sumo_ppo():
train_config = get_training_config(config)
start_episode = 1
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
checkpoint_dir = os.path.join("checkpoints", "ppo", timestamp)
log_dir = os.path.join("logs", "ppo", timestamp)
_, checkpoint_dir, log_dir = resolve_run_dirs(
"ppo",
log_dir=log_dir,
checkpoint_dir=checkpoint_dir,
run_timestamp=run_timestamp,
)
os.makedirs(checkpoint_dir, exist_ok=True)
os.makedirs(log_dir, exist_ok=True)
runtime_config = copy.deepcopy(config)
@ -249,4 +254,10 @@ def train_sumo_ppo():
if __name__ == "__main__":
train_sumo_ppo()
parser = add_run_dir_args(argparse.ArgumentParser())
args = parser.parse_args()
train_sumo_ppo(
log_dir=args.log_dir,
checkpoint_dir=args.checkpoint_dir,
run_timestamp=args.run_timestamp,
)

View File

@ -2,6 +2,7 @@
基于 SUMO+TraCI TD3 训练脚本
使用 Stable-Baselines3 TD3 算法
"""
import argparse
import os
import copy
import yaml
@ -17,9 +18,10 @@ from agents.td3_agent import TD3Agent
from utils.config import get_agent_config, get_training_config
from utils.logger import TrainingLogger
from utils.plot import plot_training_curves
from utils.run_dirs import add_run_dir_args, resolve_run_dirs
def train_sumo_td3():
def train_sumo_td3(log_dir=None, checkpoint_dir=None, run_timestamp=None):
"""SUMO 环境下的 TD3 训练主函数"""
with open("config_sumo_vsl.yaml", "r", encoding="utf-8") as f:
config = yaml.safe_load(f)
@ -27,9 +29,12 @@ def train_sumo_td3():
agent_config = get_agent_config(config, "td3")
train_config = get_training_config(config)
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
checkpoint_dir = os.path.join("checkpoints", "td3", timestamp)
log_dir = os.path.join("logs", "td3", timestamp)
_, checkpoint_dir, log_dir = resolve_run_dirs(
"td3",
log_dir=log_dir,
checkpoint_dir=checkpoint_dir,
run_timestamp=run_timestamp,
)
os.makedirs(checkpoint_dir, exist_ok=True)
os.makedirs(log_dir, exist_ok=True)
runtime_config = copy.deepcopy(config)
@ -193,4 +198,10 @@ def train_sumo_td3():
if __name__ == "__main__":
train_sumo_td3()
parser = add_run_dir_args(argparse.ArgumentParser())
args = parser.parse_args()
train_sumo_td3(
log_dir=args.log_dir,
checkpoint_dir=args.checkpoint_dir,
run_timestamp=args.run_timestamp,
)

28
utils/run_dirs.py Normal file
View File

@ -0,0 +1,28 @@
import argparse
import os
from datetime import datetime
from typing import Optional, Tuple
def add_run_dir_args(parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
parser.add_argument("--log-dir", type=str, default=None, help="Training log output directory.")
parser.add_argument("--checkpoint-dir", type=str, default=None, help="Model checkpoint output directory.")
parser.add_argument("--run-timestamp", type=str, default=None, help="Run timestamp tag for default directories.")
return parser
def resolve_run_dirs(
model_name: str,
log_dir: Optional[str] = None,
checkpoint_dir: Optional[str] = None,
run_timestamp: Optional[str] = None,
) -> Tuple[str, str, str]:
"""Resolve output directories from runtime args, falling back to per-model defaults."""
timestamp = run_timestamp or datetime.now().strftime("%Y%m%d_%H%M%S")
if checkpoint_dir is None:
checkpoint_dir = os.path.join("checkpoints", model_name, timestamp)
if log_dir is None:
log_dir = os.path.join("logs", model_name, timestamp)
return timestamp, checkpoint_dir, log_dir