添加运行时绘制全部对比图功能
This commit is contained in:
parent
27502241ad
commit
1b9f90b5bb
|
|
@ -1,20 +1,25 @@
|
|||
"""一键异步启动全部训练。"""
|
||||
"""Launch all model training jobs plus a live overview plotter."""
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
from datetime import datetime
|
||||
|
||||
AGENTS = ["ppo", "appo", "mappo", "dqn", "ddpg", "td3"]
|
||||
PLOT_REFRESH_INTERVAL = 30
|
||||
|
||||
|
||||
def main():
|
||||
run_timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
processes = {}
|
||||
run_log_root = os.path.join("logs", "multi-model", run_timestamp)
|
||||
run_ckpt_root = os.path.join("checkpoints", "multi-model", run_timestamp)
|
||||
os.makedirs(run_log_root, exist_ok=True)
|
||||
os.makedirs(run_ckpt_root, exist_ok=True)
|
||||
|
||||
processes = {}
|
||||
for agent in AGENTS:
|
||||
print(f"[{datetime.now().strftime('%H:%M:%S')}] 启动 {agent.upper()} 训练...")
|
||||
log_dir = os.path.join("logs", "multi-model", run_timestamp, agent)
|
||||
checkpoint_dir = os.path.join("checkpoints", "multi-model", run_timestamp, agent)
|
||||
log_dir = os.path.join(run_log_root, agent)
|
||||
checkpoint_dir = os.path.join(run_ckpt_root, agent)
|
||||
os.makedirs(log_dir, exist_ok=True)
|
||||
os.makedirs(checkpoint_dir, exist_ok=True)
|
||||
stdout_path = os.path.join(log_dir, "stdout.txt")
|
||||
|
|
@ -37,13 +42,41 @@ def main():
|
|||
processes[agent] = process
|
||||
print(f" PID: {process.pid}")
|
||||
|
||||
plotter_stdout_path = os.path.join(run_log_root, "plotter_stdout.txt")
|
||||
plotter_process = subprocess.Popen(
|
||||
[
|
||||
sys.executable,
|
||||
"-m",
|
||||
"scripts.plot_live_training",
|
||||
"--all-models",
|
||||
"--run",
|
||||
run_timestamp,
|
||||
"--watch",
|
||||
"--interval",
|
||||
str(PLOT_REFRESH_INTERVAL),
|
||||
],
|
||||
stdout=open(plotter_stdout_path, "w", encoding="utf-8"),
|
||||
stderr=subprocess.STDOUT,
|
||||
)
|
||||
print(f"[{datetime.now().strftime('%H:%M:%S')}] 启动联训总览绘图...")
|
||||
print(f" PID: {plotter_process.pid}")
|
||||
|
||||
print(f"\n全部 {len(AGENTS)} 个训练已启动,等待完成...")
|
||||
print(f"本次多模型时间戳: {run_timestamp}\n")
|
||||
|
||||
for agent, process in processes.items():
|
||||
process.wait()
|
||||
status = "完成" if process.returncode == 0 else f"失败(code={process.returncode})"
|
||||
print(f"[{agent.upper()}] {status}")
|
||||
try:
|
||||
for agent, process in processes.items():
|
||||
process.wait()
|
||||
status = "完成" if process.returncode == 0 else f"失败(code={process.returncode})"
|
||||
print(f"[{agent.upper()}] {status}")
|
||||
finally:
|
||||
if plotter_process.poll() is None:
|
||||
plotter_process.terminate()
|
||||
try:
|
||||
plotter_process.wait(timeout=10)
|
||||
except subprocess.TimeoutExpired:
|
||||
plotter_process.kill()
|
||||
print("[PLOTTER] 已停止自动绘图进程")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
|
|
@ -0,0 +1,566 @@
|
|||
"""Plot live training snapshots from CSV logs."""
|
||||
import argparse
|
||||
import os
|
||||
import time
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import matplotlib
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
matplotlib.use("Agg")
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
|
||||
MODEL_ORDER = ["ppo", "appo", "mappo", "dqn", "ddpg", "td3"]
|
||||
MODEL_LABELS = {name: name.upper() for name in MODEL_ORDER}
|
||||
MODEL_COLORS = {
|
||||
"ppo": "#1f77b4",
|
||||
"appo": "#ff7f0e",
|
||||
"mappo": "#2ca02c",
|
||||
"dqn": "#d62728",
|
||||
"ddpg": "#9467bd",
|
||||
"td3": "#17becf",
|
||||
}
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description="Plot live training progress from multi-model logs.")
|
||||
parser.add_argument("--model", default=None, help="Model name, e.g. ppo/appo/mappo/dqn/ddpg/td3")
|
||||
parser.add_argument(
|
||||
"--all-models",
|
||||
action="store_true",
|
||||
help="Plot all models into one overview figure.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--run",
|
||||
default=None,
|
||||
help="Multi-model run timestamp. Default: latest under logs/multi-model/",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--log-root",
|
||||
default=os.path.join("logs", "multi-model"),
|
||||
help="Multi-model log root directory.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--csv-path",
|
||||
default=None,
|
||||
help="Direct training CSV path override.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output",
|
||||
default=None,
|
||||
help="Detailed snapshot output path. Default: <model_dir>/training_snapshot.png",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--compare-output",
|
||||
default=None,
|
||||
help="Run comparison output path. Default: <model_dir>/run_comparison_snapshot.png",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--overview-output",
|
||||
default=None,
|
||||
help="All-model overview output path. Default: <run_dir>/all_models_training_snapshot.png",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--window",
|
||||
type=int,
|
||||
default=20,
|
||||
help="Moving average window.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--show-ma",
|
||||
action="store_true",
|
||||
help="Overlay moving-average curves.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--watch",
|
||||
action="store_true",
|
||||
help="Continuously refresh plots.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--interval",
|
||||
type=int,
|
||||
default=30,
|
||||
help="Refresh interval in seconds when --watch is enabled.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
if not args.all_models and not args.model:
|
||||
parser.error("Either --model or --all-models must be specified.")
|
||||
return args
|
||||
|
||||
|
||||
def normalize_model_name(name: str) -> str:
|
||||
model = name.strip().lower()
|
||||
if model not in MODEL_LABELS:
|
||||
raise ValueError(f"Unsupported model name: {name}")
|
||||
return model
|
||||
|
||||
|
||||
def find_latest_run(log_root: str) -> str:
|
||||
if not os.path.isdir(log_root):
|
||||
raise FileNotFoundError(f"Log root not found: {log_root}")
|
||||
candidates = [
|
||||
item for item in os.listdir(log_root)
|
||||
if os.path.isdir(os.path.join(log_root, item))
|
||||
]
|
||||
if not candidates:
|
||||
raise FileNotFoundError(f"No run directories found under: {log_root}")
|
||||
return max(candidates)
|
||||
|
||||
|
||||
def resolve_paths(args, model_name: str):
|
||||
if args.csv_path:
|
||||
csv_path = args.csv_path
|
||||
model_dir = os.path.dirname(os.path.abspath(csv_path))
|
||||
run_dir = os.path.dirname(model_dir)
|
||||
run_name = os.path.basename(run_dir)
|
||||
else:
|
||||
run_name = args.run or find_latest_run(args.log_root)
|
||||
model_dir = os.path.join(args.log_root, run_name, model_name)
|
||||
csv_path = os.path.join(model_dir, f"{model_name}_training_log.csv")
|
||||
|
||||
if not os.path.isfile(csv_path):
|
||||
raise FileNotFoundError(f"Training CSV not found: {csv_path}")
|
||||
|
||||
output_path = args.output or os.path.join(model_dir, "training_snapshot.png")
|
||||
compare_output = args.compare_output or os.path.join(model_dir, "run_comparison_snapshot.png")
|
||||
return run_name, model_dir, csv_path, output_path, compare_output
|
||||
|
||||
|
||||
def resolve_run_dir(args):
|
||||
run_name = args.run or find_latest_run(args.log_root)
|
||||
run_dir = os.path.join(args.log_root, run_name)
|
||||
if not os.path.isdir(run_dir):
|
||||
raise FileNotFoundError(f"Run directory not found: {run_dir}")
|
||||
overview_output = args.overview_output or os.path.join(run_dir, "all_models_training_snapshot.png")
|
||||
return run_name, run_dir, overview_output
|
||||
|
||||
|
||||
def safe_read_csv(csv_path: str) -> pd.DataFrame:
|
||||
df = pd.read_csv(csv_path, on_bad_lines="skip")
|
||||
if df.empty:
|
||||
raise ValueError(f"CSV exists but contains no rows: {csv_path}")
|
||||
|
||||
numeric_columns = [
|
||||
"episode", "reward", "throughput", "mean_speed", "speed_std",
|
||||
"r_flow", "r_var", "r_brake", "r_penalty", "hard_brakes",
|
||||
"policy_loss", "value_loss", "entropy",
|
||||
]
|
||||
for column in numeric_columns:
|
||||
if column in df.columns:
|
||||
df[column] = pd.to_numeric(df[column], errors="coerce")
|
||||
|
||||
df = df.dropna(subset=["episode"])
|
||||
df = df.sort_values("episode").drop_duplicates(subset=["episode"], keep="last")
|
||||
return df.reset_index(drop=True)
|
||||
|
||||
|
||||
def moving_average(values: pd.Series, window: int) -> pd.Series:
|
||||
return values.rolling(window=min(window, max(len(values), 1)), min_periods=1).mean()
|
||||
|
||||
|
||||
def plot_series(
|
||||
ax,
|
||||
df: pd.DataFrame,
|
||||
column: str,
|
||||
title: str,
|
||||
ylabel: str,
|
||||
color: str,
|
||||
window: int,
|
||||
show_ma: bool,
|
||||
):
|
||||
if column not in df.columns:
|
||||
ax.axis("off")
|
||||
return
|
||||
|
||||
series = pd.to_numeric(df[column], errors="coerce")
|
||||
if series.notna().sum() == 0:
|
||||
ax.axis("off")
|
||||
return
|
||||
|
||||
raw_label = "Series" if show_ma else None
|
||||
ax.plot(df["episode"], series, color=color, alpha=0.9, linewidth=1.8, label=raw_label)
|
||||
if show_ma:
|
||||
ax.plot(
|
||||
df["episode"],
|
||||
moving_average(series, window),
|
||||
color=color,
|
||||
alpha=0.35,
|
||||
linewidth=1.0,
|
||||
linestyle="--",
|
||||
label=f"MA{window}",
|
||||
)
|
||||
ax.set_title(title)
|
||||
ax.set_xlabel("Episode")
|
||||
ax.set_ylabel(ylabel)
|
||||
ax.grid(True, alpha=0.3)
|
||||
if show_ma:
|
||||
ax.legend(loc="best")
|
||||
|
||||
|
||||
def build_summary_text(model_name: str, df: pd.DataFrame, run_name: str, csv_path: str) -> str:
|
||||
last_row = df.iloc[-1]
|
||||
reward_last20 = df["reward"].tail(min(20, len(df))).mean() if "reward" in df else np.nan
|
||||
tp_last20 = df["throughput"].tail(min(20, len(df))).mean() if "throughput" in df else np.nan
|
||||
speed_last20 = df["mean_speed"].tail(min(20, len(df))).mean() if "mean_speed" in df else np.nan
|
||||
brake_last20 = df["hard_brakes"].tail(min(20, len(df))).mean() if "hard_brakes" in df else np.nan
|
||||
|
||||
return (
|
||||
f"Run: {run_name}\n"
|
||||
f"Model: {MODEL_LABELS[model_name]}\n"
|
||||
f"Episodes logged: {len(df)}\n"
|
||||
f"Latest episode: {int(last_row['episode'])}\n"
|
||||
f"Best reward: {df['reward'].max():.2f}\n"
|
||||
f"Latest reward: {last_row.get('reward', np.nan):.2f}\n"
|
||||
f"Reward last20: {reward_last20:.2f}\n"
|
||||
f"Throughput last20: {tp_last20:.1f}\n"
|
||||
f"Mean speed last20: {speed_last20:.1f}\n"
|
||||
f"Hard brakes last20: {brake_last20:.1f}\n"
|
||||
f"\nCSV:\n{os.path.abspath(csv_path)}"
|
||||
)
|
||||
|
||||
|
||||
def plot_detailed_snapshot(
|
||||
model_name: str,
|
||||
df: pd.DataFrame,
|
||||
run_name: str,
|
||||
csv_path: str,
|
||||
output_path: str,
|
||||
window: int,
|
||||
show_ma: bool,
|
||||
):
|
||||
fig, axes = plt.subplots(3, 3, figsize=(18, 12))
|
||||
axes = axes.flatten()
|
||||
|
||||
plot_series(axes[0], df, "reward", "Reward", "Reward", "tab:blue", window, show_ma)
|
||||
plot_series(axes[1], df, "throughput", "Throughput", "veh/h", "tab:green", window, show_ma)
|
||||
plot_series(axes[2], df, "mean_speed", "Mean Speed", "km/h", "tab:orange", window, show_ma)
|
||||
plot_series(axes[3], df, "speed_std", "Speed Std", "km/h", "tab:purple", window, show_ma)
|
||||
plot_series(axes[4], df, "hard_brakes", "Hard Brakes", "count", "tab:red", window, show_ma)
|
||||
|
||||
reward_components = ["r_flow", "r_var", "r_brake", "r_penalty"]
|
||||
has_components = any(
|
||||
col in df.columns and pd.to_numeric(df[col], errors="coerce").notna().sum() > 0
|
||||
for col in reward_components
|
||||
)
|
||||
if has_components:
|
||||
for col, color in zip(reward_components, ["tab:green", "tab:purple", "tab:red", "tab:brown"]):
|
||||
series = pd.to_numeric(df[col], errors="coerce")
|
||||
if series.notna().sum() > 0:
|
||||
axes[5].plot(
|
||||
df["episode"],
|
||||
series,
|
||||
linewidth=1.6,
|
||||
alpha=0.75,
|
||||
color=color,
|
||||
label=col,
|
||||
)
|
||||
if show_ma:
|
||||
axes[5].plot(
|
||||
df["episode"],
|
||||
moving_average(series, window),
|
||||
linewidth=1.0,
|
||||
alpha=0.35,
|
||||
linestyle="--",
|
||||
color=color,
|
||||
label="_nolegend_",
|
||||
)
|
||||
axes[5].set_title("Reward Components")
|
||||
axes[5].set_xlabel("Episode")
|
||||
axes[5].grid(True, alpha=0.3)
|
||||
axes[5].legend(loc="best")
|
||||
else:
|
||||
axes[5].axis("off")
|
||||
|
||||
plot_series(axes[6], df, "policy_loss", "Policy Loss", "loss", "tab:cyan", window, show_ma)
|
||||
plot_series(axes[7], df, "value_loss", "Value Loss", "loss", "tab:pink", window, show_ma)
|
||||
|
||||
axes[8].axis("off")
|
||||
axes[8].text(
|
||||
0.0,
|
||||
1.0,
|
||||
build_summary_text(model_name, df, run_name, csv_path),
|
||||
va="top",
|
||||
ha="left",
|
||||
family="monospace",
|
||||
fontsize=11,
|
||||
transform=axes[8].transAxes,
|
||||
)
|
||||
|
||||
plt.tight_layout()
|
||||
plt.savefig(output_path, dpi=160, bbox_inches="tight")
|
||||
plt.close()
|
||||
|
||||
|
||||
def load_run_model_logs(log_root: str, run_name: str) -> Dict[str, pd.DataFrame]:
|
||||
run_logs = {}
|
||||
for model_name in MODEL_ORDER:
|
||||
csv_path = os.path.join(log_root, run_name, model_name, f"{model_name}_training_log.csv")
|
||||
if os.path.isfile(csv_path):
|
||||
try:
|
||||
run_logs[model_name] = safe_read_csv(csv_path)
|
||||
except Exception:
|
||||
continue
|
||||
return run_logs
|
||||
|
||||
|
||||
def plot_waiting_placeholder(output_path: str, run_name: str, message: str):
|
||||
fig, ax = plt.subplots(figsize=(10, 4))
|
||||
ax.axis("off")
|
||||
ax.text(
|
||||
0.5,
|
||||
0.6,
|
||||
f"Run: {run_name}",
|
||||
ha="center",
|
||||
va="center",
|
||||
fontsize=16,
|
||||
family="monospace",
|
||||
transform=ax.transAxes,
|
||||
)
|
||||
ax.text(
|
||||
0.5,
|
||||
0.4,
|
||||
message,
|
||||
ha="center",
|
||||
va="center",
|
||||
fontsize=12,
|
||||
transform=ax.transAxes,
|
||||
)
|
||||
plt.tight_layout()
|
||||
plt.savefig(output_path, dpi=150, bbox_inches="tight")
|
||||
plt.close()
|
||||
|
||||
|
||||
def plot_run_comparison(
|
||||
target_model: str,
|
||||
run_logs: Dict[str, pd.DataFrame],
|
||||
compare_output: str,
|
||||
window: int,
|
||||
show_ma: bool,
|
||||
):
|
||||
if len(run_logs) <= 1:
|
||||
return
|
||||
|
||||
metrics = [
|
||||
("reward", "Reward"),
|
||||
("throughput", "Throughput"),
|
||||
("mean_speed", "Mean Speed"),
|
||||
("speed_std", "Speed Std"),
|
||||
]
|
||||
fig, axes = plt.subplots(2, 2, figsize=(16, 10))
|
||||
axes = axes.flatten()
|
||||
|
||||
for ax, (column, title) in zip(axes, metrics):
|
||||
for model_name in MODEL_ORDER:
|
||||
df = run_logs.get(model_name)
|
||||
if df is None or column not in df.columns:
|
||||
continue
|
||||
series = pd.to_numeric(df[column], errors="coerce")
|
||||
if series.notna().sum() == 0:
|
||||
continue
|
||||
linewidth = 2.8 if model_name == target_model else 1.6
|
||||
alpha = 1.0 if model_name == target_model else 0.5
|
||||
ax.plot(
|
||||
df["episode"],
|
||||
series,
|
||||
linewidth=linewidth + 0.4,
|
||||
alpha=alpha,
|
||||
color=MODEL_COLORS[model_name],
|
||||
label=MODEL_LABELS[model_name],
|
||||
)
|
||||
if show_ma:
|
||||
ax.plot(
|
||||
df["episode"],
|
||||
moving_average(series, window),
|
||||
linewidth=1.0,
|
||||
linestyle="--",
|
||||
alpha=min(alpha, 0.4),
|
||||
color=MODEL_COLORS[model_name],
|
||||
label="_nolegend_",
|
||||
)
|
||||
ax.set_title(title)
|
||||
ax.set_xlabel("Episode")
|
||||
ax.grid(True, alpha=0.3)
|
||||
axes[0].legend(loc="best")
|
||||
plt.tight_layout()
|
||||
plt.savefig(compare_output, dpi=160, bbox_inches="tight")
|
||||
plt.close()
|
||||
|
||||
|
||||
def plot_metric_overlay(
|
||||
ax,
|
||||
run_logs: Dict[str, pd.DataFrame],
|
||||
column: str,
|
||||
title: str,
|
||||
ylabel: str,
|
||||
window: int,
|
||||
show_ma: bool,
|
||||
):
|
||||
plotted = False
|
||||
for model_name in MODEL_ORDER:
|
||||
df = run_logs.get(model_name)
|
||||
if df is None or column not in df.columns:
|
||||
continue
|
||||
series = pd.to_numeric(df[column], errors="coerce")
|
||||
if series.notna().sum() == 0:
|
||||
continue
|
||||
ax.plot(
|
||||
df["episode"],
|
||||
series,
|
||||
alpha=0.85,
|
||||
linewidth=1.8,
|
||||
color=MODEL_COLORS[model_name],
|
||||
label=MODEL_LABELS[model_name],
|
||||
)
|
||||
if show_ma:
|
||||
ax.plot(
|
||||
df["episode"],
|
||||
moving_average(series, window),
|
||||
linewidth=1.0,
|
||||
linestyle="--",
|
||||
alpha=0.35,
|
||||
color=MODEL_COLORS[model_name],
|
||||
label="_nolegend_",
|
||||
)
|
||||
plotted = True
|
||||
|
||||
if not plotted:
|
||||
ax.axis("off")
|
||||
return
|
||||
|
||||
ax.set_title(title)
|
||||
ax.set_xlabel("Episode")
|
||||
ax.set_ylabel(ylabel)
|
||||
ax.grid(True, alpha=0.3)
|
||||
ax.legend(loc="best")
|
||||
|
||||
|
||||
def build_overview_summary(run_name: str, run_logs: Dict[str, pd.DataFrame]) -> str:
|
||||
lines = [f"Run: {run_name}", f"Models: {len(run_logs)}", ""]
|
||||
for model_name in MODEL_ORDER:
|
||||
df = run_logs.get(model_name)
|
||||
if df is None or df.empty:
|
||||
continue
|
||||
reward = pd.to_numeric(df.get("reward"), errors="coerce")
|
||||
throughput = pd.to_numeric(df.get("throughput"), errors="coerce")
|
||||
latest_episode = int(df["episode"].iloc[-1])
|
||||
reward_last20 = reward.tail(min(20, len(df))).mean() if reward.notna().sum() > 0 else np.nan
|
||||
tp_last20 = throughput.tail(min(20, len(df))).mean() if throughput.notna().sum() > 0 else np.nan
|
||||
lines.append(
|
||||
f"{MODEL_LABELS[model_name]}: ep={latest_episode}, "
|
||||
f"reward20={reward_last20:.1f}, tp20={tp_last20:.0f}"
|
||||
)
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def plot_all_models_overview(
|
||||
run_name: str,
|
||||
run_logs: Dict[str, pd.DataFrame],
|
||||
output_path: str,
|
||||
window: int,
|
||||
show_ma: bool,
|
||||
):
|
||||
if not run_logs:
|
||||
raise ValueError("No readable model logs found for overview plotting.")
|
||||
|
||||
fig, axes = plt.subplots(4, 3, figsize=(22, 16))
|
||||
axes = axes.flatten()
|
||||
metrics = [
|
||||
("reward", "Reward", "Reward"),
|
||||
("throughput", "Throughput", "veh/h"),
|
||||
("mean_speed", "Mean Speed", "km/h"),
|
||||
("speed_std", "Speed Std", "km/h"),
|
||||
("hard_brakes", "Hard Brakes", "count"),
|
||||
("r_flow", "R_flow", "value"),
|
||||
("r_var", "R_var", "value"),
|
||||
("r_brake", "R_brake", "value"),
|
||||
("r_penalty", "R_penalty", "value"),
|
||||
("policy_loss", "Policy Loss", "loss"),
|
||||
("value_loss", "Value Loss", "loss"),
|
||||
("entropy", "Entropy", "value"),
|
||||
]
|
||||
|
||||
for ax, (column, title, ylabel) in zip(axes, metrics):
|
||||
plot_metric_overlay(ax, run_logs, column, title, ylabel, window, show_ma)
|
||||
|
||||
fig.suptitle(f"Multi-Model Training Overview | {run_name}", fontsize=16, y=0.995)
|
||||
plt.tight_layout(rect=[0, 0, 1, 0.98])
|
||||
plt.savefig(output_path, dpi=170, bbox_inches="tight")
|
||||
plt.close()
|
||||
|
||||
summary_path = os.path.splitext(output_path)[0] + "_summary.txt"
|
||||
with open(summary_path, "w", encoding="utf-8") as f:
|
||||
f.write(build_overview_summary(run_name, run_logs))
|
||||
|
||||
|
||||
def run_once(args):
|
||||
if args.all_models:
|
||||
run_name, _, overview_output = resolve_run_dir(args)
|
||||
run_logs = load_run_model_logs(args.log_root, run_name)
|
||||
if not run_logs:
|
||||
plot_waiting_placeholder(
|
||||
overview_output,
|
||||
run_name,
|
||||
"Waiting for training logs to accumulate...",
|
||||
)
|
||||
print(
|
||||
f"[{time.strftime('%H:%M:%S')}] Waiting for model logs | "
|
||||
f"output={overview_output}"
|
||||
)
|
||||
return
|
||||
plot_all_models_overview(run_name, run_logs, overview_output, args.window, args.show_ma)
|
||||
print(
|
||||
f"[{time.strftime('%H:%M:%S')}] Updated all-model overview | "
|
||||
f"models={len(run_logs)} | output={overview_output}"
|
||||
)
|
||||
return
|
||||
|
||||
model_name = normalize_model_name(args.model)
|
||||
run_name, _, csv_path, output_path, compare_output = resolve_paths(args, model_name)
|
||||
df = safe_read_csv(csv_path)
|
||||
plot_detailed_snapshot(
|
||||
model_name=model_name,
|
||||
df=df,
|
||||
run_name=run_name,
|
||||
csv_path=csv_path,
|
||||
output_path=output_path,
|
||||
window=args.window,
|
||||
show_ma=args.show_ma,
|
||||
)
|
||||
|
||||
run_logs = load_run_model_logs(args.log_root, run_name)
|
||||
plot_run_comparison(
|
||||
target_model=model_name,
|
||||
run_logs=run_logs,
|
||||
compare_output=compare_output,
|
||||
window=args.window,
|
||||
show_ma=args.show_ma,
|
||||
)
|
||||
|
||||
print(
|
||||
f"[{time.strftime('%H:%M:%S')}] Updated {MODEL_LABELS[model_name]} snapshot | "
|
||||
f"episodes={len(df)} | output={output_path}"
|
||||
)
|
||||
if len(run_logs) > 1:
|
||||
print(f"[{time.strftime('%H:%M:%S')}] Updated run comparison | output={compare_output}")
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
if args.watch:
|
||||
try:
|
||||
while True:
|
||||
try:
|
||||
run_once(args)
|
||||
except Exception as exc:
|
||||
print(f"[{time.strftime('%H:%M:%S')}] Plot update skipped: {exc}")
|
||||
time.sleep(max(args.interval, 1))
|
||||
except KeyboardInterrupt:
|
||||
print("Stopped live plotting.")
|
||||
else:
|
||||
run_once(args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Loading…
Reference in New Issue