180 lines
5.9 KiB
Python
180 lines
5.9 KiB
Python
import argparse
|
|
import os
|
|
from datetime import datetime
|
|
from typing import Optional, Tuple
|
|
|
|
import yaml
|
|
|
|
|
|
RUNS_ROOT = "runs"
|
|
LEGACY_LOGS_ROOT = os.path.join("logs", "multi-model")
|
|
LEGACY_CHECKPOINTS_ROOT = os.path.join("checkpoints", "multi-model")
|
|
|
|
|
|
def add_run_dir_args(parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
|
|
parser.add_argument("--log-dir", type=str, default=None, help="Training log output directory.")
|
|
parser.add_argument("--checkpoint-dir", type=str, default=None, help="Model checkpoint output directory.")
|
|
parser.add_argument("--run-timestamp", type=str, default=None, help="Run timestamp tag for default directories.")
|
|
return parser
|
|
|
|
|
|
def default_run_timestamp() -> str:
|
|
return datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
|
|
|
|
def resolve_run_root(run_timestamp: Optional[str] = None) -> Tuple[str, str]:
|
|
timestamp = run_timestamp or default_run_timestamp()
|
|
return timestamp, os.path.join(RUNS_ROOT, timestamp)
|
|
|
|
|
|
def resolve_run_dirs(
|
|
model_name: str,
|
|
log_dir: Optional[str] = None,
|
|
checkpoint_dir: Optional[str] = None,
|
|
run_timestamp: Optional[str] = None,
|
|
) -> Tuple[str, str, str]:
|
|
"""Resolve output directories from runtime args, defaulting to runs/<timestamp>/..."""
|
|
timestamp = run_timestamp or default_run_timestamp()
|
|
|
|
if checkpoint_dir is None or log_dir is None:
|
|
_, run_root = resolve_run_root(timestamp)
|
|
if checkpoint_dir is None:
|
|
checkpoint_dir = os.path.join(run_root, "checkpoints", model_name)
|
|
if log_dir is None:
|
|
log_dir = os.path.join(run_root, "logs", model_name)
|
|
|
|
return timestamp, checkpoint_dir, log_dir
|
|
|
|
|
|
def infer_run_root_from_paths(
|
|
log_dir: Optional[str] = None,
|
|
checkpoint_dir: Optional[str] = None,
|
|
run_timestamp: Optional[str] = None,
|
|
) -> Optional[str]:
|
|
candidate_paths = [path for path in (log_dir, checkpoint_dir) if path]
|
|
for path in candidate_paths:
|
|
normalized = os.path.normpath(os.path.abspath(path))
|
|
parent = os.path.dirname(normalized)
|
|
basename = os.path.basename(parent).lower()
|
|
if basename in {"logs", "checkpoints"}:
|
|
return os.path.dirname(parent)
|
|
grandparent = os.path.basename(os.path.dirname(parent)).lower()
|
|
if grandparent in {"logs", "checkpoints"}:
|
|
return os.path.dirname(os.path.dirname(parent))
|
|
|
|
if run_timestamp:
|
|
return os.path.join(RUNS_ROOT, run_timestamp)
|
|
return None
|
|
|
|
|
|
def write_shared_run_config(
|
|
runtime_config: dict,
|
|
*,
|
|
run_root: Optional[str] = None,
|
|
log_dir: Optional[str] = None,
|
|
checkpoint_dir: Optional[str] = None,
|
|
run_timestamp: Optional[str] = None,
|
|
) -> Optional[str]:
|
|
resolved_run_root = run_root or infer_run_root_from_paths(
|
|
log_dir=log_dir,
|
|
checkpoint_dir=checkpoint_dir,
|
|
run_timestamp=run_timestamp,
|
|
)
|
|
if not resolved_run_root:
|
|
return None
|
|
|
|
os.makedirs(resolved_run_root, exist_ok=True)
|
|
config_path = os.path.join(resolved_run_root, "config.yaml")
|
|
with open(config_path, "w", encoding="utf-8") as f:
|
|
yaml.dump(runtime_config, f)
|
|
return config_path
|
|
|
|
|
|
def find_latest_run_root() -> str:
|
|
candidates = []
|
|
|
|
if os.path.isdir(RUNS_ROOT):
|
|
candidates.extend(
|
|
os.path.join(RUNS_ROOT, item)
|
|
for item in os.listdir(RUNS_ROOT)
|
|
if os.path.isdir(os.path.join(RUNS_ROOT, item))
|
|
)
|
|
|
|
if candidates:
|
|
return max(candidates)
|
|
|
|
legacy_candidates = []
|
|
if os.path.isdir(LEGACY_CHECKPOINTS_ROOT):
|
|
legacy_candidates.extend(
|
|
os.path.join(LEGACY_CHECKPOINTS_ROOT, item)
|
|
for item in os.listdir(LEGACY_CHECKPOINTS_ROOT)
|
|
if os.path.isdir(os.path.join(LEGACY_CHECKPOINTS_ROOT, item))
|
|
)
|
|
if legacy_candidates:
|
|
return max(legacy_candidates)
|
|
|
|
raise FileNotFoundError(
|
|
f"No run directories found under '{RUNS_ROOT}' or legacy '{LEGACY_CHECKPOINTS_ROOT}'."
|
|
)
|
|
|
|
|
|
def find_run_root_by_timestamp(run_timestamp: str) -> str:
|
|
new_run_root = os.path.join(RUNS_ROOT, run_timestamp)
|
|
if os.path.isdir(new_run_root):
|
|
return new_run_root
|
|
|
|
legacy_checkpoint_root = os.path.join(LEGACY_CHECKPOINTS_ROOT, run_timestamp)
|
|
if os.path.isdir(legacy_checkpoint_root):
|
|
return legacy_checkpoint_root
|
|
|
|
legacy_log_root = os.path.join(LEGACY_LOGS_ROOT, run_timestamp)
|
|
if os.path.isdir(legacy_log_root):
|
|
return legacy_log_root
|
|
|
|
raise FileNotFoundError(
|
|
f"Run '{run_timestamp}' not found under '{RUNS_ROOT}' or legacy multi-model directories."
|
|
)
|
|
|
|
|
|
def resolve_checkpoint_root(input_path: Optional[str]) -> str:
|
|
if input_path:
|
|
abs_path = os.path.abspath(input_path)
|
|
else:
|
|
abs_path = os.path.abspath(find_latest_run_root())
|
|
|
|
if os.path.isdir(os.path.join(abs_path, "checkpoints")):
|
|
return os.path.join(abs_path, "checkpoints")
|
|
return abs_path
|
|
|
|
|
|
def resolve_log_root(input_path: Optional[str]) -> str:
|
|
if input_path:
|
|
abs_path = os.path.abspath(input_path)
|
|
else:
|
|
abs_path = os.path.abspath(find_latest_run_root())
|
|
|
|
if os.path.isdir(os.path.join(abs_path, "logs")):
|
|
return os.path.join(abs_path, "logs")
|
|
return abs_path
|
|
|
|
|
|
def find_shared_config_path(
|
|
checkpoint_dir: Optional[str] = None,
|
|
fallback_config_path: Optional[str] = None,
|
|
) -> Optional[str]:
|
|
if checkpoint_dir:
|
|
normalized = os.path.normpath(os.path.abspath(checkpoint_dir))
|
|
local_config = os.path.join(normalized, "config.yaml")
|
|
if os.path.isfile(local_config):
|
|
return local_config
|
|
|
|
parent = os.path.dirname(normalized)
|
|
if os.path.basename(parent).lower() == "checkpoints":
|
|
shared_config = os.path.join(os.path.dirname(parent), "config.yaml")
|
|
if os.path.isfile(shared_config):
|
|
return shared_config
|
|
|
|
if fallback_config_path and os.path.isfile(fallback_config_path):
|
|
return fallback_config_path
|
|
return None
|