统一模型训练时存储路径

This commit is contained in:
Zihan Ye 2026-04-10 02:39:05 +08:00
parent 9f3ce242b2
commit cea9d42397
15 changed files with 528 additions and 104 deletions

View File

@ -1,14 +1,18 @@
"""Unified training launcher for one or more models."""
import argparse
import json
import os
import subprocess
import sys
import time
from datetime import datetime
from training.registry import DEFAULT_MODELS, normalize_model_list
from utils.run_dirs import resolve_run_root
PLOT_REFRESH_INTERVAL = 30
PROCESS_POLL_INTERVAL = 5
def parse_args():
@ -69,7 +73,7 @@ def launch_model_process(model_name, run_timestamp, run_log_root, run_ckpt_root)
stdout=stdout_handle,
stderr=subprocess.STDOUT,
)
return process, stdout_handle
return process, stdout_handle, stdout_path
def launch_plotter(run_timestamp, run_log_root, interval):
@ -90,42 +94,166 @@ def launch_plotter(run_timestamp, run_log_root, interval):
stdout=stdout_handle,
stderr=subprocess.STDOUT,
)
return process, stdout_handle
return process, stdout_handle, stdout_path
def now_str():
return datetime.now().strftime("%H:%M:%S")
def read_log_tail(log_path, max_lines=40):
if not log_path or not os.path.isfile(log_path):
return []
with open(log_path, "r", encoding="utf-8", errors="replace") as f:
lines = f.readlines()
return [line.rstrip("\n") for line in lines[-max_lines:]]
def write_run_status(run_root, run_timestamp, model_states, plotter_state=None):
payload = {
"run_timestamp": run_timestamp,
"updated_at": datetime.now().isoformat(timespec="seconds"),
"models": model_states,
}
if plotter_state is not None:
payload["plotter"] = plotter_state
status_path = os.path.join(run_root, "run_status.json")
with open(status_path, "w", encoding="utf-8") as f:
json.dump(payload, f, ensure_ascii=False, indent=2)
def print_failure_excerpt(name, log_path, returncode):
print(f"[{now_str()}] {name} failed with code={returncode}")
print(f" Log: {log_path}")
tail_lines = read_log_tail(log_path)
if tail_lines:
print(" Last log lines:")
for line in tail_lines:
print(f" {line}")
else:
print(" No readable log content found.")
def terminate_process(process):
if process is None or process.poll() is not None:
return
process.terminate()
try:
process.wait(timeout=10)
except subprocess.TimeoutExpired:
process.kill()
def monitor_processes(processes, run_root, run_timestamp, plotter_info=None):
pending = set(processes.keys())
model_states = {}
for model_name, info in processes.items():
model_states[model_name] = {
"status": "running",
"pid": info["process"].pid,
"log_dir": info["log_dir"],
"stdout_path": info["stdout_path"],
"checkpoint_dir": info["checkpoint_dir"],
"started_at": info["started_at"],
}
plotter_state = None
if plotter_info is not None:
plotter_state = {
"status": "running",
"pid": plotter_info["process"].pid,
"stdout_path": plotter_info["stdout_path"],
"started_at": plotter_info["started_at"],
}
write_run_status(run_root, run_timestamp, model_states, plotter_state)
while pending:
for model_name in list(pending):
info = processes[model_name]
process = info["process"]
returncode = process.poll()
if returncode is None:
continue
pending.remove(model_name)
status = "completed" if returncode == 0 else "failed"
model_states[model_name]["status"] = status
model_states[model_name]["returncode"] = returncode
model_states[model_name]["finished_at"] = datetime.now().isoformat(timespec="seconds")
write_run_status(run_root, run_timestamp, model_states, plotter_state)
if returncode == 0:
print(f"[{now_str()}] [{model_name.upper()}] completed")
else:
print_failure_excerpt(model_name.upper(), info["stdout_path"], returncode)
if plotter_info is not None and plotter_state is not None and plotter_state["status"] == "running":
plotter_returncode = plotter_info["process"].poll()
if plotter_returncode is not None:
plotter_state["status"] = "completed" if plotter_returncode == 0 else "failed"
plotter_state["returncode"] = plotter_returncode
plotter_state["finished_at"] = datetime.now().isoformat(timespec="seconds")
write_run_status(run_root, run_timestamp, model_states, plotter_state)
if plotter_returncode != 0:
print_failure_excerpt("PLOTTER", plotter_info["stdout_path"], plotter_returncode)
if pending:
time.sleep(PROCESS_POLL_INTERVAL)
return model_states, plotter_state
def main():
args = parse_args()
models = normalize_model_list(args.models)
run_timestamp = args.run_timestamp or datetime.now().strftime("%Y%m%d_%H%M%S")
run_log_root = os.path.join("logs", "multi-model", run_timestamp)
run_ckpt_root = os.path.join("checkpoints", "multi-model", run_timestamp)
run_timestamp, run_root = resolve_run_root(args.run_timestamp or datetime.now().strftime("%Y%m%d_%H%M%S"))
run_log_root = os.path.join(run_root, "logs")
run_ckpt_root = os.path.join(run_root, "checkpoints")
os.makedirs(run_log_root, exist_ok=True)
os.makedirs(run_ckpt_root, exist_ok=True)
processes = {}
handles = []
model_states = None
plotter_state = None
try:
for model_name in models:
print(f"[{datetime.now().strftime('%H:%M:%S')}] Launching {model_name.upper()} training...")
process, stdout_handle = launch_model_process(
process, stdout_handle, stdout_path = launch_model_process(
model_name,
run_timestamp,
run_log_root,
run_ckpt_root,
)
processes[model_name] = process
processes[model_name] = {
"process": process,
"stdout_path": stdout_path,
"stdout_handle": stdout_handle,
"log_dir": os.path.join(run_log_root, model_name),
"checkpoint_dir": os.path.join(run_ckpt_root, model_name),
"started_at": datetime.now().isoformat(timespec="seconds"),
}
handles.append(stdout_handle)
print(f" PID: {process.pid}")
plotter_process = None
plotter_info = None
if not args.no_plotter:
plotter_process, plotter_stdout = launch_plotter(
plotter_process, plotter_stdout, plotter_stdout_path = launch_plotter(
run_timestamp,
run_log_root,
args.interval,
)
handles.append(plotter_stdout)
plotter_info = {
"process": plotter_process,
"stdout_path": plotter_stdout_path,
"stdout_handle": plotter_stdout,
"started_at": datetime.now().isoformat(timespec="seconds"),
}
print(f"[{datetime.now().strftime('%H:%M:%S')}] Launching live plotter...")
print(f" PID: {plotter_process.pid}")
@ -135,17 +263,33 @@ def main():
print(f"Models: {', '.join(models)}")
print()
for model_name, process in processes.items():
process.wait()
status = "completed" if process.returncode == 0 else f"failed (code={process.returncode})"
print(f"[{model_name.upper()}] {status}")
model_states, plotter_state = monitor_processes(
processes,
run_root,
run_timestamp,
plotter_info=plotter_info,
)
failed_models = [
model_name for model_name, state in model_states.items()
if state["status"] != "completed"
]
if failed_models:
print()
print(f"[{now_str()}] Failed models: {', '.join(failed_models)}")
print(f"[{now_str()}] See run status: {os.path.join(run_root, 'run_status.json')}")
sys.exit(1)
finally:
if "plotter_process" in locals() and plotter_process is not None and plotter_process.poll() is None:
plotter_process.terminate()
try:
plotter_process.wait(timeout=10)
except subprocess.TimeoutExpired:
plotter_process.kill()
if "plotter_process" in locals():
terminate_process(plotter_process)
if plotter_state is not None and plotter_state.get("status") == "running":
plotter_state["status"] = "terminated"
plotter_state["finished_at"] = datetime.now().isoformat(timespec="seconds")
for info in processes.values():
terminate_process(info["process"])
if model_states is not None:
write_run_status(run_root, run_timestamp, model_states, plotter_state)
for handle in handles:
handle.close()

View File

@ -32,6 +32,7 @@ from agents.tcamappo_agent import TCAMAPPOAgent
from agents.td3_agent import TD3Agent
from envs.edge_vsl_env import SUMOEdgeVSLEnvironment
from utils.config import get_agent_config
from utils.run_dirs import find_shared_config_path, resolve_checkpoint_root
MODEL_ORDER = ["ppo", "gpro", "appo", "mappo", "tcamappo", "dcmappo", "dqn", "ddpg", "sac", "td3", "sctd3"]
@ -59,7 +60,7 @@ def parse_args():
"--checkpoint-root",
type=str,
default=None,
help="Checkpoint root directory. Default: latest checkpoints/multi-model/<timestamp>.",
help="Checkpoint root or run root. Default: latest under runs/<timestamp>.",
)
parser.add_argument(
"--output-dir",
@ -71,7 +72,7 @@ def parse_args():
"--config",
type=str,
default="config_sumo_vsl.yaml",
help="Fallback config path when checkpoint config.yaml is missing.",
help="Fallback config path when the shared run config is unavailable.",
)
parser.add_argument(
"--models",
@ -119,21 +120,6 @@ def normalize_model_name(name: str) -> str:
return lowered
def find_latest_multi_model_root() -> str:
base_dir = os.path.join("checkpoints", "multi-model")
if not os.path.isdir(base_dir):
raise FileNotFoundError(f"Multi-model checkpoint dir not found: {base_dir}")
candidates = [
os.path.join(base_dir, item)
for item in os.listdir(base_dir)
if os.path.isdir(os.path.join(base_dir, item))
]
if not candidates:
raise FileNotFoundError(f"No run directories found under: {base_dir}")
return max(candidates)
def discover_model_dirs(checkpoint_root: str, requested_models: List[str] = None) -> Dict[str, str]:
checkpoint_root = os.path.abspath(checkpoint_root)
requested = [normalize_model_name(m) for m in requested_models] if requested_models else None
@ -149,11 +135,23 @@ def discover_model_dirs(checkpoint_root: str, requested_models: List[str] = None
return discovered
return {k: v for k, v in discovered.items() if k in requested}
base_name = os.path.basename(checkpoint_root).lower()
parent_name = os.path.basename(os.path.dirname(checkpoint_root)).lower()
grandparent_name = os.path.basename(os.path.dirname(os.path.dirname(checkpoint_root))).lower()
if base_name in MODEL_LABELS and (
parent_name == "checkpoints" or grandparent_name in {"checkpoints", "multi-model"}
):
model_name = base_name
if requested is not None and model_name not in requested:
return {}
return {model_name: checkpoint_root}
if os.path.isfile(os.path.join(checkpoint_root, "config.yaml")):
model_name = None
parent_name = os.path.basename(os.path.dirname(checkpoint_root)).lower()
if parent_name in MODEL_LABELS:
model_name = parent_name
elif base_name in MODEL_LABELS:
model_name = base_name
elif requested and len(requested) == 1:
model_name = requested[0]
if model_name is None:
@ -166,14 +164,37 @@ def discover_model_dirs(checkpoint_root: str, requested_models: List[str] = None
raise FileNotFoundError(f"No model checkpoint directories found in: {checkpoint_root}")
def infer_eval_run_name(checkpoint_root: str) -> str:
normalized_root = os.path.abspath(checkpoint_root)
base_name = os.path.basename(normalized_root)
parent_dir = os.path.dirname(normalized_root)
parent_name = os.path.basename(parent_dir)
grandparent_dir = os.path.dirname(parent_dir)
grandparent_name = os.path.basename(grandparent_dir)
if base_name == "checkpoints":
return parent_name
if parent_name == "checkpoints":
return f"{base_name}_{grandparent_name}"
if grandparent_name == "multi-model":
return f"{base_name}_{parent_name}"
if parent_name == "multi-model":
return base_name
if parent_name in MODEL_ORDER:
return f"{parent_name}_{base_name}"
return base_name
def resolve_eval_output_dir(output_dir: str, checkpoint_root: str) -> str:
if output_dir:
return output_dir
run_name = os.path.basename(os.path.abspath(checkpoint_root))
parent_name = os.path.basename(os.path.dirname(os.path.abspath(checkpoint_root)))
if parent_name in ("multi-model", *MODEL_ORDER):
run_name = f"{parent_name}_{run_name}"
run_name = infer_eval_run_name(checkpoint_root)
return os.path.join("results", "evaluations", run_name)
@ -181,7 +202,7 @@ def load_config_for_checkpoint(checkpoint_dir: Optional[str], fallback_config_pa
with open(fallback_config_path, "r", encoding="utf-8") as f:
base_config = yaml.safe_load(f)
checkpoint_config = os.path.join(checkpoint_dir, "config.yaml") if checkpoint_dir else ""
checkpoint_config = find_shared_config_path(checkpoint_dir, fallback_config_path)
if checkpoint_config and os.path.isfile(checkpoint_config):
with open(checkpoint_config, "r", encoding="utf-8") as f:
checkpoint_loaded = yaml.safe_load(f)
@ -852,7 +873,7 @@ def print_summary(summary_df: pd.DataFrame, output_dir: str):
def main():
args = parse_args()
checkpoint_root = args.checkpoint_root or find_latest_multi_model_root()
checkpoint_root = resolve_checkpoint_root(args.checkpoint_root)
model_dirs = discover_model_dirs(checkpoint_root, args.models)
if not model_dirs:
raise FileNotFoundError("No models matched the requested selection.")

View File

@ -16,6 +16,8 @@ import pandas as pd
matplotlib.use("Agg")
import matplotlib.pyplot as plt
from utils.run_dirs import find_latest_run_root, find_run_root_by_timestamp
MODEL_ORDER = ["ppo", "gpro", "appo", "mappo", "tcamappo", "dcmappo", "dqn", "ddpg", "sac", "td3", "sctd3"]
MODEL_LABELS = {
@ -47,7 +49,7 @@ MODEL_COLORS = {
def parse_args():
parser = argparse.ArgumentParser(description="Plot live training progress from multi-model logs.")
parser = argparse.ArgumentParser(description="Plot live training progress from run logs.")
parser.add_argument("--model", default=None, help="Model name, e.g. ppo/gpro/appo/mappo/tcamappo/dcmappo/dqn/ddpg/sac/td3/sctd3")
parser.add_argument(
"--all-models",
@ -57,12 +59,12 @@ def parse_args():
parser.add_argument(
"--run",
default=None,
help="Multi-model run timestamp. Default: latest under logs/multi-model/",
help="Run timestamp. Default: latest under runs/<timestamp>/logs.",
)
parser.add_argument(
"--log-root",
default=os.path.join("logs", "multi-model"),
help="Multi-model log root directory.",
default=None,
help="Log root directory or run root. Default: latest runs/<timestamp>/logs.",
)
parser.add_argument(
"--csv-path",
@ -131,15 +133,41 @@ def find_latest_run(log_root: str) -> str:
return max(candidates)
def resolve_paths(args, model_name: str):
def resolve_run_logs_dir(log_root: Optional[str], run_name: Optional[str]) -> tuple[str, str]:
if log_root is None:
run_root = find_run_root_by_timestamp(run_name) if run_name else find_latest_run_root()
resolved_log_root = (
os.path.join(run_root, "logs")
if os.path.isdir(os.path.join(run_root, "logs"))
else run_root
)
else:
resolved_log_root = os.path.abspath(log_root)
if os.path.basename(resolved_log_root).lower() == "logs":
inferred_run_name = os.path.basename(os.path.dirname(resolved_log_root))
return inferred_run_name, resolved_log_root
if os.path.isdir(os.path.join(resolved_log_root, "logs")):
inferred_run_name = os.path.basename(resolved_log_root)
return inferred_run_name, os.path.join(resolved_log_root, "logs")
chosen_run_name = run_name or find_latest_run(resolved_log_root)
candidate_run_dir = os.path.join(resolved_log_root, chosen_run_name)
candidate_logs_dir = os.path.join(candidate_run_dir, "logs")
if os.path.isdir(candidate_logs_dir):
return chosen_run_name, candidate_logs_dir
if os.path.isdir(candidate_run_dir):
return chosen_run_name, candidate_run_dir
raise FileNotFoundError(f"Run directory not found: {candidate_run_dir}")
def resolve_paths(args, model_name: str, run_name: str, run_logs_dir: str):
if args.csv_path:
csv_path = args.csv_path
model_dir = os.path.dirname(os.path.abspath(csv_path))
run_dir = os.path.dirname(model_dir)
run_name = os.path.basename(run_dir)
else:
run_name = args.run or find_latest_run(args.log_root)
model_dir = os.path.join(args.log_root, run_name, model_name)
model_dir = os.path.join(run_logs_dir, model_name)
csv_path = os.path.join(model_dir, f"{model_name}_training_log.csv")
if not os.path.isfile(csv_path):
@ -150,11 +178,12 @@ def resolve_paths(args, model_name: str):
return run_name, model_dir, csv_path, output_path, compare_output
def resolve_run_dir(args):
run_name = args.run or find_latest_run(args.log_root)
run_dir = os.path.join(args.log_root, run_name)
if not os.path.isdir(run_dir):
raise FileNotFoundError(f"Run directory not found: {run_dir}")
def resolve_run_dir(args, run_name: str, run_logs_dir: str):
run_dir = (
os.path.dirname(run_logs_dir)
if os.path.basename(run_logs_dir).lower() == "logs"
else run_logs_dir
)
overview_output = args.overview_output or os.path.join(run_dir, "all_models_training_snapshot.png")
return run_name, run_dir, overview_output
@ -315,10 +344,10 @@ def plot_detailed_snapshot(
plt.close()
def load_run_model_logs(log_root: str, run_name: str) -> Dict[str, pd.DataFrame]:
def load_run_model_logs(run_logs_dir: str) -> Dict[str, pd.DataFrame]:
run_logs = {}
for model_name in MODEL_ORDER:
csv_path = os.path.join(log_root, run_name, model_name, f"{model_name}_training_log.csv")
csv_path = os.path.join(run_logs_dir, model_name, f"{model_name}_training_log.csv")
if os.path.isfile(csv_path):
try:
run_logs[model_name] = safe_read_csv(csv_path)
@ -506,7 +535,7 @@ def plot_all_models_overview(
for ax, (column, title, ylabel) in zip(axes, metrics):
plot_metric_overlay(ax, run_logs, column, title, ylabel, window, show_ma)
fig.suptitle(f"Multi-Model Training Overview | {run_name}", fontsize=16, y=0.995)
fig.suptitle(f"Training Overview | {run_name}", fontsize=16, y=0.995)
plt.tight_layout(rect=[0, 0, 1, 0.98])
plt.savefig(output_path, dpi=170, bbox_inches="tight")
plt.close()
@ -517,9 +546,10 @@ def plot_all_models_overview(
def run_once(args):
run_name, run_logs_dir = resolve_run_logs_dir(args.log_root, args.run)
if args.all_models:
run_name, _, overview_output = resolve_run_dir(args)
run_logs = load_run_model_logs(args.log_root, run_name)
_, _, overview_output = resolve_run_dir(args, run_name, run_logs_dir)
run_logs = load_run_model_logs(run_logs_dir)
if not run_logs:
plot_waiting_placeholder(
overview_output,
@ -539,7 +569,9 @@ def run_once(args):
return
model_name = normalize_model_name(args.model)
run_name, _, csv_path, output_path, compare_output = resolve_paths(args, model_name)
run_name, _, csv_path, output_path, compare_output = resolve_paths(
args, model_name, run_name, run_logs_dir
)
df = safe_read_csv(csv_path)
plot_detailed_snapshot(
model_name=model_name,
@ -551,7 +583,7 @@ def run_once(args):
show_ma=args.show_ma,
)
run_logs = load_run_model_logs(args.log_root, run_name)
run_logs = load_run_model_logs(run_logs_dir)
plot_run_comparison(
target_model=model_name,
run_logs=run_logs,

View File

@ -1,10 +1,37 @@
"""Unified worker entrypoint for launching one model training job."""
import argparse
import json
import os
import sys
import traceback
from datetime import datetime
from training.registry import TRAINERS, normalize_model_name
from utils.run_dirs import add_run_dir_args
def persist_training_error(model_name: str, log_dir: str, exc: Exception) -> None:
if not log_dir:
return
os.makedirs(log_dir, exist_ok=True)
tb_text = traceback.format_exc()
error_txt_path = os.path.join(log_dir, "error_traceback.txt")
with open(error_txt_path, "w", encoding="utf-8") as f:
f.write(tb_text)
error_json_path = os.path.join(log_dir, "error_summary.json")
payload = {
"model": model_name,
"failed_at": datetime.now().isoformat(timespec="seconds"),
"exception_type": type(exc).__name__,
"message": str(exc),
"traceback_file": os.path.abspath(error_txt_path),
}
with open(error_json_path, "w", encoding="utf-8") as f:
json.dump(payload, f, ensure_ascii=False, indent=2)
def main():
parser = add_run_dir_args(argparse.ArgumentParser())
parser.add_argument("--model", required=True, help="Model name to train.")
@ -12,11 +39,20 @@ def main():
model_name = normalize_model_name(args.model)
trainer = TRAINERS[model_name]
trainer(
log_dir=args.log_dir,
checkpoint_dir=args.checkpoint_dir,
run_timestamp=args.run_timestamp,
)
try:
trainer(
log_dir=args.log_dir,
checkpoint_dir=args.checkpoint_dir,
run_timestamp=args.run_timestamp,
)
except Exception as exc:
persist_training_error(model_name, args.log_dir, exc)
print(
f"[{datetime.now().strftime('%H:%M:%S')}] {model_name.upper()} training failed: {type(exc).__name__}: {exc}",
file=sys.stderr,
)
traceback.print_exc()
raise
if __name__ == "__main__":

View File

@ -20,7 +20,7 @@ from utils.config import get_agent_config, get_training_config
from utils.episode_artifacts import save_training_episode_artifacts
from utils.logger import TrainingLogger
from utils.plot import plot_training_curves
from utils.run_dirs import resolve_run_dirs
from utils.run_dirs import resolve_run_dirs, write_shared_run_config
def train_sumo_appo(log_dir=None, checkpoint_dir=None, run_timestamp=None):
@ -42,8 +42,12 @@ def train_sumo_appo(log_dir=None, checkpoint_dir=None, run_timestamp=None):
os.makedirs(log_dir, exist_ok=True)
runtime_config = copy.deepcopy(config)
runtime_config.setdefault("runtime", {})["output_dir"] = log_dir
with open(os.path.join(checkpoint_dir, "config.yaml"), "w", encoding="utf-8") as f:
yaml.dump(runtime_config, f)
write_shared_run_config(
runtime_config,
log_dir=log_dir,
checkpoint_dir=checkpoint_dir,
run_timestamp=run_timestamp,
)
logger = TrainingLogger(log_dir, "appo")
env = SUMOEdgeVSLEnvironment(runtime_config)

View File

@ -16,7 +16,7 @@ from utils.config import get_agent_config, get_training_config
from utils.episode_artifacts import save_training_episode_artifacts
from utils.logger import TrainingLogger
from utils.plot import plot_training_curves
from utils.run_dirs import resolve_run_dirs
from utils.run_dirs import resolve_run_dirs, write_shared_run_config
def train_sumo_dcmappo(log_dir=None, checkpoint_dir=None, run_timestamp=None):
@ -36,8 +36,12 @@ def train_sumo_dcmappo(log_dir=None, checkpoint_dir=None, run_timestamp=None):
os.makedirs(log_dir, exist_ok=True)
runtime_config = copy.deepcopy(config)
runtime_config.setdefault("runtime", {})["output_dir"] = log_dir
with open(os.path.join(checkpoint_dir, "config.yaml"), "w", encoding="utf-8") as f:
yaml.dump(runtime_config, f)
write_shared_run_config(
runtime_config,
log_dir=log_dir,
checkpoint_dir=checkpoint_dir,
run_timestamp=run_timestamp,
)
logger = TrainingLogger(log_dir, "dcmappo")
env = SUMOEdgeVSLEnvironment(runtime_config)

View File

@ -17,7 +17,7 @@ from utils.config import get_agent_config, get_training_config
from utils.episode_artifacts import save_training_episode_artifacts
from utils.logger import TrainingLogger
from utils.plot import plot_training_curves
from utils.run_dirs import resolve_run_dirs
from utils.run_dirs import resolve_run_dirs, write_shared_run_config
def train_sumo_ddpg(log_dir=None, checkpoint_dir=None, run_timestamp=None):
@ -38,8 +38,12 @@ def train_sumo_ddpg(log_dir=None, checkpoint_dir=None, run_timestamp=None):
os.makedirs(log_dir, exist_ok=True)
runtime_config = copy.deepcopy(config)
runtime_config.setdefault("runtime", {})["output_dir"] = log_dir
with open(os.path.join(checkpoint_dir, "config.yaml"), "w", encoding="utf-8") as f:
yaml.dump(runtime_config, f)
write_shared_run_config(
runtime_config,
log_dir=log_dir,
checkpoint_dir=checkpoint_dir,
run_timestamp=run_timestamp,
)
logger = TrainingLogger(log_dir, "ddpg")
env = SUMOEdgeVSLEnvironment(runtime_config)

View File

@ -18,7 +18,7 @@ from utils.config import get_agent_config, get_training_config
from utils.episode_artifacts import save_training_episode_artifacts
from utils.logger import TrainingLogger
from utils.plot import plot_training_curves
from utils.run_dirs import resolve_run_dirs
from utils.run_dirs import resolve_run_dirs, write_shared_run_config
def train_sumo_dqn(log_dir=None, checkpoint_dir=None, run_timestamp=None):
@ -39,8 +39,12 @@ def train_sumo_dqn(log_dir=None, checkpoint_dir=None, run_timestamp=None):
os.makedirs(log_dir, exist_ok=True)
runtime_config = copy.deepcopy(config)
runtime_config.setdefault("runtime", {})["output_dir"] = log_dir
with open(os.path.join(checkpoint_dir, "config.yaml"), "w", encoding="utf-8") as f:
yaml.dump(runtime_config, f)
write_shared_run_config(
runtime_config,
log_dir=log_dir,
checkpoint_dir=checkpoint_dir,
run_timestamp=run_timestamp,
)
logger = TrainingLogger(log_dir, "dqn")
env = SUMOEdgeVSLEnvironment(runtime_config)

View File

@ -14,7 +14,7 @@ from utils.config import get_agent_config, get_training_config
from utils.episode_artifacts import save_training_episode_artifacts
from utils.logger import TrainingLogger
from utils.plot import plot_training_curves
from utils.run_dirs import resolve_run_dirs
from utils.run_dirs import resolve_run_dirs, write_shared_run_config
def train_sumo_gpro(log_dir=None, checkpoint_dir=None, run_timestamp=None):
@ -35,8 +35,12 @@ def train_sumo_gpro(log_dir=None, checkpoint_dir=None, run_timestamp=None):
os.makedirs(log_dir, exist_ok=True)
runtime_config = copy.deepcopy(config)
runtime_config.setdefault("runtime", {})["output_dir"] = log_dir
with open(os.path.join(checkpoint_dir, "config.yaml"), "w", encoding="utf-8") as f:
yaml.dump(runtime_config, f)
write_shared_run_config(
runtime_config,
log_dir=log_dir,
checkpoint_dir=checkpoint_dir,
run_timestamp=run_timestamp,
)
logger = TrainingLogger(log_dir, "gpro")
env = SUMOEdgeVSLEnvironment(runtime_config)

View File

@ -17,7 +17,7 @@ from utils.config import get_agent_config, get_training_config
from utils.episode_artifacts import save_training_episode_artifacts
from utils.logger import TrainingLogger
from utils.plot import plot_training_curves
from utils.run_dirs import resolve_run_dirs
from utils.run_dirs import resolve_run_dirs, write_shared_run_config
def train_sumo_mappo(log_dir=None, checkpoint_dir=None, run_timestamp=None):
@ -37,8 +37,12 @@ def train_sumo_mappo(log_dir=None, checkpoint_dir=None, run_timestamp=None):
os.makedirs(log_dir, exist_ok=True)
runtime_config = copy.deepcopy(config)
runtime_config.setdefault("runtime", {})["output_dir"] = log_dir
with open(os.path.join(checkpoint_dir, "config.yaml"), "w", encoding="utf-8") as f:
yaml.dump(runtime_config, f)
write_shared_run_config(
runtime_config,
log_dir=log_dir,
checkpoint_dir=checkpoint_dir,
run_timestamp=run_timestamp,
)
logger = TrainingLogger(log_dir, "mappo")
env = SUMOEdgeVSLEnvironment(runtime_config)

View File

@ -20,7 +20,7 @@ from utils.config import get_agent_config, get_training_config
from utils.episode_artifacts import save_training_episode_artifacts
from utils.logger import TrainingLogger
from utils.plot import plot_training_curves
from utils.run_dirs import resolve_run_dirs
from utils.run_dirs import resolve_run_dirs, write_shared_run_config
def train_sumo_ppo(log_dir=None, checkpoint_dir=None, run_timestamp=None):
@ -43,8 +43,12 @@ def train_sumo_ppo(log_dir=None, checkpoint_dir=None, run_timestamp=None):
os.makedirs(log_dir, exist_ok=True)
runtime_config = copy.deepcopy(config)
runtime_config.setdefault("runtime", {})["output_dir"] = log_dir
with open(os.path.join(checkpoint_dir, "config.yaml"), "w", encoding="utf-8") as f:
yaml.dump(runtime_config, f)
write_shared_run_config(
runtime_config,
log_dir=log_dir,
checkpoint_dir=checkpoint_dir,
run_timestamp=run_timestamp,
)
logger = TrainingLogger(log_dir, "ppo")

View File

@ -14,7 +14,7 @@ from utils.config import get_agent_config, get_training_config
from utils.episode_artifacts import save_training_episode_artifacts
from utils.logger import TrainingLogger
from utils.plot import plot_training_curves
from utils.run_dirs import resolve_run_dirs
from utils.run_dirs import resolve_run_dirs, write_shared_run_config
def train_sumo_sac(log_dir=None, checkpoint_dir=None, run_timestamp=None):
@ -35,8 +35,12 @@ def train_sumo_sac(log_dir=None, checkpoint_dir=None, run_timestamp=None):
os.makedirs(log_dir, exist_ok=True)
runtime_config = copy.deepcopy(config)
runtime_config.setdefault("runtime", {})["output_dir"] = log_dir
with open(os.path.join(checkpoint_dir, "config.yaml"), "w", encoding="utf-8") as f:
yaml.dump(runtime_config, f)
write_shared_run_config(
runtime_config,
log_dir=log_dir,
checkpoint_dir=checkpoint_dir,
run_timestamp=run_timestamp,
)
logger = TrainingLogger(log_dir, "sac")
env = SUMOEdgeVSLEnvironment(runtime_config)

View File

@ -17,7 +17,7 @@ from utils.config import get_agent_config, get_training_config
from utils.episode_artifacts import save_training_episode_artifacts
from utils.logger import TrainingLogger
from utils.plot import plot_training_curves
from utils.run_dirs import resolve_run_dirs
from utils.run_dirs import resolve_run_dirs, write_shared_run_config
def train_sumo_tcamappo(log_dir=None, checkpoint_dir=None, run_timestamp=None):
@ -37,8 +37,12 @@ def train_sumo_tcamappo(log_dir=None, checkpoint_dir=None, run_timestamp=None):
os.makedirs(log_dir, exist_ok=True)
runtime_config = copy.deepcopy(config)
runtime_config.setdefault("runtime", {})["output_dir"] = log_dir
with open(os.path.join(checkpoint_dir, "config.yaml"), "w", encoding="utf-8") as f:
yaml.dump(runtime_config, f)
write_shared_run_config(
runtime_config,
log_dir=log_dir,
checkpoint_dir=checkpoint_dir,
run_timestamp=run_timestamp,
)
logger = TrainingLogger(log_dir, "tcamappo")
env = SUMOEdgeVSLEnvironment(runtime_config)

View File

@ -18,7 +18,7 @@ from utils.config import get_agent_config, get_training_config
from utils.episode_artifacts import save_training_episode_artifacts
from utils.logger import TrainingLogger
from utils.plot import plot_training_curves
from utils.run_dirs import resolve_run_dirs
from utils.run_dirs import resolve_run_dirs, write_shared_run_config
def train_sumo_td3(
@ -47,8 +47,12 @@ def train_sumo_td3(
os.makedirs(log_dir, exist_ok=True)
runtime_config = copy.deepcopy(config)
runtime_config.setdefault("runtime", {})["output_dir"] = log_dir
with open(os.path.join(checkpoint_dir, "config.yaml"), "w", encoding="utf-8") as f:
yaml.dump(runtime_config, f)
write_shared_run_config(
runtime_config,
log_dir=log_dir,
checkpoint_dir=checkpoint_dir,
run_timestamp=run_timestamp,
)
logger = TrainingLogger(log_dir, model_name)
env = SUMOEdgeVSLEnvironment(runtime_config)

View File

@ -3,6 +3,13 @@ import os
from datetime import datetime
from typing import Optional, Tuple
import yaml
RUNS_ROOT = "runs"
LEGACY_LOGS_ROOT = os.path.join("logs", "multi-model")
LEGACY_CHECKPOINTS_ROOT = os.path.join("checkpoints", "multi-model")
def add_run_dir_args(parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
parser.add_argument("--log-dir", type=str, default=None, help="Training log output directory.")
@ -11,18 +18,162 @@ def add_run_dir_args(parser: argparse.ArgumentParser) -> argparse.ArgumentParser
return parser
def default_run_timestamp() -> str:
return datetime.now().strftime("%Y%m%d_%H%M%S")
def resolve_run_root(run_timestamp: Optional[str] = None) -> Tuple[str, str]:
timestamp = run_timestamp or default_run_timestamp()
return timestamp, os.path.join(RUNS_ROOT, timestamp)
def resolve_run_dirs(
model_name: str,
log_dir: Optional[str] = None,
checkpoint_dir: Optional[str] = None,
run_timestamp: Optional[str] = None,
) -> Tuple[str, str, str]:
"""Resolve output directories from runtime args, falling back to per-model defaults."""
timestamp = run_timestamp or datetime.now().strftime("%Y%m%d_%H%M%S")
"""Resolve output directories from runtime args, defaulting to runs/<timestamp>/..."""
timestamp = run_timestamp or default_run_timestamp()
if checkpoint_dir is None:
checkpoint_dir = os.path.join("checkpoints", model_name, timestamp)
if log_dir is None:
log_dir = os.path.join("logs", model_name, timestamp)
if checkpoint_dir is None or log_dir is None:
_, run_root = resolve_run_root(timestamp)
if checkpoint_dir is None:
checkpoint_dir = os.path.join(run_root, "checkpoints", model_name)
if log_dir is None:
log_dir = os.path.join(run_root, "logs", model_name)
return timestamp, checkpoint_dir, log_dir
def infer_run_root_from_paths(
log_dir: Optional[str] = None,
checkpoint_dir: Optional[str] = None,
run_timestamp: Optional[str] = None,
) -> Optional[str]:
candidate_paths = [path for path in (log_dir, checkpoint_dir) if path]
for path in candidate_paths:
normalized = os.path.normpath(os.path.abspath(path))
parent = os.path.dirname(normalized)
basename = os.path.basename(parent).lower()
if basename in {"logs", "checkpoints"}:
return os.path.dirname(parent)
grandparent = os.path.basename(os.path.dirname(parent)).lower()
if grandparent in {"logs", "checkpoints"}:
return os.path.dirname(os.path.dirname(parent))
if run_timestamp:
return os.path.join(RUNS_ROOT, run_timestamp)
return None
def write_shared_run_config(
runtime_config: dict,
*,
run_root: Optional[str] = None,
log_dir: Optional[str] = None,
checkpoint_dir: Optional[str] = None,
run_timestamp: Optional[str] = None,
) -> Optional[str]:
resolved_run_root = run_root or infer_run_root_from_paths(
log_dir=log_dir,
checkpoint_dir=checkpoint_dir,
run_timestamp=run_timestamp,
)
if not resolved_run_root:
return None
os.makedirs(resolved_run_root, exist_ok=True)
config_path = os.path.join(resolved_run_root, "config.yaml")
with open(config_path, "w", encoding="utf-8") as f:
yaml.dump(runtime_config, f)
return config_path
def find_latest_run_root() -> str:
candidates = []
if os.path.isdir(RUNS_ROOT):
candidates.extend(
os.path.join(RUNS_ROOT, item)
for item in os.listdir(RUNS_ROOT)
if os.path.isdir(os.path.join(RUNS_ROOT, item))
)
if candidates:
return max(candidates)
legacy_candidates = []
if os.path.isdir(LEGACY_CHECKPOINTS_ROOT):
legacy_candidates.extend(
os.path.join(LEGACY_CHECKPOINTS_ROOT, item)
for item in os.listdir(LEGACY_CHECKPOINTS_ROOT)
if os.path.isdir(os.path.join(LEGACY_CHECKPOINTS_ROOT, item))
)
if legacy_candidates:
return max(legacy_candidates)
raise FileNotFoundError(
f"No run directories found under '{RUNS_ROOT}' or legacy '{LEGACY_CHECKPOINTS_ROOT}'."
)
def find_run_root_by_timestamp(run_timestamp: str) -> str:
new_run_root = os.path.join(RUNS_ROOT, run_timestamp)
if os.path.isdir(new_run_root):
return new_run_root
legacy_checkpoint_root = os.path.join(LEGACY_CHECKPOINTS_ROOT, run_timestamp)
if os.path.isdir(legacy_checkpoint_root):
return legacy_checkpoint_root
legacy_log_root = os.path.join(LEGACY_LOGS_ROOT, run_timestamp)
if os.path.isdir(legacy_log_root):
return legacy_log_root
raise FileNotFoundError(
f"Run '{run_timestamp}' not found under '{RUNS_ROOT}' or legacy multi-model directories."
)
def resolve_checkpoint_root(input_path: Optional[str]) -> str:
if input_path:
abs_path = os.path.abspath(input_path)
else:
abs_path = os.path.abspath(find_latest_run_root())
if os.path.isdir(os.path.join(abs_path, "checkpoints")):
return os.path.join(abs_path, "checkpoints")
return abs_path
def resolve_log_root(input_path: Optional[str]) -> str:
if input_path:
abs_path = os.path.abspath(input_path)
else:
abs_path = os.path.abspath(find_latest_run_root())
if os.path.isdir(os.path.join(abs_path, "logs")):
return os.path.join(abs_path, "logs")
return abs_path
def find_shared_config_path(
checkpoint_dir: Optional[str] = None,
fallback_config_path: Optional[str] = None,
) -> Optional[str]:
if checkpoint_dir:
normalized = os.path.normpath(os.path.abspath(checkpoint_dir))
local_config = os.path.join(normalized, "config.yaml")
if os.path.isfile(local_config):
return local_config
parent = os.path.dirname(normalized)
if os.path.basename(parent).lower() == "checkpoints":
shared_config = os.path.join(os.path.dirname(parent), "config.yaml")
if os.path.isfile(shared_config):
return shared_config
if fallback_config_path and os.path.isfile(fallback_config_path):
return fallback_config_path
return None