From 999430422dbb5b14ebf8c33ff51e3df801785550 Mon Sep 17 00:00:00 2001 From: Maple-YZ Date: Wed, 25 Mar 2026 13:54:47 +0800 Subject: [PATCH] =?UTF-8?q?=E6=81=A2=E5=A4=8D=E8=AE=AD=E7=BB=83=E6=97=B6?= =?UTF-8?q?=E5=90=8C=E6=AD=A5=E6=81=A2=E5=A4=8D=E5=8E=86=E5=8F=B2=E8=AE=B0?= =?UTF-8?q?=E5=BD=95=E4=BB=A5=E4=BE=BF=E7=BB=98=E5=9B=BE?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- train_sumo_appo_v2.py | 18 ++++++++++++++++++ train_sumo_dqn.py | 16 ++++++++++++++++ train_sumo_ppo.py | 18 ++++++++++++++++++ 3 files changed, 52 insertions(+) diff --git a/train_sumo_appo_v2.py b/train_sumo_appo_v2.py index d705f7c..5b67d08 100644 --- a/train_sumo_appo_v2.py +++ b/train_sumo_appo_v2.py @@ -100,6 +100,24 @@ def train_sumo_appo(resume_checkpoint=None): entropies = [] best_reward = -float("inf") + # 加载历史训练数据 + if resume_checkpoint: + import csv + log_file = os.path.join(log_dir, "appo_training_log.csv") + if os.path.exists(log_file): + with open(log_file, "r", encoding="utf-8") as f: + reader = csv.DictReader(f) + for row in reader: + episode_rewards.append(float(row["reward"])) + episode_throughputs.append(float(row["throughput"])) + episode_mean_speeds.append(float(row["mean_speed"])) + if row["policy_loss"]: + policy_losses.append(float(row["policy_loss"])) + if row["value_loss"]: + value_losses.append(float(row["value_loss"])) + best_reward = max(episode_rewards) if episode_rewards else -float("inf") + print(f"已加载 {len(episode_rewards)} 条历史记录") + print("开始训练...\n") try: diff --git a/train_sumo_dqn.py b/train_sumo_dqn.py index 02692f0..12a8a38 100644 --- a/train_sumo_dqn.py +++ b/train_sumo_dqn.py @@ -95,6 +95,22 @@ def train_sumo_dqn(resume_checkpoint=None): losses = [] best_reward = -float("inf") + # 加载历史训练数据 + if resume_checkpoint: + import csv + log_file = os.path.join(log_dir, "dqn_training_log.csv") + if os.path.exists(log_file): + with open(log_file, "r", encoding="utf-8") as f: + reader = csv.DictReader(f) + for row in reader: + episode_rewards.append(float(row["reward"])) + episode_throughputs.append(float(row["throughput"])) + episode_mean_speeds.append(float(row["mean_speed"])) + if row["value_loss"]: + losses.append(float(row["value_loss"])) + best_reward = max(episode_rewards) if episode_rewards else -float("inf") + print(f"已加载 {len(episode_rewards)} 条历史记录") + print("开始训练...\n") try: diff --git a/train_sumo_ppo.py b/train_sumo_ppo.py index 961fdaf..365a427 100644 --- a/train_sumo_ppo.py +++ b/train_sumo_ppo.py @@ -104,6 +104,24 @@ def train_sumo_ppo(resume_checkpoint=None): entropies = [] best_reward = -float("inf") + # 加载历史训练数据 + if resume_checkpoint: + import csv + log_file = os.path.join(log_dir, "ppo_training_log.csv") + if os.path.exists(log_file): + with open(log_file, "r", encoding="utf-8") as f: + reader = csv.DictReader(f) + for row in reader: + episode_rewards.append(float(row["reward"])) + episode_throughputs.append(float(row["throughput"])) + episode_mean_speeds.append(float(row["mean_speed"])) + if row["policy_loss"]: + policy_losses.append(float(row["policy_loss"])) + if row["value_loss"]: + value_losses.append(float(row["value_loss"])) + best_reward = max(episode_rewards) if episode_rewards else -float("inf") + print(f"已加载 {len(episode_rewards)} 条历史记录") + print("开始训练...\n") try: