手动适配sb3

This commit is contained in:
Zihan Ye 2026-04-10 00:21:05 +08:00
parent b48caa9a05
commit fe60757f0b
4 changed files with 33 additions and 0 deletions

View File

@ -8,6 +8,8 @@ from stable_baselines3.common.noise import NormalActionNoise
import gymnasium as gym
from gymnasium import spaces
from utils.sb3_manual import ensure_manual_logger, sync_manual_timesteps
class MultiDiscreteWrapper(gym.Env):
def __init__(self, state_dim, action_dims):
@ -67,6 +69,7 @@ class DDPGAgent:
device=device,
verbose=0,
)
ensure_manual_logger(self.model)
def select_action(self, state: np.ndarray, deterministic: bool = False):
if not deterministic and self.total_steps < self.learning_starts:
@ -88,6 +91,7 @@ class DDPGAgent:
def store_transition(self, state, action, reward, next_state, done):
self.total_steps += 1
sync_manual_timesteps(self.model, self.total_steps)
continuous_action = np.array([
action[i] / (self.action_dims[i] - 1)
for i in range(self.num_zones)
@ -105,3 +109,5 @@ class DDPGAgent:
def load(self, path: str):
self.model = DDPG.load(path, device=self.device)
ensure_manual_logger(self.model)
self.total_steps = int(getattr(self.model, "num_timesteps", 0))

View File

@ -15,6 +15,8 @@ from stable_baselines3 import TD3
from stable_baselines3.common.noise import NormalActionNoise
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor
from utils.sb3_manual import ensure_manual_logger, sync_manual_timesteps
class MultiDiscreteWrapper(gym.Env):
"""Wrap a MultiDiscrete action space as a continuous Box for TD3."""
@ -289,6 +291,7 @@ class SCTD3Agent:
verbose=0,
policy_kwargs=policy_kwargs,
)
ensure_manual_logger(self.model)
def select_action(self, state: np.ndarray, deterministic: bool = False):
if not deterministic and self.total_steps < self.learning_starts:
@ -314,6 +317,7 @@ class SCTD3Agent:
def store_transition(self, state, action, reward, next_state, done):
self.total_steps += 1
sync_manual_timesteps(self.model, self.total_steps)
continuous_action = np.array(
[action[i] / (self.action_dims[i] - 1) for i in range(self.num_zones)],
dtype=np.float32,
@ -333,3 +337,5 @@ class SCTD3Agent:
def load(self, path: str):
self.model = TD3.load(path, device=self.device)
ensure_manual_logger(self.model)
self.total_steps = int(getattr(self.model, "num_timesteps", 0))

View File

@ -14,6 +14,8 @@ from stable_baselines3 import TD3
from stable_baselines3.common.noise import NormalActionNoise
from stable_baselines3.common.torch_layers import FlattenExtractor
from utils.sb3_manual import ensure_manual_logger, sync_manual_timesteps
class MultiDiscreteWrapper(gym.Env):
"""Wrap a MultiDiscrete action space as a continuous Box for TD3."""
@ -113,6 +115,7 @@ class TD3Agent:
verbose=0,
policy_kwargs=policy_kwargs,
)
ensure_manual_logger(self.model)
def select_action(self, state: np.ndarray, deterministic: bool = False):
if not deterministic and self.total_steps < self.learning_starts:
@ -138,6 +141,7 @@ class TD3Agent:
def store_transition(self, state, action, reward, next_state, done):
self.total_steps += 1
sync_manual_timesteps(self.model, self.total_steps)
continuous_action = np.array(
[action[i] / (self.action_dims[i] - 1) for i in range(self.num_zones)],
dtype=np.float32,
@ -157,3 +161,5 @@ class TD3Agent:
def load(self, path: str):
self.model = TD3.load(path, device=self.device)
ensure_manual_logger(self.model)
self.total_steps = int(getattr(self.model, "num_timesteps", 0))

15
utils/sb3_manual.py Normal file
View File

@ -0,0 +1,15 @@
"""Compatibility helpers for manual Stable-Baselines3 off-policy training."""
from __future__ import annotations
from stable_baselines3.common.logger import Logger
def ensure_manual_logger(model) -> None:
"""Attach a no-op logger when training without calling ``model.learn()``."""
if not hasattr(model, "_logger"):
model.set_logger(Logger(folder=None, output_formats=[]))
def sync_manual_timesteps(model, total_steps: int) -> None:
"""Keep core SB3 counters roughly aligned with externally driven training."""
model.num_timesteps = int(total_steps)