更新绘图

This commit is contained in:
Zihan Ye 2026-04-17 07:17:21 +08:00
parent 288f67ae50
commit ebfd2762cf
1 changed files with 43 additions and 17 deletions

View File

@ -55,6 +55,8 @@ MODEL_COLORS = {
"td3": "#17becf",
"sctd3": "#bcbd22",
}
EFFICIENCY_COLUMN = "r_efficiency"
EFFICIENCY_LABEL = REWARD_COMPONENT_LABELS.get(EFFICIENCY_COLUMN, "Running Efficiency")
def parse_args():
@ -265,6 +267,7 @@ def build_summary_text(model_name: str, df: pd.DataFrame, run_name: str, csv_pat
reward_last20 = df["reward"].tail(min(20, len(df))).mean() if "reward" in df else np.nan
tp_last20 = df["throughput"].tail(min(20, len(df))).mean() if "throughput" in df else np.nan
speed_last20 = df["mean_speed"].tail(min(20, len(df))).mean() if "mean_speed" in df else np.nan
efficiency_last20 = df[EFFICIENCY_COLUMN].tail(min(20, len(df))).mean() if EFFICIENCY_COLUMN in df else np.nan
stop_last20 = df["stops"].tail(min(20, len(df))).mean() if "stops" in df else np.nan
return (
@ -277,6 +280,7 @@ def build_summary_text(model_name: str, df: pd.DataFrame, run_name: str, csv_pat
f"Reward last20: {reward_last20:.2f}\n"
f"Throughput last20: {tp_last20:.1f}\n"
f"Mean speed last20: {speed_last20:.1f}\n"
f"Efficiency last20: {efficiency_last20:.3f}\n"
f"Stops last20: {stop_last20:.1f}\n"
f"\nCSV:\n{os.path.abspath(csv_path)}"
)
@ -291,7 +295,7 @@ def plot_detailed_snapshot(
window: int,
show_ma: bool,
):
fig, axes = plt.subplots(3, 3, figsize=(18, 12))
fig, axes = plt.subplots(4, 3, figsize=(18, 15))
axes = axes.flatten()
plot_series(axes[0], df, "reward", "Reward", "Reward", "tab:blue", window, show_ma)
@ -300,6 +304,16 @@ def plot_detailed_snapshot(
plot_series(
axes[3],
df,
EFFICIENCY_COLUMN,
EFFICIENCY_LABEL,
"value",
"tab:olive",
window,
show_ma,
)
plot_series(
axes[4],
df,
"speed_variance_norm",
"Normalized Speed Variance",
"norm",
@ -307,9 +321,9 @@ def plot_detailed_snapshot(
window,
show_ma,
)
plot_series(axes[4], df, "stops", "Stops", "count", "tab:red", window, show_ma)
plot_series(axes[5], df, "stops", "Stops", "count", "tab:red", window, show_ma)
reward_components = list(REWARD_COMPONENT_COLUMNS)
reward_components = [column for column in REWARD_COMPONENT_COLUMNS if column != EFFICIENCY_COLUMN]
has_components = any(
col in df.columns and pd.to_numeric(df[col], errors="coerce").notna().sum() > 0
for col in reward_components
@ -327,7 +341,7 @@ def plot_detailed_snapshot(
for col, color in zip(reward_components, component_colors):
series = pd.to_numeric(df[col], errors="coerce")
if series.notna().sum() > 0:
axes[5].plot(
axes[6].plot(
df["episode"],
series,
linewidth=1.6,
@ -336,7 +350,7 @@ def plot_detailed_snapshot(
label=col,
)
if show_ma:
axes[5].plot(
axes[6].plot(
df["episode"],
moving_average(series, window),
linewidth=1.0,
@ -345,18 +359,20 @@ def plot_detailed_snapshot(
color=color,
label="_nolegend_",
)
axes[5].set_title("Reward Components")
axes[5].set_xlabel("Episode")
axes[5].grid(True, alpha=0.3)
axes[5].legend(loc="best")
axes[6].set_title("Reward Components")
axes[6].set_xlabel("Episode")
axes[6].grid(True, alpha=0.3)
axes[6].legend(loc="best")
else:
axes[5].axis("off")
axes[6].axis("off")
plot_series(axes[6], df, "policy_loss", "Policy Loss", "loss", "tab:cyan", window, show_ma)
plot_series(axes[7], df, "value_loss", "Value Loss", "loss", "tab:pink", window, show_ma)
plot_series(axes[7], df, "policy_loss", "Policy Loss", "loss", "tab:cyan", window, show_ma)
plot_series(axes[8], df, "value_loss", "Value Loss", "loss", "tab:pink", window, show_ma)
plot_series(axes[9], df, "entropy", "Entropy", "value", "tab:brown", window, show_ma)
axes[8].axis("off")
axes[8].text(
axes[10].axis("off")
axes[11].axis("off")
axes[11].text(
0.0,
1.0,
build_summary_text(model_name, df, run_name, csv_path),
@ -364,7 +380,7 @@ def plot_detailed_snapshot(
ha="left",
family="monospace",
fontsize=11,
transform=axes[8].transAxes,
transform=axes[11].transAxes,
)
plt.tight_layout()
@ -425,9 +441,10 @@ def plot_run_comparison(
("reward", "Reward"),
("throughput", "Throughput"),
("mean_speed", "Mean Speed"),
(EFFICIENCY_COLUMN, EFFICIENCY_LABEL),
("speed_variance_norm", "Normalized Speed Variance"),
]
fig, axes = plt.subplots(2, 2, figsize=(16, 10))
fig, axes = plt.subplots(3, 2, figsize=(16, 14))
axes = axes.flatten()
for ax, (column, title) in zip(axes, metrics):
@ -462,6 +479,8 @@ def plot_run_comparison(
ax.set_xlabel("Episode")
ax.grid(True, alpha=0.3)
axes[0].legend(loc="best")
for ax in axes[len(metrics):]:
ax.axis("off")
plt.tight_layout()
plt.savefig(compare_output, dpi=160, bbox_inches="tight")
plt.close()
@ -523,12 +542,14 @@ def build_overview_summary(run_name: str, run_logs: Dict[str, pd.DataFrame]) ->
continue
reward = pd.to_numeric(df.get("reward"), errors="coerce")
throughput = pd.to_numeric(df.get("throughput"), errors="coerce")
efficiency = pd.to_numeric(df.get(EFFICIENCY_COLUMN), errors="coerce")
latest_episode = int(df["episode"].iloc[-1])
reward_last20 = reward.tail(min(20, len(df))).mean() if reward.notna().sum() > 0 else np.nan
tp_last20 = throughput.tail(min(20, len(df))).mean() if throughput.notna().sum() > 0 else np.nan
eff_last20 = efficiency.tail(min(20, len(df))).mean() if efficiency.notna().sum() > 0 else np.nan
lines.append(
f"{MODEL_LABELS[model_name]}: ep={latest_episode}, "
f"reward20={reward_last20:.1f}, tp20={tp_last20:.0f}"
f"reward20={reward_last20:.1f}, tp20={tp_last20:.0f}, eff20={eff_last20:.3f}"
)
return "\n".join(lines)
@ -545,6 +566,8 @@ def plot_all_models_overview(
available_reward_components = []
for column in REWARD_COMPONENT_COLUMNS:
if column == EFFICIENCY_COLUMN:
continue
if any(
df is not None
and column in df.columns
@ -560,6 +583,7 @@ def plot_all_models_overview(
("reward", "Reward", "Reward"),
("throughput", "Throughput", "veh/h"),
("mean_speed", "Mean Speed", "km/h"),
(EFFICIENCY_COLUMN, EFFICIENCY_LABEL, "value"),
("speed_variance_norm", "Normalized Speed Variance", "norm"),
("stops", "Stops", "count"),
("policy_loss", "Policy Loss", "loss"),
@ -574,6 +598,8 @@ def plot_all_models_overview(
for ax, (column, title, ylabel) in zip(axes, metrics):
plot_metric_overlay(ax, run_logs, column, title, ylabel, window, show_ma)
for ax in axes[len(metrics):]:
ax.axis("off")
fig.suptitle(f"Training Overview | {run_name}", fontsize=16, y=0.995)
plt.tight_layout(rect=[0, 0, 1, 0.98])