ctm-dqn/training/run_model.py

24 lines
634 B
Python

"""Unified worker entrypoint for launching one model training job."""
import argparse
from training.registry import TRAINERS, normalize_model_name
from utils.run_dirs import add_run_dir_args
def main():
parser = add_run_dir_args(argparse.ArgumentParser())
parser.add_argument("--model", required=True, help="Model name to train.")
args = parser.parse_args()
model_name = normalize_model_name(args.model)
trainer = TRAINERS[model_name]
trainer(
log_dir=args.log_dir,
checkpoint_dir=args.checkpoint_dir,
run_timestamp=args.run_timestamp,
)
if __name__ == "__main__":
main()