diff --git a/run_all_training.py b/run_all_training.py index 91d9db2..9b65efb 100644 --- a/run_all_training.py +++ b/run_all_training.py @@ -1,14 +1,18 @@ """Unified training launcher for one or more models.""" import argparse +import json import os import subprocess import sys +import time from datetime import datetime from training.registry import DEFAULT_MODELS, normalize_model_list +from utils.run_dirs import resolve_run_root PLOT_REFRESH_INTERVAL = 30 +PROCESS_POLL_INTERVAL = 5 def parse_args(): @@ -69,7 +73,7 @@ def launch_model_process(model_name, run_timestamp, run_log_root, run_ckpt_root) stdout=stdout_handle, stderr=subprocess.STDOUT, ) - return process, stdout_handle + return process, stdout_handle, stdout_path def launch_plotter(run_timestamp, run_log_root, interval): @@ -90,42 +94,166 @@ def launch_plotter(run_timestamp, run_log_root, interval): stdout=stdout_handle, stderr=subprocess.STDOUT, ) - return process, stdout_handle + return process, stdout_handle, stdout_path + + +def now_str(): + return datetime.now().strftime("%H:%M:%S") + + +def read_log_tail(log_path, max_lines=40): + if not log_path or not os.path.isfile(log_path): + return [] + with open(log_path, "r", encoding="utf-8", errors="replace") as f: + lines = f.readlines() + return [line.rstrip("\n") for line in lines[-max_lines:]] + + +def write_run_status(run_root, run_timestamp, model_states, plotter_state=None): + payload = { + "run_timestamp": run_timestamp, + "updated_at": datetime.now().isoformat(timespec="seconds"), + "models": model_states, + } + if plotter_state is not None: + payload["plotter"] = plotter_state + + status_path = os.path.join(run_root, "run_status.json") + with open(status_path, "w", encoding="utf-8") as f: + json.dump(payload, f, ensure_ascii=False, indent=2) + + +def print_failure_excerpt(name, log_path, returncode): + print(f"[{now_str()}] {name} failed with code={returncode}") + print(f" Log: {log_path}") + tail_lines = read_log_tail(log_path) + if tail_lines: + print(" Last log lines:") + for line in tail_lines: + print(f" {line}") + else: + print(" No readable log content found.") + + +def terminate_process(process): + if process is None or process.poll() is not None: + return + process.terminate() + try: + process.wait(timeout=10) + except subprocess.TimeoutExpired: + process.kill() + + +def monitor_processes(processes, run_root, run_timestamp, plotter_info=None): + pending = set(processes.keys()) + model_states = {} + for model_name, info in processes.items(): + model_states[model_name] = { + "status": "running", + "pid": info["process"].pid, + "log_dir": info["log_dir"], + "stdout_path": info["stdout_path"], + "checkpoint_dir": info["checkpoint_dir"], + "started_at": info["started_at"], + } + + plotter_state = None + if plotter_info is not None: + plotter_state = { + "status": "running", + "pid": plotter_info["process"].pid, + "stdout_path": plotter_info["stdout_path"], + "started_at": plotter_info["started_at"], + } + + write_run_status(run_root, run_timestamp, model_states, plotter_state) + + while pending: + for model_name in list(pending): + info = processes[model_name] + process = info["process"] + returncode = process.poll() + if returncode is None: + continue + + pending.remove(model_name) + status = "completed" if returncode == 0 else "failed" + model_states[model_name]["status"] = status + model_states[model_name]["returncode"] = returncode + model_states[model_name]["finished_at"] = datetime.now().isoformat(timespec="seconds") + write_run_status(run_root, run_timestamp, model_states, plotter_state) + + if returncode == 0: + print(f"[{now_str()}] [{model_name.upper()}] completed") + else: + print_failure_excerpt(model_name.upper(), info["stdout_path"], returncode) + + if plotter_info is not None and plotter_state is not None and plotter_state["status"] == "running": + plotter_returncode = plotter_info["process"].poll() + if plotter_returncode is not None: + plotter_state["status"] = "completed" if plotter_returncode == 0 else "failed" + plotter_state["returncode"] = plotter_returncode + plotter_state["finished_at"] = datetime.now().isoformat(timespec="seconds") + write_run_status(run_root, run_timestamp, model_states, plotter_state) + if plotter_returncode != 0: + print_failure_excerpt("PLOTTER", plotter_info["stdout_path"], plotter_returncode) + + if pending: + time.sleep(PROCESS_POLL_INTERVAL) + + return model_states, plotter_state def main(): args = parse_args() models = normalize_model_list(args.models) - run_timestamp = args.run_timestamp or datetime.now().strftime("%Y%m%d_%H%M%S") - run_log_root = os.path.join("logs", "multi-model", run_timestamp) - run_ckpt_root = os.path.join("checkpoints", "multi-model", run_timestamp) + run_timestamp, run_root = resolve_run_root(args.run_timestamp or datetime.now().strftime("%Y%m%d_%H%M%S")) + run_log_root = os.path.join(run_root, "logs") + run_ckpt_root = os.path.join(run_root, "checkpoints") os.makedirs(run_log_root, exist_ok=True) os.makedirs(run_ckpt_root, exist_ok=True) processes = {} handles = [] + model_states = None + plotter_state = None try: for model_name in models: print(f"[{datetime.now().strftime('%H:%M:%S')}] Launching {model_name.upper()} training...") - process, stdout_handle = launch_model_process( + process, stdout_handle, stdout_path = launch_model_process( model_name, run_timestamp, run_log_root, run_ckpt_root, ) - processes[model_name] = process + processes[model_name] = { + "process": process, + "stdout_path": stdout_path, + "stdout_handle": stdout_handle, + "log_dir": os.path.join(run_log_root, model_name), + "checkpoint_dir": os.path.join(run_ckpt_root, model_name), + "started_at": datetime.now().isoformat(timespec="seconds"), + } handles.append(stdout_handle) print(f" PID: {process.pid}") plotter_process = None + plotter_info = None if not args.no_plotter: - plotter_process, plotter_stdout = launch_plotter( + plotter_process, plotter_stdout, plotter_stdout_path = launch_plotter( run_timestamp, run_log_root, args.interval, ) handles.append(plotter_stdout) + plotter_info = { + "process": plotter_process, + "stdout_path": plotter_stdout_path, + "stdout_handle": plotter_stdout, + "started_at": datetime.now().isoformat(timespec="seconds"), + } print(f"[{datetime.now().strftime('%H:%M:%S')}] Launching live plotter...") print(f" PID: {plotter_process.pid}") @@ -135,17 +263,33 @@ def main(): print(f"Models: {', '.join(models)}") print() - for model_name, process in processes.items(): - process.wait() - status = "completed" if process.returncode == 0 else f"failed (code={process.returncode})" - print(f"[{model_name.upper()}] {status}") + model_states, plotter_state = monitor_processes( + processes, + run_root, + run_timestamp, + plotter_info=plotter_info, + ) + failed_models = [ + model_name for model_name, state in model_states.items() + if state["status"] != "completed" + ] + if failed_models: + print() + print(f"[{now_str()}] Failed models: {', '.join(failed_models)}") + print(f"[{now_str()}] See run status: {os.path.join(run_root, 'run_status.json')}") + sys.exit(1) finally: - if "plotter_process" in locals() and plotter_process is not None and plotter_process.poll() is None: - plotter_process.terminate() - try: - plotter_process.wait(timeout=10) - except subprocess.TimeoutExpired: - plotter_process.kill() + if "plotter_process" in locals(): + terminate_process(plotter_process) + if plotter_state is not None and plotter_state.get("status") == "running": + plotter_state["status"] = "terminated" + plotter_state["finished_at"] = datetime.now().isoformat(timespec="seconds") + + for info in processes.values(): + terminate_process(info["process"]) + + if model_states is not None: + write_run_status(run_root, run_timestamp, model_states, plotter_state) for handle in handles: handle.close() diff --git a/scripts/evaluate_models.py b/scripts/evaluate_models.py index 0b926d9..8d35ec4 100644 --- a/scripts/evaluate_models.py +++ b/scripts/evaluate_models.py @@ -32,6 +32,7 @@ from agents.tcamappo_agent import TCAMAPPOAgent from agents.td3_agent import TD3Agent from envs.edge_vsl_env import SUMOEdgeVSLEnvironment from utils.config import get_agent_config +from utils.run_dirs import find_shared_config_path, resolve_checkpoint_root MODEL_ORDER = ["ppo", "gpro", "appo", "mappo", "tcamappo", "dcmappo", "dqn", "ddpg", "sac", "td3", "sctd3"] @@ -59,7 +60,7 @@ def parse_args(): "--checkpoint-root", type=str, default=None, - help="Checkpoint root directory. Default: latest checkpoints/multi-model/.", + help="Checkpoint root or run root. Default: latest under runs/.", ) parser.add_argument( "--output-dir", @@ -71,7 +72,7 @@ def parse_args(): "--config", type=str, default="config_sumo_vsl.yaml", - help="Fallback config path when checkpoint config.yaml is missing.", + help="Fallback config path when the shared run config is unavailable.", ) parser.add_argument( "--models", @@ -119,21 +120,6 @@ def normalize_model_name(name: str) -> str: return lowered -def find_latest_multi_model_root() -> str: - base_dir = os.path.join("checkpoints", "multi-model") - if not os.path.isdir(base_dir): - raise FileNotFoundError(f"Multi-model checkpoint dir not found: {base_dir}") - - candidates = [ - os.path.join(base_dir, item) - for item in os.listdir(base_dir) - if os.path.isdir(os.path.join(base_dir, item)) - ] - if not candidates: - raise FileNotFoundError(f"No run directories found under: {base_dir}") - return max(candidates) - - def discover_model_dirs(checkpoint_root: str, requested_models: List[str] = None) -> Dict[str, str]: checkpoint_root = os.path.abspath(checkpoint_root) requested = [normalize_model_name(m) for m in requested_models] if requested_models else None @@ -149,11 +135,23 @@ def discover_model_dirs(checkpoint_root: str, requested_models: List[str] = None return discovered return {k: v for k, v in discovered.items() if k in requested} + base_name = os.path.basename(checkpoint_root).lower() + parent_name = os.path.basename(os.path.dirname(checkpoint_root)).lower() + grandparent_name = os.path.basename(os.path.dirname(os.path.dirname(checkpoint_root))).lower() + if base_name in MODEL_LABELS and ( + parent_name == "checkpoints" or grandparent_name in {"checkpoints", "multi-model"} + ): + model_name = base_name + if requested is not None and model_name not in requested: + return {} + return {model_name: checkpoint_root} + if os.path.isfile(os.path.join(checkpoint_root, "config.yaml")): model_name = None - parent_name = os.path.basename(os.path.dirname(checkpoint_root)).lower() if parent_name in MODEL_LABELS: model_name = parent_name + elif base_name in MODEL_LABELS: + model_name = base_name elif requested and len(requested) == 1: model_name = requested[0] if model_name is None: @@ -166,14 +164,37 @@ def discover_model_dirs(checkpoint_root: str, requested_models: List[str] = None raise FileNotFoundError(f"No model checkpoint directories found in: {checkpoint_root}") +def infer_eval_run_name(checkpoint_root: str) -> str: + normalized_root = os.path.abspath(checkpoint_root) + base_name = os.path.basename(normalized_root) + parent_dir = os.path.dirname(normalized_root) + parent_name = os.path.basename(parent_dir) + grandparent_dir = os.path.dirname(parent_dir) + grandparent_name = os.path.basename(grandparent_dir) + + if base_name == "checkpoints": + return parent_name + + if parent_name == "checkpoints": + return f"{base_name}_{grandparent_name}" + + if grandparent_name == "multi-model": + return f"{base_name}_{parent_name}" + + if parent_name == "multi-model": + return base_name + + if parent_name in MODEL_ORDER: + return f"{parent_name}_{base_name}" + + return base_name + + def resolve_eval_output_dir(output_dir: str, checkpoint_root: str) -> str: if output_dir: return output_dir - run_name = os.path.basename(os.path.abspath(checkpoint_root)) - parent_name = os.path.basename(os.path.dirname(os.path.abspath(checkpoint_root))) - if parent_name in ("multi-model", *MODEL_ORDER): - run_name = f"{parent_name}_{run_name}" + run_name = infer_eval_run_name(checkpoint_root) return os.path.join("results", "evaluations", run_name) @@ -181,7 +202,7 @@ def load_config_for_checkpoint(checkpoint_dir: Optional[str], fallback_config_pa with open(fallback_config_path, "r", encoding="utf-8") as f: base_config = yaml.safe_load(f) - checkpoint_config = os.path.join(checkpoint_dir, "config.yaml") if checkpoint_dir else "" + checkpoint_config = find_shared_config_path(checkpoint_dir, fallback_config_path) if checkpoint_config and os.path.isfile(checkpoint_config): with open(checkpoint_config, "r", encoding="utf-8") as f: checkpoint_loaded = yaml.safe_load(f) @@ -852,7 +873,7 @@ def print_summary(summary_df: pd.DataFrame, output_dir: str): def main(): args = parse_args() - checkpoint_root = args.checkpoint_root or find_latest_multi_model_root() + checkpoint_root = resolve_checkpoint_root(args.checkpoint_root) model_dirs = discover_model_dirs(checkpoint_root, args.models) if not model_dirs: raise FileNotFoundError("No models matched the requested selection.") diff --git a/scripts/plot_live_training.py b/scripts/plot_live_training.py index 428c3a4..945b515 100644 --- a/scripts/plot_live_training.py +++ b/scripts/plot_live_training.py @@ -16,6 +16,8 @@ import pandas as pd matplotlib.use("Agg") import matplotlib.pyplot as plt +from utils.run_dirs import find_latest_run_root, find_run_root_by_timestamp + MODEL_ORDER = ["ppo", "gpro", "appo", "mappo", "tcamappo", "dcmappo", "dqn", "ddpg", "sac", "td3", "sctd3"] MODEL_LABELS = { @@ -47,7 +49,7 @@ MODEL_COLORS = { def parse_args(): - parser = argparse.ArgumentParser(description="Plot live training progress from multi-model logs.") + parser = argparse.ArgumentParser(description="Plot live training progress from run logs.") parser.add_argument("--model", default=None, help="Model name, e.g. ppo/gpro/appo/mappo/tcamappo/dcmappo/dqn/ddpg/sac/td3/sctd3") parser.add_argument( "--all-models", @@ -57,12 +59,12 @@ def parse_args(): parser.add_argument( "--run", default=None, - help="Multi-model run timestamp. Default: latest under logs/multi-model/", + help="Run timestamp. Default: latest under runs//logs.", ) parser.add_argument( "--log-root", - default=os.path.join("logs", "multi-model"), - help="Multi-model log root directory.", + default=None, + help="Log root directory or run root. Default: latest runs//logs.", ) parser.add_argument( "--csv-path", @@ -131,15 +133,41 @@ def find_latest_run(log_root: str) -> str: return max(candidates) -def resolve_paths(args, model_name: str): +def resolve_run_logs_dir(log_root: Optional[str], run_name: Optional[str]) -> tuple[str, str]: + if log_root is None: + run_root = find_run_root_by_timestamp(run_name) if run_name else find_latest_run_root() + resolved_log_root = ( + os.path.join(run_root, "logs") + if os.path.isdir(os.path.join(run_root, "logs")) + else run_root + ) + else: + resolved_log_root = os.path.abspath(log_root) + + if os.path.basename(resolved_log_root).lower() == "logs": + inferred_run_name = os.path.basename(os.path.dirname(resolved_log_root)) + return inferred_run_name, resolved_log_root + if os.path.isdir(os.path.join(resolved_log_root, "logs")): + inferred_run_name = os.path.basename(resolved_log_root) + return inferred_run_name, os.path.join(resolved_log_root, "logs") + + chosen_run_name = run_name or find_latest_run(resolved_log_root) + candidate_run_dir = os.path.join(resolved_log_root, chosen_run_name) + candidate_logs_dir = os.path.join(candidate_run_dir, "logs") + if os.path.isdir(candidate_logs_dir): + return chosen_run_name, candidate_logs_dir + if os.path.isdir(candidate_run_dir): + return chosen_run_name, candidate_run_dir + + raise FileNotFoundError(f"Run directory not found: {candidate_run_dir}") + + +def resolve_paths(args, model_name: str, run_name: str, run_logs_dir: str): if args.csv_path: csv_path = args.csv_path model_dir = os.path.dirname(os.path.abspath(csv_path)) - run_dir = os.path.dirname(model_dir) - run_name = os.path.basename(run_dir) else: - run_name = args.run or find_latest_run(args.log_root) - model_dir = os.path.join(args.log_root, run_name, model_name) + model_dir = os.path.join(run_logs_dir, model_name) csv_path = os.path.join(model_dir, f"{model_name}_training_log.csv") if not os.path.isfile(csv_path): @@ -150,11 +178,12 @@ def resolve_paths(args, model_name: str): return run_name, model_dir, csv_path, output_path, compare_output -def resolve_run_dir(args): - run_name = args.run or find_latest_run(args.log_root) - run_dir = os.path.join(args.log_root, run_name) - if not os.path.isdir(run_dir): - raise FileNotFoundError(f"Run directory not found: {run_dir}") +def resolve_run_dir(args, run_name: str, run_logs_dir: str): + run_dir = ( + os.path.dirname(run_logs_dir) + if os.path.basename(run_logs_dir).lower() == "logs" + else run_logs_dir + ) overview_output = args.overview_output or os.path.join(run_dir, "all_models_training_snapshot.png") return run_name, run_dir, overview_output @@ -315,10 +344,10 @@ def plot_detailed_snapshot( plt.close() -def load_run_model_logs(log_root: str, run_name: str) -> Dict[str, pd.DataFrame]: +def load_run_model_logs(run_logs_dir: str) -> Dict[str, pd.DataFrame]: run_logs = {} for model_name in MODEL_ORDER: - csv_path = os.path.join(log_root, run_name, model_name, f"{model_name}_training_log.csv") + csv_path = os.path.join(run_logs_dir, model_name, f"{model_name}_training_log.csv") if os.path.isfile(csv_path): try: run_logs[model_name] = safe_read_csv(csv_path) @@ -506,7 +535,7 @@ def plot_all_models_overview( for ax, (column, title, ylabel) in zip(axes, metrics): plot_metric_overlay(ax, run_logs, column, title, ylabel, window, show_ma) - fig.suptitle(f"Multi-Model Training Overview | {run_name}", fontsize=16, y=0.995) + fig.suptitle(f"Training Overview | {run_name}", fontsize=16, y=0.995) plt.tight_layout(rect=[0, 0, 1, 0.98]) plt.savefig(output_path, dpi=170, bbox_inches="tight") plt.close() @@ -517,9 +546,10 @@ def plot_all_models_overview( def run_once(args): + run_name, run_logs_dir = resolve_run_logs_dir(args.log_root, args.run) if args.all_models: - run_name, _, overview_output = resolve_run_dir(args) - run_logs = load_run_model_logs(args.log_root, run_name) + _, _, overview_output = resolve_run_dir(args, run_name, run_logs_dir) + run_logs = load_run_model_logs(run_logs_dir) if not run_logs: plot_waiting_placeholder( overview_output, @@ -539,7 +569,9 @@ def run_once(args): return model_name = normalize_model_name(args.model) - run_name, _, csv_path, output_path, compare_output = resolve_paths(args, model_name) + run_name, _, csv_path, output_path, compare_output = resolve_paths( + args, model_name, run_name, run_logs_dir + ) df = safe_read_csv(csv_path) plot_detailed_snapshot( model_name=model_name, @@ -551,7 +583,7 @@ def run_once(args): show_ma=args.show_ma, ) - run_logs = load_run_model_logs(args.log_root, run_name) + run_logs = load_run_model_logs(run_logs_dir) plot_run_comparison( target_model=model_name, run_logs=run_logs, diff --git a/training/run_model.py b/training/run_model.py index 72fa976..238cb4c 100644 --- a/training/run_model.py +++ b/training/run_model.py @@ -1,10 +1,37 @@ """Unified worker entrypoint for launching one model training job.""" import argparse +import json +import os +import sys +import traceback +from datetime import datetime from training.registry import TRAINERS, normalize_model_name from utils.run_dirs import add_run_dir_args +def persist_training_error(model_name: str, log_dir: str, exc: Exception) -> None: + if not log_dir: + return + + os.makedirs(log_dir, exist_ok=True) + tb_text = traceback.format_exc() + error_txt_path = os.path.join(log_dir, "error_traceback.txt") + with open(error_txt_path, "w", encoding="utf-8") as f: + f.write(tb_text) + + error_json_path = os.path.join(log_dir, "error_summary.json") + payload = { + "model": model_name, + "failed_at": datetime.now().isoformat(timespec="seconds"), + "exception_type": type(exc).__name__, + "message": str(exc), + "traceback_file": os.path.abspath(error_txt_path), + } + with open(error_json_path, "w", encoding="utf-8") as f: + json.dump(payload, f, ensure_ascii=False, indent=2) + + def main(): parser = add_run_dir_args(argparse.ArgumentParser()) parser.add_argument("--model", required=True, help="Model name to train.") @@ -12,11 +39,20 @@ def main(): model_name = normalize_model_name(args.model) trainer = TRAINERS[model_name] - trainer( - log_dir=args.log_dir, - checkpoint_dir=args.checkpoint_dir, - run_timestamp=args.run_timestamp, - ) + try: + trainer( + log_dir=args.log_dir, + checkpoint_dir=args.checkpoint_dir, + run_timestamp=args.run_timestamp, + ) + except Exception as exc: + persist_training_error(model_name, args.log_dir, exc) + print( + f"[{datetime.now().strftime('%H:%M:%S')}] {model_name.upper()} training failed: {type(exc).__name__}: {exc}", + file=sys.stderr, + ) + traceback.print_exc() + raise if __name__ == "__main__": diff --git a/training/train_appo.py b/training/train_appo.py index 0390527..1e92a2d 100644 --- a/training/train_appo.py +++ b/training/train_appo.py @@ -20,7 +20,7 @@ from utils.config import get_agent_config, get_training_config from utils.episode_artifacts import save_training_episode_artifacts from utils.logger import TrainingLogger from utils.plot import plot_training_curves -from utils.run_dirs import resolve_run_dirs +from utils.run_dirs import resolve_run_dirs, write_shared_run_config def train_sumo_appo(log_dir=None, checkpoint_dir=None, run_timestamp=None): @@ -42,8 +42,12 @@ def train_sumo_appo(log_dir=None, checkpoint_dir=None, run_timestamp=None): os.makedirs(log_dir, exist_ok=True) runtime_config = copy.deepcopy(config) runtime_config.setdefault("runtime", {})["output_dir"] = log_dir - with open(os.path.join(checkpoint_dir, "config.yaml"), "w", encoding="utf-8") as f: - yaml.dump(runtime_config, f) + write_shared_run_config( + runtime_config, + log_dir=log_dir, + checkpoint_dir=checkpoint_dir, + run_timestamp=run_timestamp, + ) logger = TrainingLogger(log_dir, "appo") env = SUMOEdgeVSLEnvironment(runtime_config) diff --git a/training/train_dcmappo.py b/training/train_dcmappo.py index c29f3b1..c3b15e9 100644 --- a/training/train_dcmappo.py +++ b/training/train_dcmappo.py @@ -16,7 +16,7 @@ from utils.config import get_agent_config, get_training_config from utils.episode_artifacts import save_training_episode_artifacts from utils.logger import TrainingLogger from utils.plot import plot_training_curves -from utils.run_dirs import resolve_run_dirs +from utils.run_dirs import resolve_run_dirs, write_shared_run_config def train_sumo_dcmappo(log_dir=None, checkpoint_dir=None, run_timestamp=None): @@ -36,8 +36,12 @@ def train_sumo_dcmappo(log_dir=None, checkpoint_dir=None, run_timestamp=None): os.makedirs(log_dir, exist_ok=True) runtime_config = copy.deepcopy(config) runtime_config.setdefault("runtime", {})["output_dir"] = log_dir - with open(os.path.join(checkpoint_dir, "config.yaml"), "w", encoding="utf-8") as f: - yaml.dump(runtime_config, f) + write_shared_run_config( + runtime_config, + log_dir=log_dir, + checkpoint_dir=checkpoint_dir, + run_timestamp=run_timestamp, + ) logger = TrainingLogger(log_dir, "dcmappo") env = SUMOEdgeVSLEnvironment(runtime_config) diff --git a/training/train_ddpg.py b/training/train_ddpg.py index ee20543..ff642ba 100644 --- a/training/train_ddpg.py +++ b/training/train_ddpg.py @@ -17,7 +17,7 @@ from utils.config import get_agent_config, get_training_config from utils.episode_artifacts import save_training_episode_artifacts from utils.logger import TrainingLogger from utils.plot import plot_training_curves -from utils.run_dirs import resolve_run_dirs +from utils.run_dirs import resolve_run_dirs, write_shared_run_config def train_sumo_ddpg(log_dir=None, checkpoint_dir=None, run_timestamp=None): @@ -38,8 +38,12 @@ def train_sumo_ddpg(log_dir=None, checkpoint_dir=None, run_timestamp=None): os.makedirs(log_dir, exist_ok=True) runtime_config = copy.deepcopy(config) runtime_config.setdefault("runtime", {})["output_dir"] = log_dir - with open(os.path.join(checkpoint_dir, "config.yaml"), "w", encoding="utf-8") as f: - yaml.dump(runtime_config, f) + write_shared_run_config( + runtime_config, + log_dir=log_dir, + checkpoint_dir=checkpoint_dir, + run_timestamp=run_timestamp, + ) logger = TrainingLogger(log_dir, "ddpg") env = SUMOEdgeVSLEnvironment(runtime_config) diff --git a/training/train_dqn.py b/training/train_dqn.py index d980405..e780a87 100644 --- a/training/train_dqn.py +++ b/training/train_dqn.py @@ -18,7 +18,7 @@ from utils.config import get_agent_config, get_training_config from utils.episode_artifacts import save_training_episode_artifacts from utils.logger import TrainingLogger from utils.plot import plot_training_curves -from utils.run_dirs import resolve_run_dirs +from utils.run_dirs import resolve_run_dirs, write_shared_run_config def train_sumo_dqn(log_dir=None, checkpoint_dir=None, run_timestamp=None): @@ -39,8 +39,12 @@ def train_sumo_dqn(log_dir=None, checkpoint_dir=None, run_timestamp=None): os.makedirs(log_dir, exist_ok=True) runtime_config = copy.deepcopy(config) runtime_config.setdefault("runtime", {})["output_dir"] = log_dir - with open(os.path.join(checkpoint_dir, "config.yaml"), "w", encoding="utf-8") as f: - yaml.dump(runtime_config, f) + write_shared_run_config( + runtime_config, + log_dir=log_dir, + checkpoint_dir=checkpoint_dir, + run_timestamp=run_timestamp, + ) logger = TrainingLogger(log_dir, "dqn") env = SUMOEdgeVSLEnvironment(runtime_config) diff --git a/training/train_gpro.py b/training/train_gpro.py index 9348af9..298124b 100644 --- a/training/train_gpro.py +++ b/training/train_gpro.py @@ -14,7 +14,7 @@ from utils.config import get_agent_config, get_training_config from utils.episode_artifacts import save_training_episode_artifacts from utils.logger import TrainingLogger from utils.plot import plot_training_curves -from utils.run_dirs import resolve_run_dirs +from utils.run_dirs import resolve_run_dirs, write_shared_run_config def train_sumo_gpro(log_dir=None, checkpoint_dir=None, run_timestamp=None): @@ -35,8 +35,12 @@ def train_sumo_gpro(log_dir=None, checkpoint_dir=None, run_timestamp=None): os.makedirs(log_dir, exist_ok=True) runtime_config = copy.deepcopy(config) runtime_config.setdefault("runtime", {})["output_dir"] = log_dir - with open(os.path.join(checkpoint_dir, "config.yaml"), "w", encoding="utf-8") as f: - yaml.dump(runtime_config, f) + write_shared_run_config( + runtime_config, + log_dir=log_dir, + checkpoint_dir=checkpoint_dir, + run_timestamp=run_timestamp, + ) logger = TrainingLogger(log_dir, "gpro") env = SUMOEdgeVSLEnvironment(runtime_config) diff --git a/training/train_mappo.py b/training/train_mappo.py index 97a866b..fa68f1e 100644 --- a/training/train_mappo.py +++ b/training/train_mappo.py @@ -17,7 +17,7 @@ from utils.config import get_agent_config, get_training_config from utils.episode_artifacts import save_training_episode_artifacts from utils.logger import TrainingLogger from utils.plot import plot_training_curves -from utils.run_dirs import resolve_run_dirs +from utils.run_dirs import resolve_run_dirs, write_shared_run_config def train_sumo_mappo(log_dir=None, checkpoint_dir=None, run_timestamp=None): @@ -37,8 +37,12 @@ def train_sumo_mappo(log_dir=None, checkpoint_dir=None, run_timestamp=None): os.makedirs(log_dir, exist_ok=True) runtime_config = copy.deepcopy(config) runtime_config.setdefault("runtime", {})["output_dir"] = log_dir - with open(os.path.join(checkpoint_dir, "config.yaml"), "w", encoding="utf-8") as f: - yaml.dump(runtime_config, f) + write_shared_run_config( + runtime_config, + log_dir=log_dir, + checkpoint_dir=checkpoint_dir, + run_timestamp=run_timestamp, + ) logger = TrainingLogger(log_dir, "mappo") env = SUMOEdgeVSLEnvironment(runtime_config) diff --git a/training/train_ppo.py b/training/train_ppo.py index f058f19..4c1ffe2 100644 --- a/training/train_ppo.py +++ b/training/train_ppo.py @@ -20,7 +20,7 @@ from utils.config import get_agent_config, get_training_config from utils.episode_artifacts import save_training_episode_artifacts from utils.logger import TrainingLogger from utils.plot import plot_training_curves -from utils.run_dirs import resolve_run_dirs +from utils.run_dirs import resolve_run_dirs, write_shared_run_config def train_sumo_ppo(log_dir=None, checkpoint_dir=None, run_timestamp=None): @@ -43,8 +43,12 @@ def train_sumo_ppo(log_dir=None, checkpoint_dir=None, run_timestamp=None): os.makedirs(log_dir, exist_ok=True) runtime_config = copy.deepcopy(config) runtime_config.setdefault("runtime", {})["output_dir"] = log_dir - with open(os.path.join(checkpoint_dir, "config.yaml"), "w", encoding="utf-8") as f: - yaml.dump(runtime_config, f) + write_shared_run_config( + runtime_config, + log_dir=log_dir, + checkpoint_dir=checkpoint_dir, + run_timestamp=run_timestamp, + ) logger = TrainingLogger(log_dir, "ppo") diff --git a/training/train_sac.py b/training/train_sac.py index 6b54150..557ce0e 100644 --- a/training/train_sac.py +++ b/training/train_sac.py @@ -14,7 +14,7 @@ from utils.config import get_agent_config, get_training_config from utils.episode_artifacts import save_training_episode_artifacts from utils.logger import TrainingLogger from utils.plot import plot_training_curves -from utils.run_dirs import resolve_run_dirs +from utils.run_dirs import resolve_run_dirs, write_shared_run_config def train_sumo_sac(log_dir=None, checkpoint_dir=None, run_timestamp=None): @@ -35,8 +35,12 @@ def train_sumo_sac(log_dir=None, checkpoint_dir=None, run_timestamp=None): os.makedirs(log_dir, exist_ok=True) runtime_config = copy.deepcopy(config) runtime_config.setdefault("runtime", {})["output_dir"] = log_dir - with open(os.path.join(checkpoint_dir, "config.yaml"), "w", encoding="utf-8") as f: - yaml.dump(runtime_config, f) + write_shared_run_config( + runtime_config, + log_dir=log_dir, + checkpoint_dir=checkpoint_dir, + run_timestamp=run_timestamp, + ) logger = TrainingLogger(log_dir, "sac") env = SUMOEdgeVSLEnvironment(runtime_config) diff --git a/training/train_tcamappo.py b/training/train_tcamappo.py index 1c4e2cc..d07b028 100644 --- a/training/train_tcamappo.py +++ b/training/train_tcamappo.py @@ -17,7 +17,7 @@ from utils.config import get_agent_config, get_training_config from utils.episode_artifacts import save_training_episode_artifacts from utils.logger import TrainingLogger from utils.plot import plot_training_curves -from utils.run_dirs import resolve_run_dirs +from utils.run_dirs import resolve_run_dirs, write_shared_run_config def train_sumo_tcamappo(log_dir=None, checkpoint_dir=None, run_timestamp=None): @@ -37,8 +37,12 @@ def train_sumo_tcamappo(log_dir=None, checkpoint_dir=None, run_timestamp=None): os.makedirs(log_dir, exist_ok=True) runtime_config = copy.deepcopy(config) runtime_config.setdefault("runtime", {})["output_dir"] = log_dir - with open(os.path.join(checkpoint_dir, "config.yaml"), "w", encoding="utf-8") as f: - yaml.dump(runtime_config, f) + write_shared_run_config( + runtime_config, + log_dir=log_dir, + checkpoint_dir=checkpoint_dir, + run_timestamp=run_timestamp, + ) logger = TrainingLogger(log_dir, "tcamappo") env = SUMOEdgeVSLEnvironment(runtime_config) diff --git a/training/train_td3.py b/training/train_td3.py index 55f2907..9a77394 100644 --- a/training/train_td3.py +++ b/training/train_td3.py @@ -18,7 +18,7 @@ from utils.config import get_agent_config, get_training_config from utils.episode_artifacts import save_training_episode_artifacts from utils.logger import TrainingLogger from utils.plot import plot_training_curves -from utils.run_dirs import resolve_run_dirs +from utils.run_dirs import resolve_run_dirs, write_shared_run_config def train_sumo_td3( @@ -47,8 +47,12 @@ def train_sumo_td3( os.makedirs(log_dir, exist_ok=True) runtime_config = copy.deepcopy(config) runtime_config.setdefault("runtime", {})["output_dir"] = log_dir - with open(os.path.join(checkpoint_dir, "config.yaml"), "w", encoding="utf-8") as f: - yaml.dump(runtime_config, f) + write_shared_run_config( + runtime_config, + log_dir=log_dir, + checkpoint_dir=checkpoint_dir, + run_timestamp=run_timestamp, + ) logger = TrainingLogger(log_dir, model_name) env = SUMOEdgeVSLEnvironment(runtime_config) diff --git a/utils/run_dirs.py b/utils/run_dirs.py index 2a31b1f..d08b595 100644 --- a/utils/run_dirs.py +++ b/utils/run_dirs.py @@ -3,6 +3,13 @@ import os from datetime import datetime from typing import Optional, Tuple +import yaml + + +RUNS_ROOT = "runs" +LEGACY_LOGS_ROOT = os.path.join("logs", "multi-model") +LEGACY_CHECKPOINTS_ROOT = os.path.join("checkpoints", "multi-model") + def add_run_dir_args(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: parser.add_argument("--log-dir", type=str, default=None, help="Training log output directory.") @@ -11,18 +18,162 @@ def add_run_dir_args(parser: argparse.ArgumentParser) -> argparse.ArgumentParser return parser +def default_run_timestamp() -> str: + return datetime.now().strftime("%Y%m%d_%H%M%S") + + +def resolve_run_root(run_timestamp: Optional[str] = None) -> Tuple[str, str]: + timestamp = run_timestamp or default_run_timestamp() + return timestamp, os.path.join(RUNS_ROOT, timestamp) + + def resolve_run_dirs( model_name: str, log_dir: Optional[str] = None, checkpoint_dir: Optional[str] = None, run_timestamp: Optional[str] = None, ) -> Tuple[str, str, str]: - """Resolve output directories from runtime args, falling back to per-model defaults.""" - timestamp = run_timestamp or datetime.now().strftime("%Y%m%d_%H%M%S") + """Resolve output directories from runtime args, defaulting to runs//...""" + timestamp = run_timestamp or default_run_timestamp() - if checkpoint_dir is None: - checkpoint_dir = os.path.join("checkpoints", model_name, timestamp) - if log_dir is None: - log_dir = os.path.join("logs", model_name, timestamp) + if checkpoint_dir is None or log_dir is None: + _, run_root = resolve_run_root(timestamp) + if checkpoint_dir is None: + checkpoint_dir = os.path.join(run_root, "checkpoints", model_name) + if log_dir is None: + log_dir = os.path.join(run_root, "logs", model_name) return timestamp, checkpoint_dir, log_dir + + +def infer_run_root_from_paths( + log_dir: Optional[str] = None, + checkpoint_dir: Optional[str] = None, + run_timestamp: Optional[str] = None, +) -> Optional[str]: + candidate_paths = [path for path in (log_dir, checkpoint_dir) if path] + for path in candidate_paths: + normalized = os.path.normpath(os.path.abspath(path)) + parent = os.path.dirname(normalized) + basename = os.path.basename(parent).lower() + if basename in {"logs", "checkpoints"}: + return os.path.dirname(parent) + grandparent = os.path.basename(os.path.dirname(parent)).lower() + if grandparent in {"logs", "checkpoints"}: + return os.path.dirname(os.path.dirname(parent)) + + if run_timestamp: + return os.path.join(RUNS_ROOT, run_timestamp) + return None + + +def write_shared_run_config( + runtime_config: dict, + *, + run_root: Optional[str] = None, + log_dir: Optional[str] = None, + checkpoint_dir: Optional[str] = None, + run_timestamp: Optional[str] = None, +) -> Optional[str]: + resolved_run_root = run_root or infer_run_root_from_paths( + log_dir=log_dir, + checkpoint_dir=checkpoint_dir, + run_timestamp=run_timestamp, + ) + if not resolved_run_root: + return None + + os.makedirs(resolved_run_root, exist_ok=True) + config_path = os.path.join(resolved_run_root, "config.yaml") + with open(config_path, "w", encoding="utf-8") as f: + yaml.dump(runtime_config, f) + return config_path + + +def find_latest_run_root() -> str: + candidates = [] + + if os.path.isdir(RUNS_ROOT): + candidates.extend( + os.path.join(RUNS_ROOT, item) + for item in os.listdir(RUNS_ROOT) + if os.path.isdir(os.path.join(RUNS_ROOT, item)) + ) + + if candidates: + return max(candidates) + + legacy_candidates = [] + if os.path.isdir(LEGACY_CHECKPOINTS_ROOT): + legacy_candidates.extend( + os.path.join(LEGACY_CHECKPOINTS_ROOT, item) + for item in os.listdir(LEGACY_CHECKPOINTS_ROOT) + if os.path.isdir(os.path.join(LEGACY_CHECKPOINTS_ROOT, item)) + ) + if legacy_candidates: + return max(legacy_candidates) + + raise FileNotFoundError( + f"No run directories found under '{RUNS_ROOT}' or legacy '{LEGACY_CHECKPOINTS_ROOT}'." + ) + + +def find_run_root_by_timestamp(run_timestamp: str) -> str: + new_run_root = os.path.join(RUNS_ROOT, run_timestamp) + if os.path.isdir(new_run_root): + return new_run_root + + legacy_checkpoint_root = os.path.join(LEGACY_CHECKPOINTS_ROOT, run_timestamp) + if os.path.isdir(legacy_checkpoint_root): + return legacy_checkpoint_root + + legacy_log_root = os.path.join(LEGACY_LOGS_ROOT, run_timestamp) + if os.path.isdir(legacy_log_root): + return legacy_log_root + + raise FileNotFoundError( + f"Run '{run_timestamp}' not found under '{RUNS_ROOT}' or legacy multi-model directories." + ) + + +def resolve_checkpoint_root(input_path: Optional[str]) -> str: + if input_path: + abs_path = os.path.abspath(input_path) + else: + abs_path = os.path.abspath(find_latest_run_root()) + + if os.path.isdir(os.path.join(abs_path, "checkpoints")): + return os.path.join(abs_path, "checkpoints") + return abs_path + + +def resolve_log_root(input_path: Optional[str]) -> str: + if input_path: + abs_path = os.path.abspath(input_path) + else: + abs_path = os.path.abspath(find_latest_run_root()) + + if os.path.isdir(os.path.join(abs_path, "logs")): + return os.path.join(abs_path, "logs") + return abs_path + + +def find_shared_config_path( + checkpoint_dir: Optional[str] = None, + fallback_config_path: Optional[str] = None, +) -> Optional[str]: + if checkpoint_dir: + normalized = os.path.normpath(os.path.abspath(checkpoint_dir)) + local_config = os.path.join(normalized, "config.yaml") + if os.path.isfile(local_config): + return local_config + + parent = os.path.dirname(normalized) + if os.path.basename(parent).lower() == "checkpoints": + shared_config = os.path.join(os.path.dirname(parent), "config.yaml") + if os.path.isfile(shared_config): + return shared_config + + if fallback_config_path and os.path.isfile(fallback_config_path): + return fallback_config_path + return None