ctm-dqn/utils/run_dirs.py

180 lines
5.9 KiB
Python

import argparse
import os
from datetime import datetime
from typing import Optional, Tuple
import yaml
RUNS_ROOT = "runs"
LEGACY_LOGS_ROOT = os.path.join("logs", "multi-model")
LEGACY_CHECKPOINTS_ROOT = os.path.join("checkpoints", "multi-model")
def add_run_dir_args(parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
parser.add_argument("--log-dir", type=str, default=None, help="Training log output directory.")
parser.add_argument("--checkpoint-dir", type=str, default=None, help="Model checkpoint output directory.")
parser.add_argument("--run-timestamp", type=str, default=None, help="Run timestamp tag for default directories.")
return parser
def default_run_timestamp() -> str:
return datetime.now().strftime("%Y%m%d_%H%M%S")
def resolve_run_root(run_timestamp: Optional[str] = None) -> Tuple[str, str]:
timestamp = run_timestamp or default_run_timestamp()
return timestamp, os.path.join(RUNS_ROOT, timestamp)
def resolve_run_dirs(
model_name: str,
log_dir: Optional[str] = None,
checkpoint_dir: Optional[str] = None,
run_timestamp: Optional[str] = None,
) -> Tuple[str, str, str]:
"""Resolve output directories from runtime args, defaulting to runs/<timestamp>/..."""
timestamp = run_timestamp or default_run_timestamp()
if checkpoint_dir is None or log_dir is None:
_, run_root = resolve_run_root(timestamp)
if checkpoint_dir is None:
checkpoint_dir = os.path.join(run_root, "checkpoints", model_name)
if log_dir is None:
log_dir = os.path.join(run_root, "logs", model_name)
return timestamp, checkpoint_dir, log_dir
def infer_run_root_from_paths(
log_dir: Optional[str] = None,
checkpoint_dir: Optional[str] = None,
run_timestamp: Optional[str] = None,
) -> Optional[str]:
candidate_paths = [path for path in (log_dir, checkpoint_dir) if path]
for path in candidate_paths:
normalized = os.path.normpath(os.path.abspath(path))
parent = os.path.dirname(normalized)
basename = os.path.basename(parent).lower()
if basename in {"logs", "checkpoints"}:
return os.path.dirname(parent)
grandparent = os.path.basename(os.path.dirname(parent)).lower()
if grandparent in {"logs", "checkpoints"}:
return os.path.dirname(os.path.dirname(parent))
if run_timestamp:
return os.path.join(RUNS_ROOT, run_timestamp)
return None
def write_shared_run_config(
runtime_config: dict,
*,
run_root: Optional[str] = None,
log_dir: Optional[str] = None,
checkpoint_dir: Optional[str] = None,
run_timestamp: Optional[str] = None,
) -> Optional[str]:
resolved_run_root = run_root or infer_run_root_from_paths(
log_dir=log_dir,
checkpoint_dir=checkpoint_dir,
run_timestamp=run_timestamp,
)
if not resolved_run_root:
return None
os.makedirs(resolved_run_root, exist_ok=True)
config_path = os.path.join(resolved_run_root, "config.yaml")
with open(config_path, "w", encoding="utf-8") as f:
yaml.dump(runtime_config, f)
return config_path
def find_latest_run_root() -> str:
candidates = []
if os.path.isdir(RUNS_ROOT):
candidates.extend(
os.path.join(RUNS_ROOT, item)
for item in os.listdir(RUNS_ROOT)
if os.path.isdir(os.path.join(RUNS_ROOT, item))
)
if candidates:
return max(candidates)
legacy_candidates = []
if os.path.isdir(LEGACY_CHECKPOINTS_ROOT):
legacy_candidates.extend(
os.path.join(LEGACY_CHECKPOINTS_ROOT, item)
for item in os.listdir(LEGACY_CHECKPOINTS_ROOT)
if os.path.isdir(os.path.join(LEGACY_CHECKPOINTS_ROOT, item))
)
if legacy_candidates:
return max(legacy_candidates)
raise FileNotFoundError(
f"No run directories found under '{RUNS_ROOT}' or legacy '{LEGACY_CHECKPOINTS_ROOT}'."
)
def find_run_root_by_timestamp(run_timestamp: str) -> str:
new_run_root = os.path.join(RUNS_ROOT, run_timestamp)
if os.path.isdir(new_run_root):
return new_run_root
legacy_checkpoint_root = os.path.join(LEGACY_CHECKPOINTS_ROOT, run_timestamp)
if os.path.isdir(legacy_checkpoint_root):
return legacy_checkpoint_root
legacy_log_root = os.path.join(LEGACY_LOGS_ROOT, run_timestamp)
if os.path.isdir(legacy_log_root):
return legacy_log_root
raise FileNotFoundError(
f"Run '{run_timestamp}' not found under '{RUNS_ROOT}' or legacy multi-model directories."
)
def resolve_checkpoint_root(input_path: Optional[str]) -> str:
if input_path:
abs_path = os.path.abspath(input_path)
else:
abs_path = os.path.abspath(find_latest_run_root())
if os.path.isdir(os.path.join(abs_path, "checkpoints")):
return os.path.join(abs_path, "checkpoints")
return abs_path
def resolve_log_root(input_path: Optional[str]) -> str:
if input_path:
abs_path = os.path.abspath(input_path)
else:
abs_path = os.path.abspath(find_latest_run_root())
if os.path.isdir(os.path.join(abs_path, "logs")):
return os.path.join(abs_path, "logs")
return abs_path
def find_shared_config_path(
checkpoint_dir: Optional[str] = None,
fallback_config_path: Optional[str] = None,
) -> Optional[str]:
if checkpoint_dir:
normalized = os.path.normpath(os.path.abspath(checkpoint_dir))
local_config = os.path.join(normalized, "config.yaml")
if os.path.isfile(local_config):
return local_config
parent = os.path.dirname(normalized)
if os.path.basename(parent).lower() == "checkpoints":
shared_config = os.path.join(os.path.dirname(parent), "config.yaml")
if os.path.isfile(shared_config):
return shared_config
if fallback_config_path and os.path.isfile(fallback_config_path):
return fallback_config_path
return None