import argparse import os from datetime import datetime from typing import Optional, Tuple import yaml RUNS_ROOT = "runs" LEGACY_LOGS_ROOT = os.path.join("logs", "multi-model") LEGACY_CHECKPOINTS_ROOT = os.path.join("checkpoints", "multi-model") def add_run_dir_args(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: parser.add_argument("--log-dir", type=str, default=None, help="Training log output directory.") parser.add_argument("--checkpoint-dir", type=str, default=None, help="Model checkpoint output directory.") parser.add_argument("--run-timestamp", type=str, default=None, help="Run timestamp tag for default directories.") return parser def default_run_timestamp() -> str: return datetime.now().strftime("%Y%m%d_%H%M%S") def resolve_run_root(run_timestamp: Optional[str] = None) -> Tuple[str, str]: timestamp = run_timestamp or default_run_timestamp() return timestamp, os.path.join(RUNS_ROOT, timestamp) def resolve_run_dirs( model_name: str, log_dir: Optional[str] = None, checkpoint_dir: Optional[str] = None, run_timestamp: Optional[str] = None, ) -> Tuple[str, str, str]: """Resolve output directories from runtime args, defaulting to runs//...""" timestamp = run_timestamp or default_run_timestamp() if checkpoint_dir is None or log_dir is None: _, run_root = resolve_run_root(timestamp) if checkpoint_dir is None: checkpoint_dir = os.path.join(run_root, "checkpoints", model_name) if log_dir is None: log_dir = os.path.join(run_root, "logs", model_name) return timestamp, checkpoint_dir, log_dir def infer_run_root_from_paths( log_dir: Optional[str] = None, checkpoint_dir: Optional[str] = None, run_timestamp: Optional[str] = None, ) -> Optional[str]: candidate_paths = [path for path in (log_dir, checkpoint_dir) if path] for path in candidate_paths: normalized = os.path.normpath(os.path.abspath(path)) parent = os.path.dirname(normalized) basename = os.path.basename(parent).lower() if basename in {"logs", "checkpoints"}: return os.path.dirname(parent) grandparent = os.path.basename(os.path.dirname(parent)).lower() if grandparent in {"logs", "checkpoints"}: return os.path.dirname(os.path.dirname(parent)) if run_timestamp: return os.path.join(RUNS_ROOT, run_timestamp) return None def write_shared_run_config( runtime_config: dict, *, run_root: Optional[str] = None, log_dir: Optional[str] = None, checkpoint_dir: Optional[str] = None, run_timestamp: Optional[str] = None, ) -> Optional[str]: resolved_run_root = run_root or infer_run_root_from_paths( log_dir=log_dir, checkpoint_dir=checkpoint_dir, run_timestamp=run_timestamp, ) if not resolved_run_root: return None os.makedirs(resolved_run_root, exist_ok=True) config_path = os.path.join(resolved_run_root, "config.yaml") with open(config_path, "w", encoding="utf-8") as f: yaml.dump(runtime_config, f) return config_path def find_latest_run_root() -> str: candidates = [] if os.path.isdir(RUNS_ROOT): candidates.extend( os.path.join(RUNS_ROOT, item) for item in os.listdir(RUNS_ROOT) if os.path.isdir(os.path.join(RUNS_ROOT, item)) ) if candidates: return max(candidates) legacy_candidates = [] if os.path.isdir(LEGACY_CHECKPOINTS_ROOT): legacy_candidates.extend( os.path.join(LEGACY_CHECKPOINTS_ROOT, item) for item in os.listdir(LEGACY_CHECKPOINTS_ROOT) if os.path.isdir(os.path.join(LEGACY_CHECKPOINTS_ROOT, item)) ) if legacy_candidates: return max(legacy_candidates) raise FileNotFoundError( f"No run directories found under '{RUNS_ROOT}' or legacy '{LEGACY_CHECKPOINTS_ROOT}'." ) def find_run_root_by_timestamp(run_timestamp: str) -> str: new_run_root = os.path.join(RUNS_ROOT, run_timestamp) if os.path.isdir(new_run_root): return new_run_root legacy_checkpoint_root = os.path.join(LEGACY_CHECKPOINTS_ROOT, run_timestamp) if os.path.isdir(legacy_checkpoint_root): return legacy_checkpoint_root legacy_log_root = os.path.join(LEGACY_LOGS_ROOT, run_timestamp) if os.path.isdir(legacy_log_root): return legacy_log_root raise FileNotFoundError( f"Run '{run_timestamp}' not found under '{RUNS_ROOT}' or legacy multi-model directories." ) def resolve_checkpoint_root(input_path: Optional[str]) -> str: if input_path: abs_path = os.path.abspath(input_path) else: abs_path = os.path.abspath(find_latest_run_root()) if os.path.isdir(os.path.join(abs_path, "checkpoints")): return os.path.join(abs_path, "checkpoints") return abs_path def resolve_log_root(input_path: Optional[str]) -> str: if input_path: abs_path = os.path.abspath(input_path) else: abs_path = os.path.abspath(find_latest_run_root()) if os.path.isdir(os.path.join(abs_path, "logs")): return os.path.join(abs_path, "logs") return abs_path def find_shared_config_path( checkpoint_dir: Optional[str] = None, fallback_config_path: Optional[str] = None, ) -> Optional[str]: if checkpoint_dir: normalized = os.path.normpath(os.path.abspath(checkpoint_dir)) local_config = os.path.join(normalized, "config.yaml") if os.path.isfile(local_config): return local_config parent = os.path.dirname(normalized) if os.path.basename(parent).lower() == "checkpoints": shared_config = os.path.join(os.path.dirname(parent), "config.yaml") if os.path.isfile(shared_config): return shared_config if fallback_config_path and os.path.isfile(fallback_config_path): return fallback_config_path return None