60 lines
1.8 KiB
Python
60 lines
1.8 KiB
Python
"""Unified worker entrypoint for launching one model training job."""
|
|
import argparse
|
|
import json
|
|
import os
|
|
import sys
|
|
import traceback
|
|
from datetime import datetime
|
|
|
|
from training.registry import TRAINERS, normalize_model_name
|
|
from utils.run_dirs import add_run_dir_args
|
|
|
|
|
|
def persist_training_error(model_name: str, log_dir: str, exc: Exception) -> None:
|
|
if not log_dir:
|
|
return
|
|
|
|
os.makedirs(log_dir, exist_ok=True)
|
|
tb_text = traceback.format_exc()
|
|
error_txt_path = os.path.join(log_dir, "error_traceback.txt")
|
|
with open(error_txt_path, "w", encoding="utf-8") as f:
|
|
f.write(tb_text)
|
|
|
|
error_json_path = os.path.join(log_dir, "error_summary.json")
|
|
payload = {
|
|
"model": model_name,
|
|
"failed_at": datetime.now().isoformat(timespec="seconds"),
|
|
"exception_type": type(exc).__name__,
|
|
"message": str(exc),
|
|
"traceback_file": os.path.abspath(error_txt_path),
|
|
}
|
|
with open(error_json_path, "w", encoding="utf-8") as f:
|
|
json.dump(payload, f, ensure_ascii=False, indent=2)
|
|
|
|
|
|
def main():
|
|
parser = add_run_dir_args(argparse.ArgumentParser())
|
|
parser.add_argument("--model", required=True, help="Model name to train.")
|
|
args = parser.parse_args()
|
|
|
|
model_name = normalize_model_name(args.model)
|
|
trainer = TRAINERS[model_name]
|
|
try:
|
|
trainer(
|
|
log_dir=args.log_dir,
|
|
checkpoint_dir=args.checkpoint_dir,
|
|
run_timestamp=args.run_timestamp,
|
|
)
|
|
except Exception as exc:
|
|
persist_training_error(model_name, args.log_dir, exc)
|
|
print(
|
|
f"[{datetime.now().strftime('%H:%M:%S')}] {model_name.upper()} training failed: {type(exc).__name__}: {exc}",
|
|
file=sys.stderr,
|
|
)
|
|
traceback.print_exc()
|
|
raise
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|