修改同时训练时保存的目录
This commit is contained in:
parent
09c0dd86e5
commit
27502241ad
|
|
@ -1,24 +1,50 @@
|
|||
"""一键异步启动全部训练"""
|
||||
"""一键异步启动全部训练。"""
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
from datetime import datetime
|
||||
|
||||
AGENTS = ["ppo", "appo", "mappo", "dqn", "ddpg", "td3"]
|
||||
|
||||
|
||||
def main():
|
||||
run_timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
processes = {}
|
||||
|
||||
for agent in AGENTS:
|
||||
print(f"[{datetime.now().strftime('%H:%M:%S')}] 启动 {agent.upper()} 训练...")
|
||||
p = subprocess.Popen(
|
||||
[sys.executable, "-m", f"training.train_{agent}"],
|
||||
stdout=open(f"logs_{agent}_stdout.txt", "w"),
|
||||
log_dir = os.path.join("logs", "multi-model", run_timestamp, agent)
|
||||
checkpoint_dir = os.path.join("checkpoints", "multi-model", run_timestamp, agent)
|
||||
os.makedirs(log_dir, exist_ok=True)
|
||||
os.makedirs(checkpoint_dir, exist_ok=True)
|
||||
stdout_path = os.path.join(log_dir, "stdout.txt")
|
||||
|
||||
process = subprocess.Popen(
|
||||
[
|
||||
sys.executable,
|
||||
"-m",
|
||||
f"training.train_{agent}",
|
||||
"--log-dir",
|
||||
log_dir,
|
||||
"--checkpoint-dir",
|
||||
checkpoint_dir,
|
||||
"--run-timestamp",
|
||||
run_timestamp,
|
||||
],
|
||||
stdout=open(stdout_path, "w", encoding="utf-8"),
|
||||
stderr=subprocess.STDOUT,
|
||||
)
|
||||
processes[agent] = p
|
||||
print(f" PID: {p.pid}")
|
||||
processes[agent] = process
|
||||
print(f" PID: {process.pid}")
|
||||
|
||||
print(f"\n全部 {len(AGENTS)} 个训练已启动,等待完成...\n")
|
||||
print(f"\n全部 {len(AGENTS)} 个训练已启动,等待完成...")
|
||||
print(f"本次多模型时间戳: {run_timestamp}\n")
|
||||
|
||||
for agent, p in processes.items():
|
||||
p.wait()
|
||||
status = "✓ 完成" if p.returncode == 0 else f"✗ 失败(code={p.returncode})"
|
||||
for agent, process in processes.items():
|
||||
process.wait()
|
||||
status = "完成" if process.returncode == 0 else f"失败(code={process.returncode})"
|
||||
print(f"[{agent.upper()}] {status}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@
|
|||
基于 SUMO+TraCI 的 APPO 训练脚本
|
||||
使用微观仿真环境训练 VSL 控制策略
|
||||
"""
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
import copy
|
||||
|
|
@ -19,9 +20,10 @@ from agents.appo_agent import APPOAgent
|
|||
from utils.config import get_agent_config, get_training_config
|
||||
from utils.logger import TrainingLogger
|
||||
from utils.plot import plot_training_curves
|
||||
from utils.run_dirs import add_run_dir_args, resolve_run_dirs
|
||||
|
||||
|
||||
def train_sumo_appo():
|
||||
def train_sumo_appo(log_dir=None, checkpoint_dir=None, run_timestamp=None):
|
||||
"""SUMO 环境下的 APPO 训练主函数"""
|
||||
with open("config_sumo_vsl.yaml", "r", encoding="utf-8") as f:
|
||||
config = yaml.safe_load(f)
|
||||
|
|
@ -30,9 +32,12 @@ def train_sumo_appo():
|
|||
train_config = get_training_config(config)
|
||||
|
||||
start_episode = 1
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
checkpoint_dir = os.path.join("checkpoints", "appo", timestamp)
|
||||
log_dir = os.path.join("logs", "appo", timestamp)
|
||||
_, checkpoint_dir, log_dir = resolve_run_dirs(
|
||||
"appo",
|
||||
log_dir=log_dir,
|
||||
checkpoint_dir=checkpoint_dir,
|
||||
run_timestamp=run_timestamp,
|
||||
)
|
||||
os.makedirs(checkpoint_dir, exist_ok=True)
|
||||
os.makedirs(log_dir, exist_ok=True)
|
||||
runtime_config = copy.deepcopy(config)
|
||||
|
|
@ -246,4 +251,10 @@ def train_sumo_appo():
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
train_sumo_appo()
|
||||
parser = add_run_dir_args(argparse.ArgumentParser())
|
||||
args = parser.parse_args()
|
||||
train_sumo_appo(
|
||||
log_dir=args.log_dir,
|
||||
checkpoint_dir=args.checkpoint_dir,
|
||||
run_timestamp=args.run_timestamp,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@
|
|||
基于 SUMO+TraCI 的 DDPG 训练脚本
|
||||
使用 Stable-Baselines3 的 DDPG 算法
|
||||
"""
|
||||
import argparse
|
||||
import os
|
||||
import copy
|
||||
import yaml
|
||||
|
|
@ -16,9 +17,10 @@ from agents.ddpg_agent import DDPGAgent
|
|||
from utils.config import get_agent_config, get_training_config
|
||||
from utils.logger import TrainingLogger
|
||||
from utils.plot import plot_training_curves
|
||||
from utils.run_dirs import add_run_dir_args, resolve_run_dirs
|
||||
|
||||
|
||||
def train_sumo_ddpg():
|
||||
def train_sumo_ddpg(log_dir=None, checkpoint_dir=None, run_timestamp=None):
|
||||
"""SUMO 环境下的 DDPG 训练主函数"""
|
||||
with open("config_sumo_vsl.yaml", "r", encoding="utf-8") as f:
|
||||
config = yaml.safe_load(f)
|
||||
|
|
@ -26,9 +28,12 @@ def train_sumo_ddpg():
|
|||
agent_config = get_agent_config(config, "ddpg")
|
||||
train_config = get_training_config(config)
|
||||
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
checkpoint_dir = os.path.join("checkpoints", "ddpg", timestamp)
|
||||
log_dir = os.path.join("logs", "ddpg", timestamp)
|
||||
_, checkpoint_dir, log_dir = resolve_run_dirs(
|
||||
"ddpg",
|
||||
log_dir=log_dir,
|
||||
checkpoint_dir=checkpoint_dir,
|
||||
run_timestamp=run_timestamp,
|
||||
)
|
||||
os.makedirs(checkpoint_dir, exist_ok=True)
|
||||
os.makedirs(log_dir, exist_ok=True)
|
||||
runtime_config = copy.deepcopy(config)
|
||||
|
|
@ -183,4 +188,10 @@ def train_sumo_ddpg():
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
train_sumo_ddpg()
|
||||
parser = add_run_dir_args(argparse.ArgumentParser())
|
||||
args = parser.parse_args()
|
||||
train_sumo_ddpg(
|
||||
log_dir=args.log_dir,
|
||||
checkpoint_dir=args.checkpoint_dir,
|
||||
run_timestamp=args.run_timestamp,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
"""
|
||||
DQN训练脚本 - SUMO VSL环境
|
||||
"""
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
import copy
|
||||
|
|
@ -17,9 +18,10 @@ from agents.dqn_agent import DQNAgent
|
|||
from utils.config import get_agent_config, get_training_config
|
||||
from utils.logger import TrainingLogger
|
||||
from utils.plot import plot_training_curves
|
||||
from utils.run_dirs import add_run_dir_args, resolve_run_dirs
|
||||
|
||||
|
||||
def train_sumo_dqn():
|
||||
def train_sumo_dqn(log_dir=None, checkpoint_dir=None, run_timestamp=None):
|
||||
with open("config_sumo_vsl.yaml", "r", encoding="utf-8") as f:
|
||||
config = yaml.safe_load(f)
|
||||
|
||||
|
|
@ -27,9 +29,12 @@ def train_sumo_dqn():
|
|||
train_config = get_training_config(config)
|
||||
|
||||
start_episode = 1
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
checkpoint_dir = os.path.join("checkpoints", "dqn", timestamp)
|
||||
log_dir = os.path.join("logs", "dqn", timestamp)
|
||||
_, checkpoint_dir, log_dir = resolve_run_dirs(
|
||||
"dqn",
|
||||
log_dir=log_dir,
|
||||
checkpoint_dir=checkpoint_dir,
|
||||
run_timestamp=run_timestamp,
|
||||
)
|
||||
os.makedirs(checkpoint_dir, exist_ok=True)
|
||||
os.makedirs(log_dir, exist_ok=True)
|
||||
runtime_config = copy.deepcopy(config)
|
||||
|
|
@ -203,4 +208,10 @@ def train_sumo_dqn():
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
train_sumo_dqn()
|
||||
parser = add_run_dir_args(argparse.ArgumentParser())
|
||||
args = parser.parse_args()
|
||||
train_sumo_dqn(
|
||||
log_dir=args.log_dir,
|
||||
checkpoint_dir=args.checkpoint_dir,
|
||||
run_timestamp=args.run_timestamp,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
"""
|
||||
MAPPO training script for SUMO + TraCI VSL.
|
||||
"""
|
||||
import argparse
|
||||
import os
|
||||
import copy
|
||||
import yaml
|
||||
|
|
@ -16,18 +17,22 @@ from agents.mappo_agent import MAPPOAgent
|
|||
from utils.config import get_agent_config, get_training_config
|
||||
from utils.logger import TrainingLogger
|
||||
from utils.plot import plot_training_curves
|
||||
from utils.run_dirs import add_run_dir_args, resolve_run_dirs
|
||||
|
||||
|
||||
def train_sumo_mappo():
|
||||
def train_sumo_mappo(log_dir=None, checkpoint_dir=None, run_timestamp=None):
|
||||
with open("config_sumo_vsl.yaml", "r", encoding="utf-8") as f:
|
||||
config = yaml.safe_load(f)
|
||||
|
||||
agent_config = get_agent_config(config, "mappo")
|
||||
train_config = get_training_config(config)
|
||||
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
checkpoint_dir = os.path.join("checkpoints", "mappo", timestamp)
|
||||
log_dir = os.path.join("logs", "mappo", timestamp)
|
||||
_, checkpoint_dir, log_dir = resolve_run_dirs(
|
||||
"mappo",
|
||||
log_dir=log_dir,
|
||||
checkpoint_dir=checkpoint_dir,
|
||||
run_timestamp=run_timestamp,
|
||||
)
|
||||
os.makedirs(checkpoint_dir, exist_ok=True)
|
||||
os.makedirs(log_dir, exist_ok=True)
|
||||
runtime_config = copy.deepcopy(config)
|
||||
|
|
@ -235,4 +240,10 @@ def train_sumo_mappo():
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
train_sumo_mappo()
|
||||
parser = add_run_dir_args(argparse.ArgumentParser())
|
||||
args = parser.parse_args()
|
||||
train_sumo_mappo(
|
||||
log_dir=args.log_dir,
|
||||
checkpoint_dir=args.checkpoint_dir,
|
||||
run_timestamp=args.run_timestamp,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@
|
|||
基于 SUMO+TraCI 的 PPO 训练脚本
|
||||
使用微观仿真环境训练 VSL 控制策略
|
||||
"""
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
import copy
|
||||
|
|
@ -19,9 +20,10 @@ from agents.ppo_agent import PPOAgent
|
|||
from utils.config import get_agent_config, get_training_config
|
||||
from utils.logger import TrainingLogger
|
||||
from utils.plot import plot_training_curves
|
||||
from utils.run_dirs import add_run_dir_args, resolve_run_dirs
|
||||
|
||||
|
||||
def train_sumo_ppo():
|
||||
def train_sumo_ppo(log_dir=None, checkpoint_dir=None, run_timestamp=None):
|
||||
"""SUMO 环境下的 PPO 训练主函数"""
|
||||
# 加载配置
|
||||
with open("config_sumo_vsl.yaml", "r", encoding="utf-8") as f:
|
||||
|
|
@ -31,9 +33,12 @@ def train_sumo_ppo():
|
|||
train_config = get_training_config(config)
|
||||
|
||||
start_episode = 1
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
checkpoint_dir = os.path.join("checkpoints", "ppo", timestamp)
|
||||
log_dir = os.path.join("logs", "ppo", timestamp)
|
||||
_, checkpoint_dir, log_dir = resolve_run_dirs(
|
||||
"ppo",
|
||||
log_dir=log_dir,
|
||||
checkpoint_dir=checkpoint_dir,
|
||||
run_timestamp=run_timestamp,
|
||||
)
|
||||
os.makedirs(checkpoint_dir, exist_ok=True)
|
||||
os.makedirs(log_dir, exist_ok=True)
|
||||
runtime_config = copy.deepcopy(config)
|
||||
|
|
@ -249,4 +254,10 @@ def train_sumo_ppo():
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
train_sumo_ppo()
|
||||
parser = add_run_dir_args(argparse.ArgumentParser())
|
||||
args = parser.parse_args()
|
||||
train_sumo_ppo(
|
||||
log_dir=args.log_dir,
|
||||
checkpoint_dir=args.checkpoint_dir,
|
||||
run_timestamp=args.run_timestamp,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@
|
|||
基于 SUMO+TraCI 的 TD3 训练脚本
|
||||
使用 Stable-Baselines3 的 TD3 算法
|
||||
"""
|
||||
import argparse
|
||||
import os
|
||||
import copy
|
||||
import yaml
|
||||
|
|
@ -17,9 +18,10 @@ from agents.td3_agent import TD3Agent
|
|||
from utils.config import get_agent_config, get_training_config
|
||||
from utils.logger import TrainingLogger
|
||||
from utils.plot import plot_training_curves
|
||||
from utils.run_dirs import add_run_dir_args, resolve_run_dirs
|
||||
|
||||
|
||||
def train_sumo_td3():
|
||||
def train_sumo_td3(log_dir=None, checkpoint_dir=None, run_timestamp=None):
|
||||
"""SUMO 环境下的 TD3 训练主函数"""
|
||||
with open("config_sumo_vsl.yaml", "r", encoding="utf-8") as f:
|
||||
config = yaml.safe_load(f)
|
||||
|
|
@ -27,9 +29,12 @@ def train_sumo_td3():
|
|||
agent_config = get_agent_config(config, "td3")
|
||||
train_config = get_training_config(config)
|
||||
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
checkpoint_dir = os.path.join("checkpoints", "td3", timestamp)
|
||||
log_dir = os.path.join("logs", "td3", timestamp)
|
||||
_, checkpoint_dir, log_dir = resolve_run_dirs(
|
||||
"td3",
|
||||
log_dir=log_dir,
|
||||
checkpoint_dir=checkpoint_dir,
|
||||
run_timestamp=run_timestamp,
|
||||
)
|
||||
os.makedirs(checkpoint_dir, exist_ok=True)
|
||||
os.makedirs(log_dir, exist_ok=True)
|
||||
runtime_config = copy.deepcopy(config)
|
||||
|
|
@ -193,4 +198,10 @@ def train_sumo_td3():
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
train_sumo_td3()
|
||||
parser = add_run_dir_args(argparse.ArgumentParser())
|
||||
args = parser.parse_args()
|
||||
train_sumo_td3(
|
||||
log_dir=args.log_dir,
|
||||
checkpoint_dir=args.checkpoint_dir,
|
||||
run_timestamp=args.run_timestamp,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,28 @@
|
|||
import argparse
|
||||
import os
|
||||
from datetime import datetime
|
||||
from typing import Optional, Tuple
|
||||
|
||||
|
||||
def add_run_dir_args(parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
|
||||
parser.add_argument("--log-dir", type=str, default=None, help="Training log output directory.")
|
||||
parser.add_argument("--checkpoint-dir", type=str, default=None, help="Model checkpoint output directory.")
|
||||
parser.add_argument("--run-timestamp", type=str, default=None, help="Run timestamp tag for default directories.")
|
||||
return parser
|
||||
|
||||
|
||||
def resolve_run_dirs(
|
||||
model_name: str,
|
||||
log_dir: Optional[str] = None,
|
||||
checkpoint_dir: Optional[str] = None,
|
||||
run_timestamp: Optional[str] = None,
|
||||
) -> Tuple[str, str, str]:
|
||||
"""Resolve output directories from runtime args, falling back to per-model defaults."""
|
||||
timestamp = run_timestamp or datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
|
||||
if checkpoint_dir is None:
|
||||
checkpoint_dir = os.path.join("checkpoints", model_name, timestamp)
|
||||
if log_dir is None:
|
||||
log_dir = os.path.join("logs", model_name, timestamp)
|
||||
|
||||
return timestamp, checkpoint_dir, log_dir
|
||||
Loading…
Reference in New Issue