ctm-dqn/run_all_training.py

300 lines
9.8 KiB
Python

"""Unified training launcher for one or more models."""
import argparse
import json
import os
import subprocess
import sys
import time
from datetime import datetime
from training.registry import DEFAULT_MODELS, normalize_model_list
from utils.run_dirs import resolve_run_root
PLOT_REFRESH_INTERVAL = 30
PROCESS_POLL_INTERVAL = 5
def parse_args():
parser = argparse.ArgumentParser(
description="Launch unified training for one or more models.",
)
parser.add_argument(
"--models",
nargs="*",
default=None,
help=(
"Model names to train. "
f"Defaults to: {', '.join(DEFAULT_MODELS)}"
),
)
parser.add_argument(
"--run-timestamp",
type=str,
default=None,
help="Run timestamp tag. Defaults to the current time.",
)
parser.add_argument(
"--interval",
type=int,
default=PLOT_REFRESH_INTERVAL,
help="Live plot refresh interval in seconds.",
)
parser.add_argument(
"--no-plotter",
action="store_true",
help="Disable the live plotter process.",
)
return parser.parse_args()
def launch_model_process(model_name, run_timestamp, run_log_root, run_ckpt_root):
log_dir = os.path.join(run_log_root, model_name)
checkpoint_dir = os.path.join(run_ckpt_root, model_name)
os.makedirs(log_dir, exist_ok=True)
os.makedirs(checkpoint_dir, exist_ok=True)
stdout_path = os.path.join(log_dir, "stdout.txt")
stdout_handle = open(stdout_path, "w", encoding="utf-8")
process = subprocess.Popen(
[
sys.executable,
"-m",
"training.run_model",
"--model",
model_name,
"--log-dir",
log_dir,
"--checkpoint-dir",
checkpoint_dir,
"--run-timestamp",
run_timestamp,
],
stdout=stdout_handle,
stderr=subprocess.STDOUT,
)
return process, stdout_handle, stdout_path
def launch_plotter(run_timestamp, run_log_root, interval):
stdout_path = os.path.join(run_log_root, "plotter_stdout.txt")
stdout_handle = open(stdout_path, "w", encoding="utf-8")
process = subprocess.Popen(
[
sys.executable,
"-m",
"scripts.plot_live_training",
"--all-models",
"--run",
run_timestamp,
"--watch",
"--interval",
str(interval),
],
stdout=stdout_handle,
stderr=subprocess.STDOUT,
)
return process, stdout_handle, stdout_path
def now_str():
return datetime.now().strftime("%H:%M:%S")
def read_log_tail(log_path, max_lines=40):
if not log_path or not os.path.isfile(log_path):
return []
with open(log_path, "r", encoding="utf-8", errors="replace") as f:
lines = f.readlines()
return [line.rstrip("\n") for line in lines[-max_lines:]]
def write_run_status(run_root, run_timestamp, model_states, plotter_state=None):
payload = {
"run_timestamp": run_timestamp,
"updated_at": datetime.now().isoformat(timespec="seconds"),
"models": model_states,
}
if plotter_state is not None:
payload["plotter"] = plotter_state
status_path = os.path.join(run_root, "run_status.json")
with open(status_path, "w", encoding="utf-8") as f:
json.dump(payload, f, ensure_ascii=False, indent=2)
def print_failure_excerpt(name, log_path, returncode):
print(f"[{now_str()}] {name} failed with code={returncode}")
print(f" Log: {log_path}")
tail_lines = read_log_tail(log_path)
if tail_lines:
print(" Last log lines:")
for line in tail_lines:
print(f" {line}")
else:
print(" No readable log content found.")
def terminate_process(process):
if process is None or process.poll() is not None:
return
process.terminate()
try:
process.wait(timeout=10)
except subprocess.TimeoutExpired:
process.kill()
def monitor_processes(processes, run_root, run_timestamp, plotter_info=None):
pending = set(processes.keys())
model_states = {}
for model_name, info in processes.items():
model_states[model_name] = {
"status": "running",
"pid": info["process"].pid,
"log_dir": info["log_dir"],
"stdout_path": info["stdout_path"],
"checkpoint_dir": info["checkpoint_dir"],
"started_at": info["started_at"],
}
plotter_state = None
if plotter_info is not None:
plotter_state = {
"status": "running",
"pid": plotter_info["process"].pid,
"stdout_path": plotter_info["stdout_path"],
"started_at": plotter_info["started_at"],
}
write_run_status(run_root, run_timestamp, model_states, plotter_state)
while pending:
for model_name in list(pending):
info = processes[model_name]
process = info["process"]
returncode = process.poll()
if returncode is None:
continue
pending.remove(model_name)
status = "completed" if returncode == 0 else "failed"
model_states[model_name]["status"] = status
model_states[model_name]["returncode"] = returncode
model_states[model_name]["finished_at"] = datetime.now().isoformat(timespec="seconds")
write_run_status(run_root, run_timestamp, model_states, plotter_state)
if returncode == 0:
print(f"[{now_str()}] [{model_name.upper()}] completed")
else:
print_failure_excerpt(model_name.upper(), info["stdout_path"], returncode)
if plotter_info is not None and plotter_state is not None and plotter_state["status"] == "running":
plotter_returncode = plotter_info["process"].poll()
if plotter_returncode is not None:
plotter_state["status"] = "completed" if plotter_returncode == 0 else "failed"
plotter_state["returncode"] = plotter_returncode
plotter_state["finished_at"] = datetime.now().isoformat(timespec="seconds")
write_run_status(run_root, run_timestamp, model_states, plotter_state)
if plotter_returncode != 0:
print_failure_excerpt("PLOTTER", plotter_info["stdout_path"], plotter_returncode)
if pending:
time.sleep(PROCESS_POLL_INTERVAL)
return model_states, plotter_state
def main():
args = parse_args()
models = normalize_model_list(args.models)
run_timestamp, run_root = resolve_run_root(args.run_timestamp or datetime.now().strftime("%Y%m%d_%H%M%S"))
run_log_root = os.path.join(run_root, "logs")
run_ckpt_root = os.path.join(run_root, "checkpoints")
os.makedirs(run_log_root, exist_ok=True)
os.makedirs(run_ckpt_root, exist_ok=True)
processes = {}
handles = []
model_states = None
plotter_state = None
try:
for model_name in models:
print(f"[{datetime.now().strftime('%H:%M:%S')}] Launching {model_name.upper()} training...")
process, stdout_handle, stdout_path = launch_model_process(
model_name,
run_timestamp,
run_log_root,
run_ckpt_root,
)
processes[model_name] = {
"process": process,
"stdout_path": stdout_path,
"stdout_handle": stdout_handle,
"log_dir": os.path.join(run_log_root, model_name),
"checkpoint_dir": os.path.join(run_ckpt_root, model_name),
"started_at": datetime.now().isoformat(timespec="seconds"),
}
handles.append(stdout_handle)
print(f" PID: {process.pid}")
plotter_process = None
plotter_info = None
if not args.no_plotter:
plotter_process, plotter_stdout, plotter_stdout_path = launch_plotter(
run_timestamp,
run_log_root,
args.interval,
)
handles.append(plotter_stdout)
plotter_info = {
"process": plotter_process,
"stdout_path": plotter_stdout_path,
"stdout_handle": plotter_stdout,
"started_at": datetime.now().isoformat(timespec="seconds"),
}
print(f"[{datetime.now().strftime('%H:%M:%S')}] Launching live plotter...")
print(f" PID: {plotter_process.pid}")
print()
print(f"Started {len(models)} training job(s).")
print(f"Run timestamp: {run_timestamp}")
print(f"Models: {', '.join(models)}")
print()
model_states, plotter_state = monitor_processes(
processes,
run_root,
run_timestamp,
plotter_info=plotter_info,
)
failed_models = [
model_name for model_name, state in model_states.items()
if state["status"] != "completed"
]
if failed_models:
print()
print(f"[{now_str()}] Failed models: {', '.join(failed_models)}")
print(f"[{now_str()}] See run status: {os.path.join(run_root, 'run_status.json')}")
sys.exit(1)
finally:
if "plotter_process" in locals():
terminate_process(plotter_process)
if plotter_state is not None and plotter_state.get("status") == "running":
plotter_state["status"] = "terminated"
plotter_state["finished_at"] = datetime.now().isoformat(timespec="seconds")
for info in processes.values():
terminate_process(info["process"])
if model_states is not None:
write_run_status(run_root, run_timestamp, model_states, plotter_state)
for handle in handles:
handle.close()
if __name__ == "__main__":
main()