156 lines
4.6 KiB
Python
156 lines
4.6 KiB
Python
"""Unified training launcher for one or more models."""
|
|
import argparse
|
|
import os
|
|
import subprocess
|
|
import sys
|
|
from datetime import datetime
|
|
|
|
from training.registry import DEFAULT_MODELS, normalize_model_list
|
|
|
|
|
|
PLOT_REFRESH_INTERVAL = 30
|
|
|
|
|
|
def parse_args():
|
|
parser = argparse.ArgumentParser(
|
|
description="Launch unified training for one or more models.",
|
|
)
|
|
parser.add_argument(
|
|
"--models",
|
|
nargs="*",
|
|
default=None,
|
|
help=(
|
|
"Model names to train. "
|
|
f"Defaults to: {', '.join(DEFAULT_MODELS)}"
|
|
),
|
|
)
|
|
parser.add_argument(
|
|
"--run-timestamp",
|
|
type=str,
|
|
default=None,
|
|
help="Run timestamp tag. Defaults to the current time.",
|
|
)
|
|
parser.add_argument(
|
|
"--interval",
|
|
type=int,
|
|
default=PLOT_REFRESH_INTERVAL,
|
|
help="Live plot refresh interval in seconds.",
|
|
)
|
|
parser.add_argument(
|
|
"--no-plotter",
|
|
action="store_true",
|
|
help="Disable the live plotter process.",
|
|
)
|
|
return parser.parse_args()
|
|
|
|
|
|
def launch_model_process(model_name, run_timestamp, run_log_root, run_ckpt_root):
|
|
log_dir = os.path.join(run_log_root, model_name)
|
|
checkpoint_dir = os.path.join(run_ckpt_root, model_name)
|
|
os.makedirs(log_dir, exist_ok=True)
|
|
os.makedirs(checkpoint_dir, exist_ok=True)
|
|
|
|
stdout_path = os.path.join(log_dir, "stdout.txt")
|
|
stdout_handle = open(stdout_path, "w", encoding="utf-8")
|
|
process = subprocess.Popen(
|
|
[
|
|
sys.executable,
|
|
"-m",
|
|
"training.run_model",
|
|
"--model",
|
|
model_name,
|
|
"--log-dir",
|
|
log_dir,
|
|
"--checkpoint-dir",
|
|
checkpoint_dir,
|
|
"--run-timestamp",
|
|
run_timestamp,
|
|
],
|
|
stdout=stdout_handle,
|
|
stderr=subprocess.STDOUT,
|
|
)
|
|
return process, stdout_handle
|
|
|
|
|
|
def launch_plotter(run_timestamp, run_log_root, interval):
|
|
stdout_path = os.path.join(run_log_root, "plotter_stdout.txt")
|
|
stdout_handle = open(stdout_path, "w", encoding="utf-8")
|
|
process = subprocess.Popen(
|
|
[
|
|
sys.executable,
|
|
"-m",
|
|
"scripts.plot_live_training",
|
|
"--all-models",
|
|
"--run",
|
|
run_timestamp,
|
|
"--watch",
|
|
"--interval",
|
|
str(interval),
|
|
],
|
|
stdout=stdout_handle,
|
|
stderr=subprocess.STDOUT,
|
|
)
|
|
return process, stdout_handle
|
|
|
|
|
|
def main():
|
|
args = parse_args()
|
|
models = normalize_model_list(args.models)
|
|
run_timestamp = args.run_timestamp or datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
run_log_root = os.path.join("logs", "multi-model", run_timestamp)
|
|
run_ckpt_root = os.path.join("checkpoints", "multi-model", run_timestamp)
|
|
os.makedirs(run_log_root, exist_ok=True)
|
|
os.makedirs(run_ckpt_root, exist_ok=True)
|
|
|
|
processes = {}
|
|
handles = []
|
|
|
|
try:
|
|
for model_name in models:
|
|
print(f"[{datetime.now().strftime('%H:%M:%S')}] Launching {model_name.upper()} training...")
|
|
process, stdout_handle = launch_model_process(
|
|
model_name,
|
|
run_timestamp,
|
|
run_log_root,
|
|
run_ckpt_root,
|
|
)
|
|
processes[model_name] = process
|
|
handles.append(stdout_handle)
|
|
print(f" PID: {process.pid}")
|
|
|
|
plotter_process = None
|
|
if not args.no_plotter:
|
|
plotter_process, plotter_stdout = launch_plotter(
|
|
run_timestamp,
|
|
run_log_root,
|
|
args.interval,
|
|
)
|
|
handles.append(plotter_stdout)
|
|
print(f"[{datetime.now().strftime('%H:%M:%S')}] Launching live plotter...")
|
|
print(f" PID: {plotter_process.pid}")
|
|
|
|
print()
|
|
print(f"Started {len(models)} training job(s).")
|
|
print(f"Run timestamp: {run_timestamp}")
|
|
print(f"Models: {', '.join(models)}")
|
|
print()
|
|
|
|
for model_name, process in processes.items():
|
|
process.wait()
|
|
status = "completed" if process.returncode == 0 else f"failed (code={process.returncode})"
|
|
print(f"[{model_name.upper()}] {status}")
|
|
finally:
|
|
if "plotter_process" in locals() and plotter_process is not None and plotter_process.poll() is None:
|
|
plotter_process.terminate()
|
|
try:
|
|
plotter_process.wait(timeout=10)
|
|
except subprocess.TimeoutExpired:
|
|
plotter_process.kill()
|
|
|
|
for handle in handles:
|
|
handle.close()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|