恢复训练时同步恢复历史记录以便绘图

This commit is contained in:
Zihan Ye 2026-03-25 13:54:47 +08:00
parent 8377658995
commit 999430422d
3 changed files with 52 additions and 0 deletions

View File

@ -100,6 +100,24 @@ def train_sumo_appo(resume_checkpoint=None):
entropies = []
best_reward = -float("inf")
# 加载历史训练数据
if resume_checkpoint:
import csv
log_file = os.path.join(log_dir, "appo_training_log.csv")
if os.path.exists(log_file):
with open(log_file, "r", encoding="utf-8") as f:
reader = csv.DictReader(f)
for row in reader:
episode_rewards.append(float(row["reward"]))
episode_throughputs.append(float(row["throughput"]))
episode_mean_speeds.append(float(row["mean_speed"]))
if row["policy_loss"]:
policy_losses.append(float(row["policy_loss"]))
if row["value_loss"]:
value_losses.append(float(row["value_loss"]))
best_reward = max(episode_rewards) if episode_rewards else -float("inf")
print(f"已加载 {len(episode_rewards)} 条历史记录")
print("开始训练...\n")
try:

View File

@ -95,6 +95,22 @@ def train_sumo_dqn(resume_checkpoint=None):
losses = []
best_reward = -float("inf")
# 加载历史训练数据
if resume_checkpoint:
import csv
log_file = os.path.join(log_dir, "dqn_training_log.csv")
if os.path.exists(log_file):
with open(log_file, "r", encoding="utf-8") as f:
reader = csv.DictReader(f)
for row in reader:
episode_rewards.append(float(row["reward"]))
episode_throughputs.append(float(row["throughput"]))
episode_mean_speeds.append(float(row["mean_speed"]))
if row["value_loss"]:
losses.append(float(row["value_loss"]))
best_reward = max(episode_rewards) if episode_rewards else -float("inf")
print(f"已加载 {len(episode_rewards)} 条历史记录")
print("开始训练...\n")
try:

View File

@ -104,6 +104,24 @@ def train_sumo_ppo(resume_checkpoint=None):
entropies = []
best_reward = -float("inf")
# 加载历史训练数据
if resume_checkpoint:
import csv
log_file = os.path.join(log_dir, "ppo_training_log.csv")
if os.path.exists(log_file):
with open(log_file, "r", encoding="utf-8") as f:
reader = csv.DictReader(f)
for row in reader:
episode_rewards.append(float(row["reward"]))
episode_throughputs.append(float(row["throughput"]))
episode_mean_speeds.append(float(row["mean_speed"]))
if row["policy_loss"]:
policy_losses.append(float(row["policy_loss"]))
if row["value_loss"]:
value_losses.append(float(row["value_loss"]))
best_reward = max(episode_rewards) if episode_rewards else -float("inf")
print(f"已加载 {len(episode_rewards)} 条历史记录")
print("开始训练...\n")
try: