84 lines
2.7 KiB
Python
84 lines
2.7 KiB
Python
"""Launch all model training jobs plus a live overview plotter."""
|
|
import os
|
|
import subprocess
|
|
import sys
|
|
from datetime import datetime
|
|
|
|
AGENTS = ["ppo", "appo", "mappo", "dqn", "ddpg", "td3"]
|
|
PLOT_REFRESH_INTERVAL = 30
|
|
|
|
|
|
def main():
|
|
run_timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
run_log_root = os.path.join("logs", "multi-model", run_timestamp)
|
|
run_ckpt_root = os.path.join("checkpoints", "multi-model", run_timestamp)
|
|
os.makedirs(run_log_root, exist_ok=True)
|
|
os.makedirs(run_ckpt_root, exist_ok=True)
|
|
|
|
processes = {}
|
|
for agent in AGENTS:
|
|
print(f"[{datetime.now().strftime('%H:%M:%S')}] 启动 {agent.upper()} 训练...")
|
|
log_dir = os.path.join(run_log_root, agent)
|
|
checkpoint_dir = os.path.join(run_ckpt_root, agent)
|
|
os.makedirs(log_dir, exist_ok=True)
|
|
os.makedirs(checkpoint_dir, exist_ok=True)
|
|
stdout_path = os.path.join(log_dir, "stdout.txt")
|
|
|
|
process = subprocess.Popen(
|
|
[
|
|
sys.executable,
|
|
"-m",
|
|
f"training.train_{agent}",
|
|
"--log-dir",
|
|
log_dir,
|
|
"--checkpoint-dir",
|
|
checkpoint_dir,
|
|
"--run-timestamp",
|
|
run_timestamp,
|
|
],
|
|
stdout=open(stdout_path, "w", encoding="utf-8"),
|
|
stderr=subprocess.STDOUT,
|
|
)
|
|
processes[agent] = process
|
|
print(f" PID: {process.pid}")
|
|
|
|
plotter_stdout_path = os.path.join(run_log_root, "plotter_stdout.txt")
|
|
plotter_process = subprocess.Popen(
|
|
[
|
|
sys.executable,
|
|
"-m",
|
|
"scripts.plot_live_training",
|
|
"--all-models",
|
|
"--run",
|
|
run_timestamp,
|
|
"--watch",
|
|
"--interval",
|
|
str(PLOT_REFRESH_INTERVAL),
|
|
],
|
|
stdout=open(plotter_stdout_path, "w", encoding="utf-8"),
|
|
stderr=subprocess.STDOUT,
|
|
)
|
|
print(f"[{datetime.now().strftime('%H:%M:%S')}] 启动联训总览绘图...")
|
|
print(f" PID: {plotter_process.pid}")
|
|
|
|
print(f"\n全部 {len(AGENTS)} 个训练已启动,等待完成...")
|
|
print(f"本次多模型时间戳: {run_timestamp}\n")
|
|
|
|
try:
|
|
for agent, process in processes.items():
|
|
process.wait()
|
|
status = "完成" if process.returncode == 0 else f"失败(code={process.returncode})"
|
|
print(f"[{agent.upper()}] {status}")
|
|
finally:
|
|
if plotter_process.poll() is None:
|
|
plotter_process.terminate()
|
|
try:
|
|
plotter_process.wait(timeout=10)
|
|
except subprocess.TimeoutExpired:
|
|
plotter_process.kill()
|
|
print("[PLOTTER] 已停止自动绘图进程")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|