From fe60757f0b8ac7925c72c6ce258b754031671d77 Mon Sep 17 00:00:00 2001 From: Maple-YZ Date: Fri, 10 Apr 2026 00:21:05 +0800 Subject: [PATCH] =?UTF-8?q?=E6=89=8B=E5=8A=A8=E9=80=82=E9=85=8Dsb3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- agents/ddpg_agent.py | 6 ++++++ agents/sctd3_agent.py | 6 ++++++ agents/td3_agent.py | 6 ++++++ utils/sb3_manual.py | 15 +++++++++++++++ 4 files changed, 33 insertions(+) create mode 100644 utils/sb3_manual.py diff --git a/agents/ddpg_agent.py b/agents/ddpg_agent.py index 4e3a798..7e36cf1 100644 --- a/agents/ddpg_agent.py +++ b/agents/ddpg_agent.py @@ -8,6 +8,8 @@ from stable_baselines3.common.noise import NormalActionNoise import gymnasium as gym from gymnasium import spaces +from utils.sb3_manual import ensure_manual_logger, sync_manual_timesteps + class MultiDiscreteWrapper(gym.Env): def __init__(self, state_dim, action_dims): @@ -67,6 +69,7 @@ class DDPGAgent: device=device, verbose=0, ) + ensure_manual_logger(self.model) def select_action(self, state: np.ndarray, deterministic: bool = False): if not deterministic and self.total_steps < self.learning_starts: @@ -88,6 +91,7 @@ class DDPGAgent: def store_transition(self, state, action, reward, next_state, done): self.total_steps += 1 + sync_manual_timesteps(self.model, self.total_steps) continuous_action = np.array([ action[i] / (self.action_dims[i] - 1) for i in range(self.num_zones) @@ -105,3 +109,5 @@ class DDPGAgent: def load(self, path: str): self.model = DDPG.load(path, device=self.device) + ensure_manual_logger(self.model) + self.total_steps = int(getattr(self.model, "num_timesteps", 0)) diff --git a/agents/sctd3_agent.py b/agents/sctd3_agent.py index ab0f22f..079ac8d 100644 --- a/agents/sctd3_agent.py +++ b/agents/sctd3_agent.py @@ -15,6 +15,8 @@ from stable_baselines3 import TD3 from stable_baselines3.common.noise import NormalActionNoise from stable_baselines3.common.torch_layers import BaseFeaturesExtractor +from utils.sb3_manual import ensure_manual_logger, sync_manual_timesteps + class MultiDiscreteWrapper(gym.Env): """Wrap a MultiDiscrete action space as a continuous Box for TD3.""" @@ -289,6 +291,7 @@ class SCTD3Agent: verbose=0, policy_kwargs=policy_kwargs, ) + ensure_manual_logger(self.model) def select_action(self, state: np.ndarray, deterministic: bool = False): if not deterministic and self.total_steps < self.learning_starts: @@ -314,6 +317,7 @@ class SCTD3Agent: def store_transition(self, state, action, reward, next_state, done): self.total_steps += 1 + sync_manual_timesteps(self.model, self.total_steps) continuous_action = np.array( [action[i] / (self.action_dims[i] - 1) for i in range(self.num_zones)], dtype=np.float32, @@ -333,3 +337,5 @@ class SCTD3Agent: def load(self, path: str): self.model = TD3.load(path, device=self.device) + ensure_manual_logger(self.model) + self.total_steps = int(getattr(self.model, "num_timesteps", 0)) diff --git a/agents/td3_agent.py b/agents/td3_agent.py index 9ffbe9d..231b6f7 100644 --- a/agents/td3_agent.py +++ b/agents/td3_agent.py @@ -14,6 +14,8 @@ from stable_baselines3 import TD3 from stable_baselines3.common.noise import NormalActionNoise from stable_baselines3.common.torch_layers import FlattenExtractor +from utils.sb3_manual import ensure_manual_logger, sync_manual_timesteps + class MultiDiscreteWrapper(gym.Env): """Wrap a MultiDiscrete action space as a continuous Box for TD3.""" @@ -113,6 +115,7 @@ class TD3Agent: verbose=0, policy_kwargs=policy_kwargs, ) + ensure_manual_logger(self.model) def select_action(self, state: np.ndarray, deterministic: bool = False): if not deterministic and self.total_steps < self.learning_starts: @@ -138,6 +141,7 @@ class TD3Agent: def store_transition(self, state, action, reward, next_state, done): self.total_steps += 1 + sync_manual_timesteps(self.model, self.total_steps) continuous_action = np.array( [action[i] / (self.action_dims[i] - 1) for i in range(self.num_zones)], dtype=np.float32, @@ -157,3 +161,5 @@ class TD3Agent: def load(self, path: str): self.model = TD3.load(path, device=self.device) + ensure_manual_logger(self.model) + self.total_steps = int(getattr(self.model, "num_timesteps", 0)) diff --git a/utils/sb3_manual.py b/utils/sb3_manual.py new file mode 100644 index 0000000..bbc420e --- /dev/null +++ b/utils/sb3_manual.py @@ -0,0 +1,15 @@ +"""Compatibility helpers for manual Stable-Baselines3 off-policy training.""" +from __future__ import annotations + +from stable_baselines3.common.logger import Logger + + +def ensure_manual_logger(model) -> None: + """Attach a no-op logger when training without calling ``model.learn()``.""" + if not hasattr(model, "_logger"): + model.set_logger(Logger(folder=None, output_formats=[])) + + +def sync_manual_timesteps(model, total_steps: int) -> None: + """Keep core SB3 counters roughly aligned with externally driven training.""" + model.num_timesteps = int(total_steps)