29 lines
1.1 KiB
Python
29 lines
1.1 KiB
Python
import argparse
|
|
import os
|
|
from datetime import datetime
|
|
from typing import Optional, Tuple
|
|
|
|
|
|
def add_run_dir_args(parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
|
|
parser.add_argument("--log-dir", type=str, default=None, help="Training log output directory.")
|
|
parser.add_argument("--checkpoint-dir", type=str, default=None, help="Model checkpoint output directory.")
|
|
parser.add_argument("--run-timestamp", type=str, default=None, help="Run timestamp tag for default directories.")
|
|
return parser
|
|
|
|
|
|
def resolve_run_dirs(
|
|
model_name: str,
|
|
log_dir: Optional[str] = None,
|
|
checkpoint_dir: Optional[str] = None,
|
|
run_timestamp: Optional[str] = None,
|
|
) -> Tuple[str, str, str]:
|
|
"""Resolve output directories from runtime args, falling back to per-model defaults."""
|
|
timestamp = run_timestamp or datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
|
|
if checkpoint_dir is None:
|
|
checkpoint_dir = os.path.join("checkpoints", model_name, timestamp)
|
|
if log_dir is None:
|
|
log_dir = os.path.join("logs", model_name, timestamp)
|
|
|
|
return timestamp, checkpoint_dir, log_dir
|