ctm-dqn/run_all_training.py

84 lines
2.7 KiB
Python

"""Launch all model training jobs plus a live overview plotter."""
import os
import subprocess
import sys
from datetime import datetime
AGENTS = ["ppo", "appo", "mappo", "dqn", "ddpg", "td3"]
PLOT_REFRESH_INTERVAL = 30
def main():
run_timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
run_log_root = os.path.join("logs", "multi-model", run_timestamp)
run_ckpt_root = os.path.join("checkpoints", "multi-model", run_timestamp)
os.makedirs(run_log_root, exist_ok=True)
os.makedirs(run_ckpt_root, exist_ok=True)
processes = {}
for agent in AGENTS:
print(f"[{datetime.now().strftime('%H:%M:%S')}] 启动 {agent.upper()} 训练...")
log_dir = os.path.join(run_log_root, agent)
checkpoint_dir = os.path.join(run_ckpt_root, agent)
os.makedirs(log_dir, exist_ok=True)
os.makedirs(checkpoint_dir, exist_ok=True)
stdout_path = os.path.join(log_dir, "stdout.txt")
process = subprocess.Popen(
[
sys.executable,
"-m",
f"training.train_{agent}",
"--log-dir",
log_dir,
"--checkpoint-dir",
checkpoint_dir,
"--run-timestamp",
run_timestamp,
],
stdout=open(stdout_path, "w", encoding="utf-8"),
stderr=subprocess.STDOUT,
)
processes[agent] = process
print(f" PID: {process.pid}")
plotter_stdout_path = os.path.join(run_log_root, "plotter_stdout.txt")
plotter_process = subprocess.Popen(
[
sys.executable,
"-m",
"scripts.plot_live_training",
"--all-models",
"--run",
run_timestamp,
"--watch",
"--interval",
str(PLOT_REFRESH_INTERVAL),
],
stdout=open(plotter_stdout_path, "w", encoding="utf-8"),
stderr=subprocess.STDOUT,
)
print(f"[{datetime.now().strftime('%H:%M:%S')}] 启动联训总览绘图...")
print(f" PID: {plotter_process.pid}")
print(f"\n全部 {len(AGENTS)} 个训练已启动,等待完成...")
print(f"本次多模型时间戳: {run_timestamp}\n")
try:
for agent, process in processes.items():
process.wait()
status = "完成" if process.returncode == 0 else f"失败(code={process.returncode})"
print(f"[{agent.upper()}] {status}")
finally:
if plotter_process.poll() is None:
plotter_process.terminate()
try:
plotter_process.wait(timeout=10)
except subprocess.TimeoutExpired:
plotter_process.kill()
print("[PLOTTER] 已停止自动绘图进程")
if __name__ == "__main__":
main()