302 lines
10 KiB
Python
302 lines
10 KiB
Python
"""Unified training launcher for one or more models."""
|
|
import argparse
|
|
import json
|
|
import os
|
|
import subprocess
|
|
import sys
|
|
import time
|
|
from datetime import datetime
|
|
|
|
from training.registry import DEFAULT_MODELS, normalize_model_list
|
|
from utils.run_dirs import resolve_run_root
|
|
|
|
|
|
PLOT_REFRESH_INTERVAL = 30
|
|
PROCESS_POLL_INTERVAL = 5
|
|
|
|
|
|
def parse_args():
|
|
parser = argparse.ArgumentParser(
|
|
description="Launch unified training for one or more models.",
|
|
)
|
|
parser.add_argument(
|
|
"--models",
|
|
nargs="*",
|
|
default=None,
|
|
help=(
|
|
"Model names to train. "
|
|
f"Defaults to: {', '.join(DEFAULT_MODELS)}"
|
|
),
|
|
)
|
|
parser.add_argument(
|
|
"--run-timestamp",
|
|
type=str,
|
|
default=None,
|
|
help="Run timestamp tag. Defaults to the current time.",
|
|
)
|
|
parser.add_argument(
|
|
"--interval",
|
|
type=int,
|
|
default=PLOT_REFRESH_INTERVAL,
|
|
help="Live plot refresh interval in seconds.",
|
|
)
|
|
parser.add_argument(
|
|
"--no-plotter",
|
|
action="store_true",
|
|
help="Disable the live plotter process.",
|
|
)
|
|
return parser.parse_args()
|
|
|
|
|
|
def launch_model_process(model_name, run_timestamp, run_log_root, run_ckpt_root):
|
|
log_dir = os.path.join(run_log_root, model_name)
|
|
checkpoint_dir = os.path.join(run_ckpt_root, model_name)
|
|
os.makedirs(log_dir, exist_ok=True)
|
|
os.makedirs(checkpoint_dir, exist_ok=True)
|
|
|
|
stdout_path = os.path.join(log_dir, "stdout.txt")
|
|
stdout_handle = open(stdout_path, "w", encoding="utf-8")
|
|
process = subprocess.Popen(
|
|
[
|
|
sys.executable,
|
|
"-m",
|
|
"training.run_model",
|
|
"--model",
|
|
model_name,
|
|
"--log-dir",
|
|
log_dir,
|
|
"--checkpoint-dir",
|
|
checkpoint_dir,
|
|
"--run-timestamp",
|
|
run_timestamp,
|
|
],
|
|
stdout=stdout_handle,
|
|
stderr=subprocess.STDOUT,
|
|
)
|
|
return process, stdout_handle, stdout_path
|
|
|
|
|
|
def launch_plotter(run_timestamp, run_log_root, interval):
|
|
stdout_path = os.path.join(run_log_root, "plotter_stdout.txt")
|
|
stdout_handle = open(stdout_path, "w", encoding="utf-8")
|
|
process = subprocess.Popen(
|
|
[
|
|
sys.executable,
|
|
"-m",
|
|
"scripts.plot_live_training",
|
|
"--all-models",
|
|
"--run",
|
|
run_timestamp,
|
|
"--watch",
|
|
"--interval",
|
|
str(interval),
|
|
],
|
|
stdout=stdout_handle,
|
|
stderr=subprocess.STDOUT,
|
|
)
|
|
return process, stdout_handle, stdout_path
|
|
|
|
|
|
def now_str():
|
|
return datetime.now().strftime("%H:%M:%S")
|
|
|
|
|
|
def read_log_tail(log_path, max_lines=40):
|
|
if not log_path or not os.path.isfile(log_path):
|
|
return []
|
|
with open(log_path, "r", encoding="utf-8", errors="replace") as f:
|
|
lines = f.readlines()
|
|
return [line.rstrip("\n") for line in lines[-max_lines:]]
|
|
|
|
|
|
def write_run_status(run_root, run_timestamp, model_states, plotter_state=None):
|
|
payload = {
|
|
"run_timestamp": run_timestamp,
|
|
"updated_at": datetime.now().isoformat(timespec="seconds"),
|
|
"models": model_states,
|
|
}
|
|
if plotter_state is not None:
|
|
payload["plotter"] = plotter_state
|
|
|
|
status_path = os.path.join(run_root, "run_status.json")
|
|
with open(status_path, "w", encoding="utf-8") as f:
|
|
json.dump(payload, f, ensure_ascii=False, indent=2)
|
|
|
|
|
|
def print_failure_excerpt(name, log_path, returncode):
|
|
print(f"[{now_str()}] {name} failed with code={returncode}")
|
|
print(f" Log: {log_path}")
|
|
tail_lines = read_log_tail(log_path)
|
|
if tail_lines:
|
|
print(" Last log lines:")
|
|
for line in tail_lines:
|
|
encoding = sys.stdout.encoding or "utf-8"
|
|
safe_line = line.encode(encoding, errors="replace").decode(encoding, errors="replace")
|
|
print(f" {safe_line}")
|
|
else:
|
|
print(" No readable log content found.")
|
|
|
|
|
|
def terminate_process(process):
|
|
if process is None or process.poll() is not None:
|
|
return
|
|
process.terminate()
|
|
try:
|
|
process.wait(timeout=10)
|
|
except subprocess.TimeoutExpired:
|
|
process.kill()
|
|
|
|
|
|
def monitor_processes(processes, run_root, run_timestamp, plotter_info=None):
|
|
pending = set(processes.keys())
|
|
model_states = {}
|
|
for model_name, info in processes.items():
|
|
model_states[model_name] = {
|
|
"status": "running",
|
|
"pid": info["process"].pid,
|
|
"log_dir": info["log_dir"],
|
|
"stdout_path": info["stdout_path"],
|
|
"checkpoint_dir": info["checkpoint_dir"],
|
|
"started_at": info["started_at"],
|
|
}
|
|
|
|
plotter_state = None
|
|
if plotter_info is not None:
|
|
plotter_state = {
|
|
"status": "running",
|
|
"pid": plotter_info["process"].pid,
|
|
"stdout_path": plotter_info["stdout_path"],
|
|
"started_at": plotter_info["started_at"],
|
|
}
|
|
|
|
write_run_status(run_root, run_timestamp, model_states, plotter_state)
|
|
|
|
while pending:
|
|
for model_name in list(pending):
|
|
info = processes[model_name]
|
|
process = info["process"]
|
|
returncode = process.poll()
|
|
if returncode is None:
|
|
continue
|
|
|
|
pending.remove(model_name)
|
|
status = "completed" if returncode == 0 else "failed"
|
|
model_states[model_name]["status"] = status
|
|
model_states[model_name]["returncode"] = returncode
|
|
model_states[model_name]["finished_at"] = datetime.now().isoformat(timespec="seconds")
|
|
write_run_status(run_root, run_timestamp, model_states, plotter_state)
|
|
|
|
if returncode == 0:
|
|
print(f"[{now_str()}] [{model_name.upper()}] completed")
|
|
else:
|
|
print_failure_excerpt(model_name.upper(), info["stdout_path"], returncode)
|
|
|
|
if plotter_info is not None and plotter_state is not None and plotter_state["status"] == "running":
|
|
plotter_returncode = plotter_info["process"].poll()
|
|
if plotter_returncode is not None:
|
|
plotter_state["status"] = "completed" if plotter_returncode == 0 else "failed"
|
|
plotter_state["returncode"] = plotter_returncode
|
|
plotter_state["finished_at"] = datetime.now().isoformat(timespec="seconds")
|
|
write_run_status(run_root, run_timestamp, model_states, plotter_state)
|
|
if plotter_returncode != 0:
|
|
print_failure_excerpt("PLOTTER", plotter_info["stdout_path"], plotter_returncode)
|
|
|
|
if pending:
|
|
time.sleep(PROCESS_POLL_INTERVAL)
|
|
|
|
return model_states, plotter_state
|
|
|
|
|
|
def main():
|
|
args = parse_args()
|
|
models = normalize_model_list(args.models)
|
|
run_timestamp, run_root = resolve_run_root(args.run_timestamp or datetime.now().strftime("%Y%m%d_%H%M%S"))
|
|
run_log_root = os.path.join(run_root, "logs")
|
|
run_ckpt_root = os.path.join(run_root, "checkpoints")
|
|
os.makedirs(run_log_root, exist_ok=True)
|
|
os.makedirs(run_ckpt_root, exist_ok=True)
|
|
|
|
processes = {}
|
|
handles = []
|
|
model_states = None
|
|
plotter_state = None
|
|
|
|
try:
|
|
for model_name in models:
|
|
print(f"[{datetime.now().strftime('%H:%M:%S')}] Launching {model_name.upper()} training...")
|
|
process, stdout_handle, stdout_path = launch_model_process(
|
|
model_name,
|
|
run_timestamp,
|
|
run_log_root,
|
|
run_ckpt_root,
|
|
)
|
|
processes[model_name] = {
|
|
"process": process,
|
|
"stdout_path": stdout_path,
|
|
"stdout_handle": stdout_handle,
|
|
"log_dir": os.path.join(run_log_root, model_name),
|
|
"checkpoint_dir": os.path.join(run_ckpt_root, model_name),
|
|
"started_at": datetime.now().isoformat(timespec="seconds"),
|
|
}
|
|
handles.append(stdout_handle)
|
|
print(f" PID: {process.pid}")
|
|
|
|
plotter_process = None
|
|
plotter_info = None
|
|
if not args.no_plotter:
|
|
plotter_process, plotter_stdout, plotter_stdout_path = launch_plotter(
|
|
run_timestamp,
|
|
run_log_root,
|
|
args.interval,
|
|
)
|
|
handles.append(plotter_stdout)
|
|
plotter_info = {
|
|
"process": plotter_process,
|
|
"stdout_path": plotter_stdout_path,
|
|
"stdout_handle": plotter_stdout,
|
|
"started_at": datetime.now().isoformat(timespec="seconds"),
|
|
}
|
|
print(f"[{datetime.now().strftime('%H:%M:%S')}] Launching live plotter...")
|
|
print(f" PID: {plotter_process.pid}")
|
|
|
|
print()
|
|
print(f"Started {len(models)} training job(s).")
|
|
print(f"Run timestamp: {run_timestamp}")
|
|
print(f"Models: {', '.join(models)}")
|
|
print()
|
|
|
|
model_states, plotter_state = monitor_processes(
|
|
processes,
|
|
run_root,
|
|
run_timestamp,
|
|
plotter_info=plotter_info,
|
|
)
|
|
failed_models = [
|
|
model_name for model_name, state in model_states.items()
|
|
if state["status"] != "completed"
|
|
]
|
|
if failed_models:
|
|
print()
|
|
print(f"[{now_str()}] Failed models: {', '.join(failed_models)}")
|
|
print(f"[{now_str()}] See run status: {os.path.join(run_root, 'run_status.json')}")
|
|
sys.exit(1)
|
|
finally:
|
|
if "plotter_process" in locals():
|
|
terminate_process(plotter_process)
|
|
if plotter_state is not None and plotter_state.get("status") == "running":
|
|
plotter_state["status"] = "terminated"
|
|
plotter_state["finished_at"] = datetime.now().isoformat(timespec="seconds")
|
|
|
|
for info in processes.values():
|
|
terminate_process(info["process"])
|
|
|
|
if model_states is not None:
|
|
write_run_status(run_root, run_timestamp, model_states, plotter_state)
|
|
|
|
for handle in handles:
|
|
handle.close()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|