"""Unified training launcher for one or more models.""" import argparse import os import subprocess import sys from datetime import datetime from training.registry import DEFAULT_MODELS, normalize_model_list PLOT_REFRESH_INTERVAL = 30 def parse_args(): parser = argparse.ArgumentParser( description="Launch unified training for one or more models.", ) parser.add_argument( "--models", nargs="*", default=None, help=( "Model names to train. " f"Defaults to: {', '.join(DEFAULT_MODELS)}" ), ) parser.add_argument( "--run-timestamp", type=str, default=None, help="Run timestamp tag. Defaults to the current time.", ) parser.add_argument( "--interval", type=int, default=PLOT_REFRESH_INTERVAL, help="Live plot refresh interval in seconds.", ) parser.add_argument( "--no-plotter", action="store_true", help="Disable the live plotter process.", ) return parser.parse_args() def launch_model_process(model_name, run_timestamp, run_log_root, run_ckpt_root): log_dir = os.path.join(run_log_root, model_name) checkpoint_dir = os.path.join(run_ckpt_root, model_name) os.makedirs(log_dir, exist_ok=True) os.makedirs(checkpoint_dir, exist_ok=True) stdout_path = os.path.join(log_dir, "stdout.txt") stdout_handle = open(stdout_path, "w", encoding="utf-8") process = subprocess.Popen( [ sys.executable, "-m", "training.run_model", "--model", model_name, "--log-dir", log_dir, "--checkpoint-dir", checkpoint_dir, "--run-timestamp", run_timestamp, ], stdout=stdout_handle, stderr=subprocess.STDOUT, ) return process, stdout_handle def launch_plotter(run_timestamp, run_log_root, interval): stdout_path = os.path.join(run_log_root, "plotter_stdout.txt") stdout_handle = open(stdout_path, "w", encoding="utf-8") process = subprocess.Popen( [ sys.executable, "-m", "scripts.plot_live_training", "--all-models", "--run", run_timestamp, "--watch", "--interval", str(interval), ], stdout=stdout_handle, stderr=subprocess.STDOUT, ) return process, stdout_handle def main(): args = parse_args() models = normalize_model_list(args.models) run_timestamp = args.run_timestamp or datetime.now().strftime("%Y%m%d_%H%M%S") run_log_root = os.path.join("logs", "multi-model", run_timestamp) run_ckpt_root = os.path.join("checkpoints", "multi-model", run_timestamp) os.makedirs(run_log_root, exist_ok=True) os.makedirs(run_ckpt_root, exist_ok=True) processes = {} handles = [] try: for model_name in models: print(f"[{datetime.now().strftime('%H:%M:%S')}] Launching {model_name.upper()} training...") process, stdout_handle = launch_model_process( model_name, run_timestamp, run_log_root, run_ckpt_root, ) processes[model_name] = process handles.append(stdout_handle) print(f" PID: {process.pid}") plotter_process = None if not args.no_plotter: plotter_process, plotter_stdout = launch_plotter( run_timestamp, run_log_root, args.interval, ) handles.append(plotter_stdout) print(f"[{datetime.now().strftime('%H:%M:%S')}] Launching live plotter...") print(f" PID: {plotter_process.pid}") print() print(f"Started {len(models)} training job(s).") print(f"Run timestamp: {run_timestamp}") print(f"Models: {', '.join(models)}") print() for model_name, process in processes.items(): process.wait() status = "completed" if process.returncode == 0 else f"failed (code={process.returncode})" print(f"[{model_name.upper()}] {status}") finally: if "plotter_process" in locals() and plotter_process is not None and plotter_process.poll() is None: plotter_process.terminate() try: plotter_process.wait(timeout=10) except subprocess.TimeoutExpired: plotter_process.kill() for handle in handles: handle.close() if __name__ == "__main__": main()