统一模型训练时存储路径
This commit is contained in:
parent
9f3ce242b2
commit
cea9d42397
|
|
@ -1,14 +1,18 @@
|
|||
"""Unified training launcher for one or more models."""
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
from datetime import datetime
|
||||
|
||||
from training.registry import DEFAULT_MODELS, normalize_model_list
|
||||
from utils.run_dirs import resolve_run_root
|
||||
|
||||
|
||||
PLOT_REFRESH_INTERVAL = 30
|
||||
PROCESS_POLL_INTERVAL = 5
|
||||
|
||||
|
||||
def parse_args():
|
||||
|
|
@ -69,7 +73,7 @@ def launch_model_process(model_name, run_timestamp, run_log_root, run_ckpt_root)
|
|||
stdout=stdout_handle,
|
||||
stderr=subprocess.STDOUT,
|
||||
)
|
||||
return process, stdout_handle
|
||||
return process, stdout_handle, stdout_path
|
||||
|
||||
|
||||
def launch_plotter(run_timestamp, run_log_root, interval):
|
||||
|
|
@ -90,42 +94,166 @@ def launch_plotter(run_timestamp, run_log_root, interval):
|
|||
stdout=stdout_handle,
|
||||
stderr=subprocess.STDOUT,
|
||||
)
|
||||
return process, stdout_handle
|
||||
return process, stdout_handle, stdout_path
|
||||
|
||||
|
||||
def now_str():
|
||||
return datetime.now().strftime("%H:%M:%S")
|
||||
|
||||
|
||||
def read_log_tail(log_path, max_lines=40):
|
||||
if not log_path or not os.path.isfile(log_path):
|
||||
return []
|
||||
with open(log_path, "r", encoding="utf-8", errors="replace") as f:
|
||||
lines = f.readlines()
|
||||
return [line.rstrip("\n") for line in lines[-max_lines:]]
|
||||
|
||||
|
||||
def write_run_status(run_root, run_timestamp, model_states, plotter_state=None):
|
||||
payload = {
|
||||
"run_timestamp": run_timestamp,
|
||||
"updated_at": datetime.now().isoformat(timespec="seconds"),
|
||||
"models": model_states,
|
||||
}
|
||||
if plotter_state is not None:
|
||||
payload["plotter"] = plotter_state
|
||||
|
||||
status_path = os.path.join(run_root, "run_status.json")
|
||||
with open(status_path, "w", encoding="utf-8") as f:
|
||||
json.dump(payload, f, ensure_ascii=False, indent=2)
|
||||
|
||||
|
||||
def print_failure_excerpt(name, log_path, returncode):
|
||||
print(f"[{now_str()}] {name} failed with code={returncode}")
|
||||
print(f" Log: {log_path}")
|
||||
tail_lines = read_log_tail(log_path)
|
||||
if tail_lines:
|
||||
print(" Last log lines:")
|
||||
for line in tail_lines:
|
||||
print(f" {line}")
|
||||
else:
|
||||
print(" No readable log content found.")
|
||||
|
||||
|
||||
def terminate_process(process):
|
||||
if process is None or process.poll() is not None:
|
||||
return
|
||||
process.terminate()
|
||||
try:
|
||||
process.wait(timeout=10)
|
||||
except subprocess.TimeoutExpired:
|
||||
process.kill()
|
||||
|
||||
|
||||
def monitor_processes(processes, run_root, run_timestamp, plotter_info=None):
|
||||
pending = set(processes.keys())
|
||||
model_states = {}
|
||||
for model_name, info in processes.items():
|
||||
model_states[model_name] = {
|
||||
"status": "running",
|
||||
"pid": info["process"].pid,
|
||||
"log_dir": info["log_dir"],
|
||||
"stdout_path": info["stdout_path"],
|
||||
"checkpoint_dir": info["checkpoint_dir"],
|
||||
"started_at": info["started_at"],
|
||||
}
|
||||
|
||||
plotter_state = None
|
||||
if plotter_info is not None:
|
||||
plotter_state = {
|
||||
"status": "running",
|
||||
"pid": plotter_info["process"].pid,
|
||||
"stdout_path": plotter_info["stdout_path"],
|
||||
"started_at": plotter_info["started_at"],
|
||||
}
|
||||
|
||||
write_run_status(run_root, run_timestamp, model_states, plotter_state)
|
||||
|
||||
while pending:
|
||||
for model_name in list(pending):
|
||||
info = processes[model_name]
|
||||
process = info["process"]
|
||||
returncode = process.poll()
|
||||
if returncode is None:
|
||||
continue
|
||||
|
||||
pending.remove(model_name)
|
||||
status = "completed" if returncode == 0 else "failed"
|
||||
model_states[model_name]["status"] = status
|
||||
model_states[model_name]["returncode"] = returncode
|
||||
model_states[model_name]["finished_at"] = datetime.now().isoformat(timespec="seconds")
|
||||
write_run_status(run_root, run_timestamp, model_states, plotter_state)
|
||||
|
||||
if returncode == 0:
|
||||
print(f"[{now_str()}] [{model_name.upper()}] completed")
|
||||
else:
|
||||
print_failure_excerpt(model_name.upper(), info["stdout_path"], returncode)
|
||||
|
||||
if plotter_info is not None and plotter_state is not None and plotter_state["status"] == "running":
|
||||
plotter_returncode = plotter_info["process"].poll()
|
||||
if plotter_returncode is not None:
|
||||
plotter_state["status"] = "completed" if plotter_returncode == 0 else "failed"
|
||||
plotter_state["returncode"] = plotter_returncode
|
||||
plotter_state["finished_at"] = datetime.now().isoformat(timespec="seconds")
|
||||
write_run_status(run_root, run_timestamp, model_states, plotter_state)
|
||||
if plotter_returncode != 0:
|
||||
print_failure_excerpt("PLOTTER", plotter_info["stdout_path"], plotter_returncode)
|
||||
|
||||
if pending:
|
||||
time.sleep(PROCESS_POLL_INTERVAL)
|
||||
|
||||
return model_states, plotter_state
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
models = normalize_model_list(args.models)
|
||||
run_timestamp = args.run_timestamp or datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
run_log_root = os.path.join("logs", "multi-model", run_timestamp)
|
||||
run_ckpt_root = os.path.join("checkpoints", "multi-model", run_timestamp)
|
||||
run_timestamp, run_root = resolve_run_root(args.run_timestamp or datetime.now().strftime("%Y%m%d_%H%M%S"))
|
||||
run_log_root = os.path.join(run_root, "logs")
|
||||
run_ckpt_root = os.path.join(run_root, "checkpoints")
|
||||
os.makedirs(run_log_root, exist_ok=True)
|
||||
os.makedirs(run_ckpt_root, exist_ok=True)
|
||||
|
||||
processes = {}
|
||||
handles = []
|
||||
model_states = None
|
||||
plotter_state = None
|
||||
|
||||
try:
|
||||
for model_name in models:
|
||||
print(f"[{datetime.now().strftime('%H:%M:%S')}] Launching {model_name.upper()} training...")
|
||||
process, stdout_handle = launch_model_process(
|
||||
process, stdout_handle, stdout_path = launch_model_process(
|
||||
model_name,
|
||||
run_timestamp,
|
||||
run_log_root,
|
||||
run_ckpt_root,
|
||||
)
|
||||
processes[model_name] = process
|
||||
processes[model_name] = {
|
||||
"process": process,
|
||||
"stdout_path": stdout_path,
|
||||
"stdout_handle": stdout_handle,
|
||||
"log_dir": os.path.join(run_log_root, model_name),
|
||||
"checkpoint_dir": os.path.join(run_ckpt_root, model_name),
|
||||
"started_at": datetime.now().isoformat(timespec="seconds"),
|
||||
}
|
||||
handles.append(stdout_handle)
|
||||
print(f" PID: {process.pid}")
|
||||
|
||||
plotter_process = None
|
||||
plotter_info = None
|
||||
if not args.no_plotter:
|
||||
plotter_process, plotter_stdout = launch_plotter(
|
||||
plotter_process, plotter_stdout, plotter_stdout_path = launch_plotter(
|
||||
run_timestamp,
|
||||
run_log_root,
|
||||
args.interval,
|
||||
)
|
||||
handles.append(plotter_stdout)
|
||||
plotter_info = {
|
||||
"process": plotter_process,
|
||||
"stdout_path": plotter_stdout_path,
|
||||
"stdout_handle": plotter_stdout,
|
||||
"started_at": datetime.now().isoformat(timespec="seconds"),
|
||||
}
|
||||
print(f"[{datetime.now().strftime('%H:%M:%S')}] Launching live plotter...")
|
||||
print(f" PID: {plotter_process.pid}")
|
||||
|
||||
|
|
@ -135,17 +263,33 @@ def main():
|
|||
print(f"Models: {', '.join(models)}")
|
||||
print()
|
||||
|
||||
for model_name, process in processes.items():
|
||||
process.wait()
|
||||
status = "completed" if process.returncode == 0 else f"failed (code={process.returncode})"
|
||||
print(f"[{model_name.upper()}] {status}")
|
||||
model_states, plotter_state = monitor_processes(
|
||||
processes,
|
||||
run_root,
|
||||
run_timestamp,
|
||||
plotter_info=plotter_info,
|
||||
)
|
||||
failed_models = [
|
||||
model_name for model_name, state in model_states.items()
|
||||
if state["status"] != "completed"
|
||||
]
|
||||
if failed_models:
|
||||
print()
|
||||
print(f"[{now_str()}] Failed models: {', '.join(failed_models)}")
|
||||
print(f"[{now_str()}] See run status: {os.path.join(run_root, 'run_status.json')}")
|
||||
sys.exit(1)
|
||||
finally:
|
||||
if "plotter_process" in locals() and plotter_process is not None and plotter_process.poll() is None:
|
||||
plotter_process.terminate()
|
||||
try:
|
||||
plotter_process.wait(timeout=10)
|
||||
except subprocess.TimeoutExpired:
|
||||
plotter_process.kill()
|
||||
if "plotter_process" in locals():
|
||||
terminate_process(plotter_process)
|
||||
if plotter_state is not None and plotter_state.get("status") == "running":
|
||||
plotter_state["status"] = "terminated"
|
||||
plotter_state["finished_at"] = datetime.now().isoformat(timespec="seconds")
|
||||
|
||||
for info in processes.values():
|
||||
terminate_process(info["process"])
|
||||
|
||||
if model_states is not None:
|
||||
write_run_status(run_root, run_timestamp, model_states, plotter_state)
|
||||
|
||||
for handle in handles:
|
||||
handle.close()
|
||||
|
|
|
|||
|
|
@ -32,6 +32,7 @@ from agents.tcamappo_agent import TCAMAPPOAgent
|
|||
from agents.td3_agent import TD3Agent
|
||||
from envs.edge_vsl_env import SUMOEdgeVSLEnvironment
|
||||
from utils.config import get_agent_config
|
||||
from utils.run_dirs import find_shared_config_path, resolve_checkpoint_root
|
||||
|
||||
|
||||
MODEL_ORDER = ["ppo", "gpro", "appo", "mappo", "tcamappo", "dcmappo", "dqn", "ddpg", "sac", "td3", "sctd3"]
|
||||
|
|
@ -59,7 +60,7 @@ def parse_args():
|
|||
"--checkpoint-root",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Checkpoint root directory. Default: latest checkpoints/multi-model/<timestamp>.",
|
||||
help="Checkpoint root or run root. Default: latest under runs/<timestamp>.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-dir",
|
||||
|
|
@ -71,7 +72,7 @@ def parse_args():
|
|||
"--config",
|
||||
type=str,
|
||||
default="config_sumo_vsl.yaml",
|
||||
help="Fallback config path when checkpoint config.yaml is missing.",
|
||||
help="Fallback config path when the shared run config is unavailable.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--models",
|
||||
|
|
@ -119,21 +120,6 @@ def normalize_model_name(name: str) -> str:
|
|||
return lowered
|
||||
|
||||
|
||||
def find_latest_multi_model_root() -> str:
|
||||
base_dir = os.path.join("checkpoints", "multi-model")
|
||||
if not os.path.isdir(base_dir):
|
||||
raise FileNotFoundError(f"Multi-model checkpoint dir not found: {base_dir}")
|
||||
|
||||
candidates = [
|
||||
os.path.join(base_dir, item)
|
||||
for item in os.listdir(base_dir)
|
||||
if os.path.isdir(os.path.join(base_dir, item))
|
||||
]
|
||||
if not candidates:
|
||||
raise FileNotFoundError(f"No run directories found under: {base_dir}")
|
||||
return max(candidates)
|
||||
|
||||
|
||||
def discover_model_dirs(checkpoint_root: str, requested_models: List[str] = None) -> Dict[str, str]:
|
||||
checkpoint_root = os.path.abspath(checkpoint_root)
|
||||
requested = [normalize_model_name(m) for m in requested_models] if requested_models else None
|
||||
|
|
@ -149,11 +135,23 @@ def discover_model_dirs(checkpoint_root: str, requested_models: List[str] = None
|
|||
return discovered
|
||||
return {k: v for k, v in discovered.items() if k in requested}
|
||||
|
||||
base_name = os.path.basename(checkpoint_root).lower()
|
||||
parent_name = os.path.basename(os.path.dirname(checkpoint_root)).lower()
|
||||
grandparent_name = os.path.basename(os.path.dirname(os.path.dirname(checkpoint_root))).lower()
|
||||
if base_name in MODEL_LABELS and (
|
||||
parent_name == "checkpoints" or grandparent_name in {"checkpoints", "multi-model"}
|
||||
):
|
||||
model_name = base_name
|
||||
if requested is not None and model_name not in requested:
|
||||
return {}
|
||||
return {model_name: checkpoint_root}
|
||||
|
||||
if os.path.isfile(os.path.join(checkpoint_root, "config.yaml")):
|
||||
model_name = None
|
||||
parent_name = os.path.basename(os.path.dirname(checkpoint_root)).lower()
|
||||
if parent_name in MODEL_LABELS:
|
||||
model_name = parent_name
|
||||
elif base_name in MODEL_LABELS:
|
||||
model_name = base_name
|
||||
elif requested and len(requested) == 1:
|
||||
model_name = requested[0]
|
||||
if model_name is None:
|
||||
|
|
@ -166,14 +164,37 @@ def discover_model_dirs(checkpoint_root: str, requested_models: List[str] = None
|
|||
raise FileNotFoundError(f"No model checkpoint directories found in: {checkpoint_root}")
|
||||
|
||||
|
||||
def infer_eval_run_name(checkpoint_root: str) -> str:
|
||||
normalized_root = os.path.abspath(checkpoint_root)
|
||||
base_name = os.path.basename(normalized_root)
|
||||
parent_dir = os.path.dirname(normalized_root)
|
||||
parent_name = os.path.basename(parent_dir)
|
||||
grandparent_dir = os.path.dirname(parent_dir)
|
||||
grandparent_name = os.path.basename(grandparent_dir)
|
||||
|
||||
if base_name == "checkpoints":
|
||||
return parent_name
|
||||
|
||||
if parent_name == "checkpoints":
|
||||
return f"{base_name}_{grandparent_name}"
|
||||
|
||||
if grandparent_name == "multi-model":
|
||||
return f"{base_name}_{parent_name}"
|
||||
|
||||
if parent_name == "multi-model":
|
||||
return base_name
|
||||
|
||||
if parent_name in MODEL_ORDER:
|
||||
return f"{parent_name}_{base_name}"
|
||||
|
||||
return base_name
|
||||
|
||||
|
||||
def resolve_eval_output_dir(output_dir: str, checkpoint_root: str) -> str:
|
||||
if output_dir:
|
||||
return output_dir
|
||||
|
||||
run_name = os.path.basename(os.path.abspath(checkpoint_root))
|
||||
parent_name = os.path.basename(os.path.dirname(os.path.abspath(checkpoint_root)))
|
||||
if parent_name in ("multi-model", *MODEL_ORDER):
|
||||
run_name = f"{parent_name}_{run_name}"
|
||||
run_name = infer_eval_run_name(checkpoint_root)
|
||||
return os.path.join("results", "evaluations", run_name)
|
||||
|
||||
|
||||
|
|
@ -181,7 +202,7 @@ def load_config_for_checkpoint(checkpoint_dir: Optional[str], fallback_config_pa
|
|||
with open(fallback_config_path, "r", encoding="utf-8") as f:
|
||||
base_config = yaml.safe_load(f)
|
||||
|
||||
checkpoint_config = os.path.join(checkpoint_dir, "config.yaml") if checkpoint_dir else ""
|
||||
checkpoint_config = find_shared_config_path(checkpoint_dir, fallback_config_path)
|
||||
if checkpoint_config and os.path.isfile(checkpoint_config):
|
||||
with open(checkpoint_config, "r", encoding="utf-8") as f:
|
||||
checkpoint_loaded = yaml.safe_load(f)
|
||||
|
|
@ -852,7 +873,7 @@ def print_summary(summary_df: pd.DataFrame, output_dir: str):
|
|||
|
||||
def main():
|
||||
args = parse_args()
|
||||
checkpoint_root = args.checkpoint_root or find_latest_multi_model_root()
|
||||
checkpoint_root = resolve_checkpoint_root(args.checkpoint_root)
|
||||
model_dirs = discover_model_dirs(checkpoint_root, args.models)
|
||||
if not model_dirs:
|
||||
raise FileNotFoundError("No models matched the requested selection.")
|
||||
|
|
|
|||
|
|
@ -16,6 +16,8 @@ import pandas as pd
|
|||
matplotlib.use("Agg")
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
from utils.run_dirs import find_latest_run_root, find_run_root_by_timestamp
|
||||
|
||||
|
||||
MODEL_ORDER = ["ppo", "gpro", "appo", "mappo", "tcamappo", "dcmappo", "dqn", "ddpg", "sac", "td3", "sctd3"]
|
||||
MODEL_LABELS = {
|
||||
|
|
@ -47,7 +49,7 @@ MODEL_COLORS = {
|
|||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description="Plot live training progress from multi-model logs.")
|
||||
parser = argparse.ArgumentParser(description="Plot live training progress from run logs.")
|
||||
parser.add_argument("--model", default=None, help="Model name, e.g. ppo/gpro/appo/mappo/tcamappo/dcmappo/dqn/ddpg/sac/td3/sctd3")
|
||||
parser.add_argument(
|
||||
"--all-models",
|
||||
|
|
@ -57,12 +59,12 @@ def parse_args():
|
|||
parser.add_argument(
|
||||
"--run",
|
||||
default=None,
|
||||
help="Multi-model run timestamp. Default: latest under logs/multi-model/",
|
||||
help="Run timestamp. Default: latest under runs/<timestamp>/logs.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--log-root",
|
||||
default=os.path.join("logs", "multi-model"),
|
||||
help="Multi-model log root directory.",
|
||||
default=None,
|
||||
help="Log root directory or run root. Default: latest runs/<timestamp>/logs.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--csv-path",
|
||||
|
|
@ -131,15 +133,41 @@ def find_latest_run(log_root: str) -> str:
|
|||
return max(candidates)
|
||||
|
||||
|
||||
def resolve_paths(args, model_name: str):
|
||||
def resolve_run_logs_dir(log_root: Optional[str], run_name: Optional[str]) -> tuple[str, str]:
|
||||
if log_root is None:
|
||||
run_root = find_run_root_by_timestamp(run_name) if run_name else find_latest_run_root()
|
||||
resolved_log_root = (
|
||||
os.path.join(run_root, "logs")
|
||||
if os.path.isdir(os.path.join(run_root, "logs"))
|
||||
else run_root
|
||||
)
|
||||
else:
|
||||
resolved_log_root = os.path.abspath(log_root)
|
||||
|
||||
if os.path.basename(resolved_log_root).lower() == "logs":
|
||||
inferred_run_name = os.path.basename(os.path.dirname(resolved_log_root))
|
||||
return inferred_run_name, resolved_log_root
|
||||
if os.path.isdir(os.path.join(resolved_log_root, "logs")):
|
||||
inferred_run_name = os.path.basename(resolved_log_root)
|
||||
return inferred_run_name, os.path.join(resolved_log_root, "logs")
|
||||
|
||||
chosen_run_name = run_name or find_latest_run(resolved_log_root)
|
||||
candidate_run_dir = os.path.join(resolved_log_root, chosen_run_name)
|
||||
candidate_logs_dir = os.path.join(candidate_run_dir, "logs")
|
||||
if os.path.isdir(candidate_logs_dir):
|
||||
return chosen_run_name, candidate_logs_dir
|
||||
if os.path.isdir(candidate_run_dir):
|
||||
return chosen_run_name, candidate_run_dir
|
||||
|
||||
raise FileNotFoundError(f"Run directory not found: {candidate_run_dir}")
|
||||
|
||||
|
||||
def resolve_paths(args, model_name: str, run_name: str, run_logs_dir: str):
|
||||
if args.csv_path:
|
||||
csv_path = args.csv_path
|
||||
model_dir = os.path.dirname(os.path.abspath(csv_path))
|
||||
run_dir = os.path.dirname(model_dir)
|
||||
run_name = os.path.basename(run_dir)
|
||||
else:
|
||||
run_name = args.run or find_latest_run(args.log_root)
|
||||
model_dir = os.path.join(args.log_root, run_name, model_name)
|
||||
model_dir = os.path.join(run_logs_dir, model_name)
|
||||
csv_path = os.path.join(model_dir, f"{model_name}_training_log.csv")
|
||||
|
||||
if not os.path.isfile(csv_path):
|
||||
|
|
@ -150,11 +178,12 @@ def resolve_paths(args, model_name: str):
|
|||
return run_name, model_dir, csv_path, output_path, compare_output
|
||||
|
||||
|
||||
def resolve_run_dir(args):
|
||||
run_name = args.run or find_latest_run(args.log_root)
|
||||
run_dir = os.path.join(args.log_root, run_name)
|
||||
if not os.path.isdir(run_dir):
|
||||
raise FileNotFoundError(f"Run directory not found: {run_dir}")
|
||||
def resolve_run_dir(args, run_name: str, run_logs_dir: str):
|
||||
run_dir = (
|
||||
os.path.dirname(run_logs_dir)
|
||||
if os.path.basename(run_logs_dir).lower() == "logs"
|
||||
else run_logs_dir
|
||||
)
|
||||
overview_output = args.overview_output or os.path.join(run_dir, "all_models_training_snapshot.png")
|
||||
return run_name, run_dir, overview_output
|
||||
|
||||
|
|
@ -315,10 +344,10 @@ def plot_detailed_snapshot(
|
|||
plt.close()
|
||||
|
||||
|
||||
def load_run_model_logs(log_root: str, run_name: str) -> Dict[str, pd.DataFrame]:
|
||||
def load_run_model_logs(run_logs_dir: str) -> Dict[str, pd.DataFrame]:
|
||||
run_logs = {}
|
||||
for model_name in MODEL_ORDER:
|
||||
csv_path = os.path.join(log_root, run_name, model_name, f"{model_name}_training_log.csv")
|
||||
csv_path = os.path.join(run_logs_dir, model_name, f"{model_name}_training_log.csv")
|
||||
if os.path.isfile(csv_path):
|
||||
try:
|
||||
run_logs[model_name] = safe_read_csv(csv_path)
|
||||
|
|
@ -506,7 +535,7 @@ def plot_all_models_overview(
|
|||
for ax, (column, title, ylabel) in zip(axes, metrics):
|
||||
plot_metric_overlay(ax, run_logs, column, title, ylabel, window, show_ma)
|
||||
|
||||
fig.suptitle(f"Multi-Model Training Overview | {run_name}", fontsize=16, y=0.995)
|
||||
fig.suptitle(f"Training Overview | {run_name}", fontsize=16, y=0.995)
|
||||
plt.tight_layout(rect=[0, 0, 1, 0.98])
|
||||
plt.savefig(output_path, dpi=170, bbox_inches="tight")
|
||||
plt.close()
|
||||
|
|
@ -517,9 +546,10 @@ def plot_all_models_overview(
|
|||
|
||||
|
||||
def run_once(args):
|
||||
run_name, run_logs_dir = resolve_run_logs_dir(args.log_root, args.run)
|
||||
if args.all_models:
|
||||
run_name, _, overview_output = resolve_run_dir(args)
|
||||
run_logs = load_run_model_logs(args.log_root, run_name)
|
||||
_, _, overview_output = resolve_run_dir(args, run_name, run_logs_dir)
|
||||
run_logs = load_run_model_logs(run_logs_dir)
|
||||
if not run_logs:
|
||||
plot_waiting_placeholder(
|
||||
overview_output,
|
||||
|
|
@ -539,7 +569,9 @@ def run_once(args):
|
|||
return
|
||||
|
||||
model_name = normalize_model_name(args.model)
|
||||
run_name, _, csv_path, output_path, compare_output = resolve_paths(args, model_name)
|
||||
run_name, _, csv_path, output_path, compare_output = resolve_paths(
|
||||
args, model_name, run_name, run_logs_dir
|
||||
)
|
||||
df = safe_read_csv(csv_path)
|
||||
plot_detailed_snapshot(
|
||||
model_name=model_name,
|
||||
|
|
@ -551,7 +583,7 @@ def run_once(args):
|
|||
show_ma=args.show_ma,
|
||||
)
|
||||
|
||||
run_logs = load_run_model_logs(args.log_root, run_name)
|
||||
run_logs = load_run_model_logs(run_logs_dir)
|
||||
plot_run_comparison(
|
||||
target_model=model_name,
|
||||
run_logs=run_logs,
|
||||
|
|
|
|||
|
|
@ -1,10 +1,37 @@
|
|||
"""Unified worker entrypoint for launching one model training job."""
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import traceback
|
||||
from datetime import datetime
|
||||
|
||||
from training.registry import TRAINERS, normalize_model_name
|
||||
from utils.run_dirs import add_run_dir_args
|
||||
|
||||
|
||||
def persist_training_error(model_name: str, log_dir: str, exc: Exception) -> None:
|
||||
if not log_dir:
|
||||
return
|
||||
|
||||
os.makedirs(log_dir, exist_ok=True)
|
||||
tb_text = traceback.format_exc()
|
||||
error_txt_path = os.path.join(log_dir, "error_traceback.txt")
|
||||
with open(error_txt_path, "w", encoding="utf-8") as f:
|
||||
f.write(tb_text)
|
||||
|
||||
error_json_path = os.path.join(log_dir, "error_summary.json")
|
||||
payload = {
|
||||
"model": model_name,
|
||||
"failed_at": datetime.now().isoformat(timespec="seconds"),
|
||||
"exception_type": type(exc).__name__,
|
||||
"message": str(exc),
|
||||
"traceback_file": os.path.abspath(error_txt_path),
|
||||
}
|
||||
with open(error_json_path, "w", encoding="utf-8") as f:
|
||||
json.dump(payload, f, ensure_ascii=False, indent=2)
|
||||
|
||||
|
||||
def main():
|
||||
parser = add_run_dir_args(argparse.ArgumentParser())
|
||||
parser.add_argument("--model", required=True, help="Model name to train.")
|
||||
|
|
@ -12,11 +39,20 @@ def main():
|
|||
|
||||
model_name = normalize_model_name(args.model)
|
||||
trainer = TRAINERS[model_name]
|
||||
trainer(
|
||||
log_dir=args.log_dir,
|
||||
checkpoint_dir=args.checkpoint_dir,
|
||||
run_timestamp=args.run_timestamp,
|
||||
)
|
||||
try:
|
||||
trainer(
|
||||
log_dir=args.log_dir,
|
||||
checkpoint_dir=args.checkpoint_dir,
|
||||
run_timestamp=args.run_timestamp,
|
||||
)
|
||||
except Exception as exc:
|
||||
persist_training_error(model_name, args.log_dir, exc)
|
||||
print(
|
||||
f"[{datetime.now().strftime('%H:%M:%S')}] {model_name.upper()} training failed: {type(exc).__name__}: {exc}",
|
||||
file=sys.stderr,
|
||||
)
|
||||
traceback.print_exc()
|
||||
raise
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
|
|
@ -20,7 +20,7 @@ from utils.config import get_agent_config, get_training_config
|
|||
from utils.episode_artifacts import save_training_episode_artifacts
|
||||
from utils.logger import TrainingLogger
|
||||
from utils.plot import plot_training_curves
|
||||
from utils.run_dirs import resolve_run_dirs
|
||||
from utils.run_dirs import resolve_run_dirs, write_shared_run_config
|
||||
|
||||
|
||||
def train_sumo_appo(log_dir=None, checkpoint_dir=None, run_timestamp=None):
|
||||
|
|
@ -42,8 +42,12 @@ def train_sumo_appo(log_dir=None, checkpoint_dir=None, run_timestamp=None):
|
|||
os.makedirs(log_dir, exist_ok=True)
|
||||
runtime_config = copy.deepcopy(config)
|
||||
runtime_config.setdefault("runtime", {})["output_dir"] = log_dir
|
||||
with open(os.path.join(checkpoint_dir, "config.yaml"), "w", encoding="utf-8") as f:
|
||||
yaml.dump(runtime_config, f)
|
||||
write_shared_run_config(
|
||||
runtime_config,
|
||||
log_dir=log_dir,
|
||||
checkpoint_dir=checkpoint_dir,
|
||||
run_timestamp=run_timestamp,
|
||||
)
|
||||
|
||||
logger = TrainingLogger(log_dir, "appo")
|
||||
env = SUMOEdgeVSLEnvironment(runtime_config)
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ from utils.config import get_agent_config, get_training_config
|
|||
from utils.episode_artifacts import save_training_episode_artifacts
|
||||
from utils.logger import TrainingLogger
|
||||
from utils.plot import plot_training_curves
|
||||
from utils.run_dirs import resolve_run_dirs
|
||||
from utils.run_dirs import resolve_run_dirs, write_shared_run_config
|
||||
|
||||
|
||||
def train_sumo_dcmappo(log_dir=None, checkpoint_dir=None, run_timestamp=None):
|
||||
|
|
@ -36,8 +36,12 @@ def train_sumo_dcmappo(log_dir=None, checkpoint_dir=None, run_timestamp=None):
|
|||
os.makedirs(log_dir, exist_ok=True)
|
||||
runtime_config = copy.deepcopy(config)
|
||||
runtime_config.setdefault("runtime", {})["output_dir"] = log_dir
|
||||
with open(os.path.join(checkpoint_dir, "config.yaml"), "w", encoding="utf-8") as f:
|
||||
yaml.dump(runtime_config, f)
|
||||
write_shared_run_config(
|
||||
runtime_config,
|
||||
log_dir=log_dir,
|
||||
checkpoint_dir=checkpoint_dir,
|
||||
run_timestamp=run_timestamp,
|
||||
)
|
||||
|
||||
logger = TrainingLogger(log_dir, "dcmappo")
|
||||
env = SUMOEdgeVSLEnvironment(runtime_config)
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@ from utils.config import get_agent_config, get_training_config
|
|||
from utils.episode_artifacts import save_training_episode_artifacts
|
||||
from utils.logger import TrainingLogger
|
||||
from utils.plot import plot_training_curves
|
||||
from utils.run_dirs import resolve_run_dirs
|
||||
from utils.run_dirs import resolve_run_dirs, write_shared_run_config
|
||||
|
||||
|
||||
def train_sumo_ddpg(log_dir=None, checkpoint_dir=None, run_timestamp=None):
|
||||
|
|
@ -38,8 +38,12 @@ def train_sumo_ddpg(log_dir=None, checkpoint_dir=None, run_timestamp=None):
|
|||
os.makedirs(log_dir, exist_ok=True)
|
||||
runtime_config = copy.deepcopy(config)
|
||||
runtime_config.setdefault("runtime", {})["output_dir"] = log_dir
|
||||
with open(os.path.join(checkpoint_dir, "config.yaml"), "w", encoding="utf-8") as f:
|
||||
yaml.dump(runtime_config, f)
|
||||
write_shared_run_config(
|
||||
runtime_config,
|
||||
log_dir=log_dir,
|
||||
checkpoint_dir=checkpoint_dir,
|
||||
run_timestamp=run_timestamp,
|
||||
)
|
||||
|
||||
logger = TrainingLogger(log_dir, "ddpg")
|
||||
env = SUMOEdgeVSLEnvironment(runtime_config)
|
||||
|
|
|
|||
|
|
@ -18,7 +18,7 @@ from utils.config import get_agent_config, get_training_config
|
|||
from utils.episode_artifacts import save_training_episode_artifacts
|
||||
from utils.logger import TrainingLogger
|
||||
from utils.plot import plot_training_curves
|
||||
from utils.run_dirs import resolve_run_dirs
|
||||
from utils.run_dirs import resolve_run_dirs, write_shared_run_config
|
||||
|
||||
|
||||
def train_sumo_dqn(log_dir=None, checkpoint_dir=None, run_timestamp=None):
|
||||
|
|
@ -39,8 +39,12 @@ def train_sumo_dqn(log_dir=None, checkpoint_dir=None, run_timestamp=None):
|
|||
os.makedirs(log_dir, exist_ok=True)
|
||||
runtime_config = copy.deepcopy(config)
|
||||
runtime_config.setdefault("runtime", {})["output_dir"] = log_dir
|
||||
with open(os.path.join(checkpoint_dir, "config.yaml"), "w", encoding="utf-8") as f:
|
||||
yaml.dump(runtime_config, f)
|
||||
write_shared_run_config(
|
||||
runtime_config,
|
||||
log_dir=log_dir,
|
||||
checkpoint_dir=checkpoint_dir,
|
||||
run_timestamp=run_timestamp,
|
||||
)
|
||||
|
||||
logger = TrainingLogger(log_dir, "dqn")
|
||||
env = SUMOEdgeVSLEnvironment(runtime_config)
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ from utils.config import get_agent_config, get_training_config
|
|||
from utils.episode_artifacts import save_training_episode_artifacts
|
||||
from utils.logger import TrainingLogger
|
||||
from utils.plot import plot_training_curves
|
||||
from utils.run_dirs import resolve_run_dirs
|
||||
from utils.run_dirs import resolve_run_dirs, write_shared_run_config
|
||||
|
||||
|
||||
def train_sumo_gpro(log_dir=None, checkpoint_dir=None, run_timestamp=None):
|
||||
|
|
@ -35,8 +35,12 @@ def train_sumo_gpro(log_dir=None, checkpoint_dir=None, run_timestamp=None):
|
|||
os.makedirs(log_dir, exist_ok=True)
|
||||
runtime_config = copy.deepcopy(config)
|
||||
runtime_config.setdefault("runtime", {})["output_dir"] = log_dir
|
||||
with open(os.path.join(checkpoint_dir, "config.yaml"), "w", encoding="utf-8") as f:
|
||||
yaml.dump(runtime_config, f)
|
||||
write_shared_run_config(
|
||||
runtime_config,
|
||||
log_dir=log_dir,
|
||||
checkpoint_dir=checkpoint_dir,
|
||||
run_timestamp=run_timestamp,
|
||||
)
|
||||
|
||||
logger = TrainingLogger(log_dir, "gpro")
|
||||
env = SUMOEdgeVSLEnvironment(runtime_config)
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@ from utils.config import get_agent_config, get_training_config
|
|||
from utils.episode_artifacts import save_training_episode_artifacts
|
||||
from utils.logger import TrainingLogger
|
||||
from utils.plot import plot_training_curves
|
||||
from utils.run_dirs import resolve_run_dirs
|
||||
from utils.run_dirs import resolve_run_dirs, write_shared_run_config
|
||||
|
||||
|
||||
def train_sumo_mappo(log_dir=None, checkpoint_dir=None, run_timestamp=None):
|
||||
|
|
@ -37,8 +37,12 @@ def train_sumo_mappo(log_dir=None, checkpoint_dir=None, run_timestamp=None):
|
|||
os.makedirs(log_dir, exist_ok=True)
|
||||
runtime_config = copy.deepcopy(config)
|
||||
runtime_config.setdefault("runtime", {})["output_dir"] = log_dir
|
||||
with open(os.path.join(checkpoint_dir, "config.yaml"), "w", encoding="utf-8") as f:
|
||||
yaml.dump(runtime_config, f)
|
||||
write_shared_run_config(
|
||||
runtime_config,
|
||||
log_dir=log_dir,
|
||||
checkpoint_dir=checkpoint_dir,
|
||||
run_timestamp=run_timestamp,
|
||||
)
|
||||
|
||||
logger = TrainingLogger(log_dir, "mappo")
|
||||
env = SUMOEdgeVSLEnvironment(runtime_config)
|
||||
|
|
|
|||
|
|
@ -20,7 +20,7 @@ from utils.config import get_agent_config, get_training_config
|
|||
from utils.episode_artifacts import save_training_episode_artifacts
|
||||
from utils.logger import TrainingLogger
|
||||
from utils.plot import plot_training_curves
|
||||
from utils.run_dirs import resolve_run_dirs
|
||||
from utils.run_dirs import resolve_run_dirs, write_shared_run_config
|
||||
|
||||
|
||||
def train_sumo_ppo(log_dir=None, checkpoint_dir=None, run_timestamp=None):
|
||||
|
|
@ -43,8 +43,12 @@ def train_sumo_ppo(log_dir=None, checkpoint_dir=None, run_timestamp=None):
|
|||
os.makedirs(log_dir, exist_ok=True)
|
||||
runtime_config = copy.deepcopy(config)
|
||||
runtime_config.setdefault("runtime", {})["output_dir"] = log_dir
|
||||
with open(os.path.join(checkpoint_dir, "config.yaml"), "w", encoding="utf-8") as f:
|
||||
yaml.dump(runtime_config, f)
|
||||
write_shared_run_config(
|
||||
runtime_config,
|
||||
log_dir=log_dir,
|
||||
checkpoint_dir=checkpoint_dir,
|
||||
run_timestamp=run_timestamp,
|
||||
)
|
||||
|
||||
logger = TrainingLogger(log_dir, "ppo")
|
||||
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ from utils.config import get_agent_config, get_training_config
|
|||
from utils.episode_artifacts import save_training_episode_artifacts
|
||||
from utils.logger import TrainingLogger
|
||||
from utils.plot import plot_training_curves
|
||||
from utils.run_dirs import resolve_run_dirs
|
||||
from utils.run_dirs import resolve_run_dirs, write_shared_run_config
|
||||
|
||||
|
||||
def train_sumo_sac(log_dir=None, checkpoint_dir=None, run_timestamp=None):
|
||||
|
|
@ -35,8 +35,12 @@ def train_sumo_sac(log_dir=None, checkpoint_dir=None, run_timestamp=None):
|
|||
os.makedirs(log_dir, exist_ok=True)
|
||||
runtime_config = copy.deepcopy(config)
|
||||
runtime_config.setdefault("runtime", {})["output_dir"] = log_dir
|
||||
with open(os.path.join(checkpoint_dir, "config.yaml"), "w", encoding="utf-8") as f:
|
||||
yaml.dump(runtime_config, f)
|
||||
write_shared_run_config(
|
||||
runtime_config,
|
||||
log_dir=log_dir,
|
||||
checkpoint_dir=checkpoint_dir,
|
||||
run_timestamp=run_timestamp,
|
||||
)
|
||||
|
||||
logger = TrainingLogger(log_dir, "sac")
|
||||
env = SUMOEdgeVSLEnvironment(runtime_config)
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@ from utils.config import get_agent_config, get_training_config
|
|||
from utils.episode_artifacts import save_training_episode_artifacts
|
||||
from utils.logger import TrainingLogger
|
||||
from utils.plot import plot_training_curves
|
||||
from utils.run_dirs import resolve_run_dirs
|
||||
from utils.run_dirs import resolve_run_dirs, write_shared_run_config
|
||||
|
||||
|
||||
def train_sumo_tcamappo(log_dir=None, checkpoint_dir=None, run_timestamp=None):
|
||||
|
|
@ -37,8 +37,12 @@ def train_sumo_tcamappo(log_dir=None, checkpoint_dir=None, run_timestamp=None):
|
|||
os.makedirs(log_dir, exist_ok=True)
|
||||
runtime_config = copy.deepcopy(config)
|
||||
runtime_config.setdefault("runtime", {})["output_dir"] = log_dir
|
||||
with open(os.path.join(checkpoint_dir, "config.yaml"), "w", encoding="utf-8") as f:
|
||||
yaml.dump(runtime_config, f)
|
||||
write_shared_run_config(
|
||||
runtime_config,
|
||||
log_dir=log_dir,
|
||||
checkpoint_dir=checkpoint_dir,
|
||||
run_timestamp=run_timestamp,
|
||||
)
|
||||
|
||||
logger = TrainingLogger(log_dir, "tcamappo")
|
||||
env = SUMOEdgeVSLEnvironment(runtime_config)
|
||||
|
|
|
|||
|
|
@ -18,7 +18,7 @@ from utils.config import get_agent_config, get_training_config
|
|||
from utils.episode_artifacts import save_training_episode_artifacts
|
||||
from utils.logger import TrainingLogger
|
||||
from utils.plot import plot_training_curves
|
||||
from utils.run_dirs import resolve_run_dirs
|
||||
from utils.run_dirs import resolve_run_dirs, write_shared_run_config
|
||||
|
||||
|
||||
def train_sumo_td3(
|
||||
|
|
@ -47,8 +47,12 @@ def train_sumo_td3(
|
|||
os.makedirs(log_dir, exist_ok=True)
|
||||
runtime_config = copy.deepcopy(config)
|
||||
runtime_config.setdefault("runtime", {})["output_dir"] = log_dir
|
||||
with open(os.path.join(checkpoint_dir, "config.yaml"), "w", encoding="utf-8") as f:
|
||||
yaml.dump(runtime_config, f)
|
||||
write_shared_run_config(
|
||||
runtime_config,
|
||||
log_dir=log_dir,
|
||||
checkpoint_dir=checkpoint_dir,
|
||||
run_timestamp=run_timestamp,
|
||||
)
|
||||
|
||||
logger = TrainingLogger(log_dir, model_name)
|
||||
env = SUMOEdgeVSLEnvironment(runtime_config)
|
||||
|
|
|
|||
|
|
@ -3,6 +3,13 @@ import os
|
|||
from datetime import datetime
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import yaml
|
||||
|
||||
|
||||
RUNS_ROOT = "runs"
|
||||
LEGACY_LOGS_ROOT = os.path.join("logs", "multi-model")
|
||||
LEGACY_CHECKPOINTS_ROOT = os.path.join("checkpoints", "multi-model")
|
||||
|
||||
|
||||
def add_run_dir_args(parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
|
||||
parser.add_argument("--log-dir", type=str, default=None, help="Training log output directory.")
|
||||
|
|
@ -11,18 +18,162 @@ def add_run_dir_args(parser: argparse.ArgumentParser) -> argparse.ArgumentParser
|
|||
return parser
|
||||
|
||||
|
||||
def default_run_timestamp() -> str:
|
||||
return datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
|
||||
|
||||
def resolve_run_root(run_timestamp: Optional[str] = None) -> Tuple[str, str]:
|
||||
timestamp = run_timestamp or default_run_timestamp()
|
||||
return timestamp, os.path.join(RUNS_ROOT, timestamp)
|
||||
|
||||
|
||||
def resolve_run_dirs(
|
||||
model_name: str,
|
||||
log_dir: Optional[str] = None,
|
||||
checkpoint_dir: Optional[str] = None,
|
||||
run_timestamp: Optional[str] = None,
|
||||
) -> Tuple[str, str, str]:
|
||||
"""Resolve output directories from runtime args, falling back to per-model defaults."""
|
||||
timestamp = run_timestamp or datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
"""Resolve output directories from runtime args, defaulting to runs/<timestamp>/..."""
|
||||
timestamp = run_timestamp or default_run_timestamp()
|
||||
|
||||
if checkpoint_dir is None:
|
||||
checkpoint_dir = os.path.join("checkpoints", model_name, timestamp)
|
||||
if log_dir is None:
|
||||
log_dir = os.path.join("logs", model_name, timestamp)
|
||||
if checkpoint_dir is None or log_dir is None:
|
||||
_, run_root = resolve_run_root(timestamp)
|
||||
if checkpoint_dir is None:
|
||||
checkpoint_dir = os.path.join(run_root, "checkpoints", model_name)
|
||||
if log_dir is None:
|
||||
log_dir = os.path.join(run_root, "logs", model_name)
|
||||
|
||||
return timestamp, checkpoint_dir, log_dir
|
||||
|
||||
|
||||
def infer_run_root_from_paths(
|
||||
log_dir: Optional[str] = None,
|
||||
checkpoint_dir: Optional[str] = None,
|
||||
run_timestamp: Optional[str] = None,
|
||||
) -> Optional[str]:
|
||||
candidate_paths = [path for path in (log_dir, checkpoint_dir) if path]
|
||||
for path in candidate_paths:
|
||||
normalized = os.path.normpath(os.path.abspath(path))
|
||||
parent = os.path.dirname(normalized)
|
||||
basename = os.path.basename(parent).lower()
|
||||
if basename in {"logs", "checkpoints"}:
|
||||
return os.path.dirname(parent)
|
||||
grandparent = os.path.basename(os.path.dirname(parent)).lower()
|
||||
if grandparent in {"logs", "checkpoints"}:
|
||||
return os.path.dirname(os.path.dirname(parent))
|
||||
|
||||
if run_timestamp:
|
||||
return os.path.join(RUNS_ROOT, run_timestamp)
|
||||
return None
|
||||
|
||||
|
||||
def write_shared_run_config(
|
||||
runtime_config: dict,
|
||||
*,
|
||||
run_root: Optional[str] = None,
|
||||
log_dir: Optional[str] = None,
|
||||
checkpoint_dir: Optional[str] = None,
|
||||
run_timestamp: Optional[str] = None,
|
||||
) -> Optional[str]:
|
||||
resolved_run_root = run_root or infer_run_root_from_paths(
|
||||
log_dir=log_dir,
|
||||
checkpoint_dir=checkpoint_dir,
|
||||
run_timestamp=run_timestamp,
|
||||
)
|
||||
if not resolved_run_root:
|
||||
return None
|
||||
|
||||
os.makedirs(resolved_run_root, exist_ok=True)
|
||||
config_path = os.path.join(resolved_run_root, "config.yaml")
|
||||
with open(config_path, "w", encoding="utf-8") as f:
|
||||
yaml.dump(runtime_config, f)
|
||||
return config_path
|
||||
|
||||
|
||||
def find_latest_run_root() -> str:
|
||||
candidates = []
|
||||
|
||||
if os.path.isdir(RUNS_ROOT):
|
||||
candidates.extend(
|
||||
os.path.join(RUNS_ROOT, item)
|
||||
for item in os.listdir(RUNS_ROOT)
|
||||
if os.path.isdir(os.path.join(RUNS_ROOT, item))
|
||||
)
|
||||
|
||||
if candidates:
|
||||
return max(candidates)
|
||||
|
||||
legacy_candidates = []
|
||||
if os.path.isdir(LEGACY_CHECKPOINTS_ROOT):
|
||||
legacy_candidates.extend(
|
||||
os.path.join(LEGACY_CHECKPOINTS_ROOT, item)
|
||||
for item in os.listdir(LEGACY_CHECKPOINTS_ROOT)
|
||||
if os.path.isdir(os.path.join(LEGACY_CHECKPOINTS_ROOT, item))
|
||||
)
|
||||
if legacy_candidates:
|
||||
return max(legacy_candidates)
|
||||
|
||||
raise FileNotFoundError(
|
||||
f"No run directories found under '{RUNS_ROOT}' or legacy '{LEGACY_CHECKPOINTS_ROOT}'."
|
||||
)
|
||||
|
||||
|
||||
def find_run_root_by_timestamp(run_timestamp: str) -> str:
|
||||
new_run_root = os.path.join(RUNS_ROOT, run_timestamp)
|
||||
if os.path.isdir(new_run_root):
|
||||
return new_run_root
|
||||
|
||||
legacy_checkpoint_root = os.path.join(LEGACY_CHECKPOINTS_ROOT, run_timestamp)
|
||||
if os.path.isdir(legacy_checkpoint_root):
|
||||
return legacy_checkpoint_root
|
||||
|
||||
legacy_log_root = os.path.join(LEGACY_LOGS_ROOT, run_timestamp)
|
||||
if os.path.isdir(legacy_log_root):
|
||||
return legacy_log_root
|
||||
|
||||
raise FileNotFoundError(
|
||||
f"Run '{run_timestamp}' not found under '{RUNS_ROOT}' or legacy multi-model directories."
|
||||
)
|
||||
|
||||
|
||||
def resolve_checkpoint_root(input_path: Optional[str]) -> str:
|
||||
if input_path:
|
||||
abs_path = os.path.abspath(input_path)
|
||||
else:
|
||||
abs_path = os.path.abspath(find_latest_run_root())
|
||||
|
||||
if os.path.isdir(os.path.join(abs_path, "checkpoints")):
|
||||
return os.path.join(abs_path, "checkpoints")
|
||||
return abs_path
|
||||
|
||||
|
||||
def resolve_log_root(input_path: Optional[str]) -> str:
|
||||
if input_path:
|
||||
abs_path = os.path.abspath(input_path)
|
||||
else:
|
||||
abs_path = os.path.abspath(find_latest_run_root())
|
||||
|
||||
if os.path.isdir(os.path.join(abs_path, "logs")):
|
||||
return os.path.join(abs_path, "logs")
|
||||
return abs_path
|
||||
|
||||
|
||||
def find_shared_config_path(
|
||||
checkpoint_dir: Optional[str] = None,
|
||||
fallback_config_path: Optional[str] = None,
|
||||
) -> Optional[str]:
|
||||
if checkpoint_dir:
|
||||
normalized = os.path.normpath(os.path.abspath(checkpoint_dir))
|
||||
local_config = os.path.join(normalized, "config.yaml")
|
||||
if os.path.isfile(local_config):
|
||||
return local_config
|
||||
|
||||
parent = os.path.dirname(normalized)
|
||||
if os.path.basename(parent).lower() == "checkpoints":
|
||||
shared_config = os.path.join(os.path.dirname(parent), "config.yaml")
|
||||
if os.path.isfile(shared_config):
|
||||
return shared_config
|
||||
|
||||
if fallback_config_path and os.path.isfile(fallback_config_path):
|
||||
return fallback_config_path
|
||||
return None
|
||||
|
|
|
|||
Loading…
Reference in New Issue