恢复训练时同步恢复历史记录以便绘图
This commit is contained in:
parent
8377658995
commit
999430422d
|
|
@ -100,6 +100,24 @@ def train_sumo_appo(resume_checkpoint=None):
|
||||||
entropies = []
|
entropies = []
|
||||||
best_reward = -float("inf")
|
best_reward = -float("inf")
|
||||||
|
|
||||||
|
# 加载历史训练数据
|
||||||
|
if resume_checkpoint:
|
||||||
|
import csv
|
||||||
|
log_file = os.path.join(log_dir, "appo_training_log.csv")
|
||||||
|
if os.path.exists(log_file):
|
||||||
|
with open(log_file, "r", encoding="utf-8") as f:
|
||||||
|
reader = csv.DictReader(f)
|
||||||
|
for row in reader:
|
||||||
|
episode_rewards.append(float(row["reward"]))
|
||||||
|
episode_throughputs.append(float(row["throughput"]))
|
||||||
|
episode_mean_speeds.append(float(row["mean_speed"]))
|
||||||
|
if row["policy_loss"]:
|
||||||
|
policy_losses.append(float(row["policy_loss"]))
|
||||||
|
if row["value_loss"]:
|
||||||
|
value_losses.append(float(row["value_loss"]))
|
||||||
|
best_reward = max(episode_rewards) if episode_rewards else -float("inf")
|
||||||
|
print(f"已加载 {len(episode_rewards)} 条历史记录")
|
||||||
|
|
||||||
print("开始训练...\n")
|
print("开始训练...\n")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
|
||||||
|
|
@ -95,6 +95,22 @@ def train_sumo_dqn(resume_checkpoint=None):
|
||||||
losses = []
|
losses = []
|
||||||
best_reward = -float("inf")
|
best_reward = -float("inf")
|
||||||
|
|
||||||
|
# 加载历史训练数据
|
||||||
|
if resume_checkpoint:
|
||||||
|
import csv
|
||||||
|
log_file = os.path.join(log_dir, "dqn_training_log.csv")
|
||||||
|
if os.path.exists(log_file):
|
||||||
|
with open(log_file, "r", encoding="utf-8") as f:
|
||||||
|
reader = csv.DictReader(f)
|
||||||
|
for row in reader:
|
||||||
|
episode_rewards.append(float(row["reward"]))
|
||||||
|
episode_throughputs.append(float(row["throughput"]))
|
||||||
|
episode_mean_speeds.append(float(row["mean_speed"]))
|
||||||
|
if row["value_loss"]:
|
||||||
|
losses.append(float(row["value_loss"]))
|
||||||
|
best_reward = max(episode_rewards) if episode_rewards else -float("inf")
|
||||||
|
print(f"已加载 {len(episode_rewards)} 条历史记录")
|
||||||
|
|
||||||
print("开始训练...\n")
|
print("开始训练...\n")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
|
||||||
|
|
@ -104,6 +104,24 @@ def train_sumo_ppo(resume_checkpoint=None):
|
||||||
entropies = []
|
entropies = []
|
||||||
best_reward = -float("inf")
|
best_reward = -float("inf")
|
||||||
|
|
||||||
|
# 加载历史训练数据
|
||||||
|
if resume_checkpoint:
|
||||||
|
import csv
|
||||||
|
log_file = os.path.join(log_dir, "ppo_training_log.csv")
|
||||||
|
if os.path.exists(log_file):
|
||||||
|
with open(log_file, "r", encoding="utf-8") as f:
|
||||||
|
reader = csv.DictReader(f)
|
||||||
|
for row in reader:
|
||||||
|
episode_rewards.append(float(row["reward"]))
|
||||||
|
episode_throughputs.append(float(row["throughput"]))
|
||||||
|
episode_mean_speeds.append(float(row["mean_speed"]))
|
||||||
|
if row["policy_loss"]:
|
||||||
|
policy_losses.append(float(row["policy_loss"]))
|
||||||
|
if row["value_loss"]:
|
||||||
|
value_losses.append(float(row["value_loss"]))
|
||||||
|
best_reward = max(episode_rewards) if episode_rewards else -float("inf")
|
||||||
|
print(f"已加载 {len(episode_rewards)} 条历史记录")
|
||||||
|
|
||||||
print("开始训练...\n")
|
print("开始训练...\n")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue