diff --git a/train_sumo_appo_v2.py b/train_sumo_appo_v2.py index d705f7c..5b67d08 100644 --- a/train_sumo_appo_v2.py +++ b/train_sumo_appo_v2.py @@ -100,6 +100,24 @@ def train_sumo_appo(resume_checkpoint=None): entropies = [] best_reward = -float("inf") + # 加载历史训练数据 + if resume_checkpoint: + import csv + log_file = os.path.join(log_dir, "appo_training_log.csv") + if os.path.exists(log_file): + with open(log_file, "r", encoding="utf-8") as f: + reader = csv.DictReader(f) + for row in reader: + episode_rewards.append(float(row["reward"])) + episode_throughputs.append(float(row["throughput"])) + episode_mean_speeds.append(float(row["mean_speed"])) + if row["policy_loss"]: + policy_losses.append(float(row["policy_loss"])) + if row["value_loss"]: + value_losses.append(float(row["value_loss"])) + best_reward = max(episode_rewards) if episode_rewards else -float("inf") + print(f"已加载 {len(episode_rewards)} 条历史记录") + print("开始训练...\n") try: diff --git a/train_sumo_dqn.py b/train_sumo_dqn.py index 02692f0..12a8a38 100644 --- a/train_sumo_dqn.py +++ b/train_sumo_dqn.py @@ -95,6 +95,22 @@ def train_sumo_dqn(resume_checkpoint=None): losses = [] best_reward = -float("inf") + # 加载历史训练数据 + if resume_checkpoint: + import csv + log_file = os.path.join(log_dir, "dqn_training_log.csv") + if os.path.exists(log_file): + with open(log_file, "r", encoding="utf-8") as f: + reader = csv.DictReader(f) + for row in reader: + episode_rewards.append(float(row["reward"])) + episode_throughputs.append(float(row["throughput"])) + episode_mean_speeds.append(float(row["mean_speed"])) + if row["value_loss"]: + losses.append(float(row["value_loss"])) + best_reward = max(episode_rewards) if episode_rewards else -float("inf") + print(f"已加载 {len(episode_rewards)} 条历史记录") + print("开始训练...\n") try: diff --git a/train_sumo_ppo.py b/train_sumo_ppo.py index 961fdaf..365a427 100644 --- a/train_sumo_ppo.py +++ b/train_sumo_ppo.py @@ -104,6 +104,24 @@ def train_sumo_ppo(resume_checkpoint=None): entropies = [] best_reward = -float("inf") + # 加载历史训练数据 + if resume_checkpoint: + import csv + log_file = os.path.join(log_dir, "ppo_training_log.csv") + if os.path.exists(log_file): + with open(log_file, "r", encoding="utf-8") as f: + reader = csv.DictReader(f) + for row in reader: + episode_rewards.append(float(row["reward"])) + episode_throughputs.append(float(row["throughput"])) + episode_mean_speeds.append(float(row["mean_speed"])) + if row["policy_loss"]: + policy_losses.append(float(row["policy_loss"])) + if row["value_loss"]: + value_losses.append(float(row["value_loss"])) + best_reward = max(episode_rewards) if episode_rewards else -float("inf") + print(f"已加载 {len(episode_rewards)} 条历史记录") + print("开始训练...\n") try: