恢复训练时同步恢复历史记录以便绘图
This commit is contained in:
parent
8377658995
commit
999430422d
|
|
@ -100,6 +100,24 @@ def train_sumo_appo(resume_checkpoint=None):
|
|||
entropies = []
|
||||
best_reward = -float("inf")
|
||||
|
||||
# 加载历史训练数据
|
||||
if resume_checkpoint:
|
||||
import csv
|
||||
log_file = os.path.join(log_dir, "appo_training_log.csv")
|
||||
if os.path.exists(log_file):
|
||||
with open(log_file, "r", encoding="utf-8") as f:
|
||||
reader = csv.DictReader(f)
|
||||
for row in reader:
|
||||
episode_rewards.append(float(row["reward"]))
|
||||
episode_throughputs.append(float(row["throughput"]))
|
||||
episode_mean_speeds.append(float(row["mean_speed"]))
|
||||
if row["policy_loss"]:
|
||||
policy_losses.append(float(row["policy_loss"]))
|
||||
if row["value_loss"]:
|
||||
value_losses.append(float(row["value_loss"]))
|
||||
best_reward = max(episode_rewards) if episode_rewards else -float("inf")
|
||||
print(f"已加载 {len(episode_rewards)} 条历史记录")
|
||||
|
||||
print("开始训练...\n")
|
||||
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -95,6 +95,22 @@ def train_sumo_dqn(resume_checkpoint=None):
|
|||
losses = []
|
||||
best_reward = -float("inf")
|
||||
|
||||
# 加载历史训练数据
|
||||
if resume_checkpoint:
|
||||
import csv
|
||||
log_file = os.path.join(log_dir, "dqn_training_log.csv")
|
||||
if os.path.exists(log_file):
|
||||
with open(log_file, "r", encoding="utf-8") as f:
|
||||
reader = csv.DictReader(f)
|
||||
for row in reader:
|
||||
episode_rewards.append(float(row["reward"]))
|
||||
episode_throughputs.append(float(row["throughput"]))
|
||||
episode_mean_speeds.append(float(row["mean_speed"]))
|
||||
if row["value_loss"]:
|
||||
losses.append(float(row["value_loss"]))
|
||||
best_reward = max(episode_rewards) if episode_rewards else -float("inf")
|
||||
print(f"已加载 {len(episode_rewards)} 条历史记录")
|
||||
|
||||
print("开始训练...\n")
|
||||
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -104,6 +104,24 @@ def train_sumo_ppo(resume_checkpoint=None):
|
|||
entropies = []
|
||||
best_reward = -float("inf")
|
||||
|
||||
# 加载历史训练数据
|
||||
if resume_checkpoint:
|
||||
import csv
|
||||
log_file = os.path.join(log_dir, "ppo_training_log.csv")
|
||||
if os.path.exists(log_file):
|
||||
with open(log_file, "r", encoding="utf-8") as f:
|
||||
reader = csv.DictReader(f)
|
||||
for row in reader:
|
||||
episode_rewards.append(float(row["reward"]))
|
||||
episode_throughputs.append(float(row["throughput"]))
|
||||
episode_mean_speeds.append(float(row["mean_speed"]))
|
||||
if row["policy_loss"]:
|
||||
policy_losses.append(float(row["policy_loss"]))
|
||||
if row["value_loss"]:
|
||||
value_losses.append(float(row["value_loss"]))
|
||||
best_reward = max(episode_rewards) if episode_rewards else -float("inf")
|
||||
print(f"已加载 {len(episode_rewards)} 条历史记录")
|
||||
|
||||
print("开始训练...\n")
|
||||
|
||||
try:
|
||||
|
|
|
|||
Loading…
Reference in New Issue