ctm-dqn/run_all_training.py

156 lines
4.6 KiB
Python

"""Unified training launcher for one or more models."""
import argparse
import os
import subprocess
import sys
from datetime import datetime
from training.registry import DEFAULT_MODELS, normalize_model_list
PLOT_REFRESH_INTERVAL = 30
def parse_args():
parser = argparse.ArgumentParser(
description="Launch unified training for one or more models.",
)
parser.add_argument(
"--models",
nargs="*",
default=None,
help=(
"Model names to train. "
f"Defaults to: {', '.join(DEFAULT_MODELS)}"
),
)
parser.add_argument(
"--run-timestamp",
type=str,
default=None,
help="Run timestamp tag. Defaults to the current time.",
)
parser.add_argument(
"--interval",
type=int,
default=PLOT_REFRESH_INTERVAL,
help="Live plot refresh interval in seconds.",
)
parser.add_argument(
"--no-plotter",
action="store_true",
help="Disable the live plotter process.",
)
return parser.parse_args()
def launch_model_process(model_name, run_timestamp, run_log_root, run_ckpt_root):
log_dir = os.path.join(run_log_root, model_name)
checkpoint_dir = os.path.join(run_ckpt_root, model_name)
os.makedirs(log_dir, exist_ok=True)
os.makedirs(checkpoint_dir, exist_ok=True)
stdout_path = os.path.join(log_dir, "stdout.txt")
stdout_handle = open(stdout_path, "w", encoding="utf-8")
process = subprocess.Popen(
[
sys.executable,
"-m",
"training.run_model",
"--model",
model_name,
"--log-dir",
log_dir,
"--checkpoint-dir",
checkpoint_dir,
"--run-timestamp",
run_timestamp,
],
stdout=stdout_handle,
stderr=subprocess.STDOUT,
)
return process, stdout_handle
def launch_plotter(run_timestamp, run_log_root, interval):
stdout_path = os.path.join(run_log_root, "plotter_stdout.txt")
stdout_handle = open(stdout_path, "w", encoding="utf-8")
process = subprocess.Popen(
[
sys.executable,
"-m",
"scripts.plot_live_training",
"--all-models",
"--run",
run_timestamp,
"--watch",
"--interval",
str(interval),
],
stdout=stdout_handle,
stderr=subprocess.STDOUT,
)
return process, stdout_handle
def main():
args = parse_args()
models = normalize_model_list(args.models)
run_timestamp = args.run_timestamp or datetime.now().strftime("%Y%m%d_%H%M%S")
run_log_root = os.path.join("logs", "multi-model", run_timestamp)
run_ckpt_root = os.path.join("checkpoints", "multi-model", run_timestamp)
os.makedirs(run_log_root, exist_ok=True)
os.makedirs(run_ckpt_root, exist_ok=True)
processes = {}
handles = []
try:
for model_name in models:
print(f"[{datetime.now().strftime('%H:%M:%S')}] Launching {model_name.upper()} training...")
process, stdout_handle = launch_model_process(
model_name,
run_timestamp,
run_log_root,
run_ckpt_root,
)
processes[model_name] = process
handles.append(stdout_handle)
print(f" PID: {process.pid}")
plotter_process = None
if not args.no_plotter:
plotter_process, plotter_stdout = launch_plotter(
run_timestamp,
run_log_root,
args.interval,
)
handles.append(plotter_stdout)
print(f"[{datetime.now().strftime('%H:%M:%S')}] Launching live plotter...")
print(f" PID: {plotter_process.pid}")
print()
print(f"Started {len(models)} training job(s).")
print(f"Run timestamp: {run_timestamp}")
print(f"Models: {', '.join(models)}")
print()
for model_name, process in processes.items():
process.wait()
status = "completed" if process.returncode == 0 else f"failed (code={process.returncode})"
print(f"[{model_name.upper()}] {status}")
finally:
if "plotter_process" in locals() and plotter_process is not None and plotter_process.poll() is None:
plotter_process.terminate()
try:
plotter_process.wait(timeout=10)
except subprocess.TimeoutExpired:
plotter_process.kill()
for handle in handles:
handle.close()
if __name__ == "__main__":
main()