ctm-dqn/training/run_model.py

60 lines
1.8 KiB
Python

"""Unified worker entrypoint for launching one model training job."""
import argparse
import json
import os
import sys
import traceback
from datetime import datetime
from training.registry import TRAINERS, normalize_model_name
from utils.run_dirs import add_run_dir_args
def persist_training_error(model_name: str, log_dir: str, exc: Exception) -> None:
if not log_dir:
return
os.makedirs(log_dir, exist_ok=True)
tb_text = traceback.format_exc()
error_txt_path = os.path.join(log_dir, "error_traceback.txt")
with open(error_txt_path, "w", encoding="utf-8") as f:
f.write(tb_text)
error_json_path = os.path.join(log_dir, "error_summary.json")
payload = {
"model": model_name,
"failed_at": datetime.now().isoformat(timespec="seconds"),
"exception_type": type(exc).__name__,
"message": str(exc),
"traceback_file": os.path.abspath(error_txt_path),
}
with open(error_json_path, "w", encoding="utf-8") as f:
json.dump(payload, f, ensure_ascii=False, indent=2)
def main():
parser = add_run_dir_args(argparse.ArgumentParser())
parser.add_argument("--model", required=True, help="Model name to train.")
args = parser.parse_args()
model_name = normalize_model_name(args.model)
trainer = TRAINERS[model_name]
try:
trainer(
log_dir=args.log_dir,
checkpoint_dir=args.checkpoint_dir,
run_timestamp=args.run_timestamp,
)
except Exception as exc:
persist_training_error(model_name, args.log_dir, exc)
print(
f"[{datetime.now().strftime('%H:%M:%S')}] {model_name.upper()} training failed: {type(exc).__name__}: {exc}",
file=sys.stderr,
)
traceback.print_exc()
raise
if __name__ == "__main__":
main()