24 lines
634 B
Python
24 lines
634 B
Python
"""Unified worker entrypoint for launching one model training job."""
|
|
import argparse
|
|
|
|
from training.registry import TRAINERS, normalize_model_name
|
|
from utils.run_dirs import add_run_dir_args
|
|
|
|
|
|
def main():
|
|
parser = add_run_dir_args(argparse.ArgumentParser())
|
|
parser.add_argument("--model", required=True, help="Model name to train.")
|
|
args = parser.parse_args()
|
|
|
|
model_name = normalize_model_name(args.model)
|
|
trainer = TRAINERS[model_name]
|
|
trainer(
|
|
log_dir=args.log_dir,
|
|
checkpoint_dir=args.checkpoint_dir,
|
|
run_timestamp=args.run_timestamp,
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|