"""Unified training launcher for one or more models.""" import argparse import json import os import subprocess import sys import time from datetime import datetime from training.registry import DEFAULT_MODELS, normalize_model_list from utils.run_dirs import resolve_run_root PLOT_REFRESH_INTERVAL = 30 PROCESS_POLL_INTERVAL = 5 def parse_args(): parser = argparse.ArgumentParser( description="Launch unified training for one or more models.", ) parser.add_argument( "--models", nargs="*", default=None, help=( "Model names to train. " f"Defaults to: {', '.join(DEFAULT_MODELS)}" ), ) parser.add_argument( "--run-timestamp", type=str, default=None, help="Run timestamp tag. Defaults to the current time.", ) parser.add_argument( "--interval", type=int, default=PLOT_REFRESH_INTERVAL, help="Live plot refresh interval in seconds.", ) parser.add_argument( "--no-plotter", action="store_true", help="Disable the live plotter process.", ) return parser.parse_args() def launch_model_process(model_name, run_timestamp, run_log_root, run_ckpt_root): log_dir = os.path.join(run_log_root, model_name) checkpoint_dir = os.path.join(run_ckpt_root, model_name) os.makedirs(log_dir, exist_ok=True) os.makedirs(checkpoint_dir, exist_ok=True) stdout_path = os.path.join(log_dir, "stdout.txt") stdout_handle = open(stdout_path, "w", encoding="utf-8") process = subprocess.Popen( [ sys.executable, "-m", "training.run_model", "--model", model_name, "--log-dir", log_dir, "--checkpoint-dir", checkpoint_dir, "--run-timestamp", run_timestamp, ], stdout=stdout_handle, stderr=subprocess.STDOUT, ) return process, stdout_handle, stdout_path def launch_plotter(run_timestamp, run_log_root, interval): stdout_path = os.path.join(run_log_root, "plotter_stdout.txt") stdout_handle = open(stdout_path, "w", encoding="utf-8") process = subprocess.Popen( [ sys.executable, "-m", "scripts.plot_live_training", "--all-models", "--run", run_timestamp, "--watch", "--interval", str(interval), ], stdout=stdout_handle, stderr=subprocess.STDOUT, ) return process, stdout_handle, stdout_path def now_str(): return datetime.now().strftime("%H:%M:%S") def read_log_tail(log_path, max_lines=40): if not log_path or not os.path.isfile(log_path): return [] with open(log_path, "r", encoding="utf-8", errors="replace") as f: lines = f.readlines() return [line.rstrip("\n") for line in lines[-max_lines:]] def write_run_status(run_root, run_timestamp, model_states, plotter_state=None): payload = { "run_timestamp": run_timestamp, "updated_at": datetime.now().isoformat(timespec="seconds"), "models": model_states, } if plotter_state is not None: payload["plotter"] = plotter_state status_path = os.path.join(run_root, "run_status.json") with open(status_path, "w", encoding="utf-8") as f: json.dump(payload, f, ensure_ascii=False, indent=2) def print_failure_excerpt(name, log_path, returncode): print(f"[{now_str()}] {name} failed with code={returncode}") print(f" Log: {log_path}") tail_lines = read_log_tail(log_path) if tail_lines: print(" Last log lines:") for line in tail_lines: encoding = sys.stdout.encoding or "utf-8" safe_line = line.encode(encoding, errors="replace").decode(encoding, errors="replace") print(f" {safe_line}") else: print(" No readable log content found.") def terminate_process(process): if process is None or process.poll() is not None: return process.terminate() try: process.wait(timeout=10) except subprocess.TimeoutExpired: process.kill() def monitor_processes(processes, run_root, run_timestamp, plotter_info=None): pending = set(processes.keys()) model_states = {} for model_name, info in processes.items(): model_states[model_name] = { "status": "running", "pid": info["process"].pid, "log_dir": info["log_dir"], "stdout_path": info["stdout_path"], "checkpoint_dir": info["checkpoint_dir"], "started_at": info["started_at"], } plotter_state = None if plotter_info is not None: plotter_state = { "status": "running", "pid": plotter_info["process"].pid, "stdout_path": plotter_info["stdout_path"], "started_at": plotter_info["started_at"], } write_run_status(run_root, run_timestamp, model_states, plotter_state) while pending: for model_name in list(pending): info = processes[model_name] process = info["process"] returncode = process.poll() if returncode is None: continue pending.remove(model_name) status = "completed" if returncode == 0 else "failed" model_states[model_name]["status"] = status model_states[model_name]["returncode"] = returncode model_states[model_name]["finished_at"] = datetime.now().isoformat(timespec="seconds") write_run_status(run_root, run_timestamp, model_states, plotter_state) if returncode == 0: print(f"[{now_str()}] [{model_name.upper()}] completed") else: print_failure_excerpt(model_name.upper(), info["stdout_path"], returncode) if plotter_info is not None and plotter_state is not None and plotter_state["status"] == "running": plotter_returncode = plotter_info["process"].poll() if plotter_returncode is not None: plotter_state["status"] = "completed" if plotter_returncode == 0 else "failed" plotter_state["returncode"] = plotter_returncode plotter_state["finished_at"] = datetime.now().isoformat(timespec="seconds") write_run_status(run_root, run_timestamp, model_states, plotter_state) if plotter_returncode != 0: print_failure_excerpt("PLOTTER", plotter_info["stdout_path"], plotter_returncode) if pending: time.sleep(PROCESS_POLL_INTERVAL) return model_states, plotter_state def main(): args = parse_args() models = normalize_model_list(args.models) run_timestamp, run_root = resolve_run_root(args.run_timestamp or datetime.now().strftime("%Y%m%d_%H%M%S")) run_log_root = os.path.join(run_root, "logs") run_ckpt_root = os.path.join(run_root, "checkpoints") os.makedirs(run_log_root, exist_ok=True) os.makedirs(run_ckpt_root, exist_ok=True) processes = {} handles = [] model_states = None plotter_state = None try: for model_name in models: print(f"[{datetime.now().strftime('%H:%M:%S')}] Launching {model_name.upper()} training...") process, stdout_handle, stdout_path = launch_model_process( model_name, run_timestamp, run_log_root, run_ckpt_root, ) processes[model_name] = { "process": process, "stdout_path": stdout_path, "stdout_handle": stdout_handle, "log_dir": os.path.join(run_log_root, model_name), "checkpoint_dir": os.path.join(run_ckpt_root, model_name), "started_at": datetime.now().isoformat(timespec="seconds"), } handles.append(stdout_handle) print(f" PID: {process.pid}") plotter_process = None plotter_info = None if not args.no_plotter: plotter_process, plotter_stdout, plotter_stdout_path = launch_plotter( run_timestamp, run_log_root, args.interval, ) handles.append(plotter_stdout) plotter_info = { "process": plotter_process, "stdout_path": plotter_stdout_path, "stdout_handle": plotter_stdout, "started_at": datetime.now().isoformat(timespec="seconds"), } print(f"[{datetime.now().strftime('%H:%M:%S')}] Launching live plotter...") print(f" PID: {plotter_process.pid}") print() print(f"Started {len(models)} training job(s).") print(f"Run timestamp: {run_timestamp}") print(f"Models: {', '.join(models)}") print() model_states, plotter_state = monitor_processes( processes, run_root, run_timestamp, plotter_info=plotter_info, ) failed_models = [ model_name for model_name, state in model_states.items() if state["status"] != "completed" ] if failed_models: print() print(f"[{now_str()}] Failed models: {', '.join(failed_models)}") print(f"[{now_str()}] See run status: {os.path.join(run_root, 'run_status.json')}") sys.exit(1) finally: if "plotter_process" in locals(): terminate_process(plotter_process) if plotter_state is not None and plotter_state.get("status") == "running": plotter_state["status"] = "terminated" plotter_state["finished_at"] = datetime.now().isoformat(timespec="seconds") for info in processes.values(): terminate_process(info["process"]) if model_states is not None: write_run_status(run_root, run_timestamp, model_states, plotter_state) for handle in handles: handle.close() if __name__ == "__main__": main()