手动适配sb3
This commit is contained in:
parent
b48caa9a05
commit
fe60757f0b
|
|
@ -8,6 +8,8 @@ from stable_baselines3.common.noise import NormalActionNoise
|
|||
import gymnasium as gym
|
||||
from gymnasium import spaces
|
||||
|
||||
from utils.sb3_manual import ensure_manual_logger, sync_manual_timesteps
|
||||
|
||||
|
||||
class MultiDiscreteWrapper(gym.Env):
|
||||
def __init__(self, state_dim, action_dims):
|
||||
|
|
@ -67,6 +69,7 @@ class DDPGAgent:
|
|||
device=device,
|
||||
verbose=0,
|
||||
)
|
||||
ensure_manual_logger(self.model)
|
||||
|
||||
def select_action(self, state: np.ndarray, deterministic: bool = False):
|
||||
if not deterministic and self.total_steps < self.learning_starts:
|
||||
|
|
@ -88,6 +91,7 @@ class DDPGAgent:
|
|||
|
||||
def store_transition(self, state, action, reward, next_state, done):
|
||||
self.total_steps += 1
|
||||
sync_manual_timesteps(self.model, self.total_steps)
|
||||
continuous_action = np.array([
|
||||
action[i] / (self.action_dims[i] - 1)
|
||||
for i in range(self.num_zones)
|
||||
|
|
@ -105,3 +109,5 @@ class DDPGAgent:
|
|||
|
||||
def load(self, path: str):
|
||||
self.model = DDPG.load(path, device=self.device)
|
||||
ensure_manual_logger(self.model)
|
||||
self.total_steps = int(getattr(self.model, "num_timesteps", 0))
|
||||
|
|
|
|||
|
|
@ -15,6 +15,8 @@ from stable_baselines3 import TD3
|
|||
from stable_baselines3.common.noise import NormalActionNoise
|
||||
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor
|
||||
|
||||
from utils.sb3_manual import ensure_manual_logger, sync_manual_timesteps
|
||||
|
||||
|
||||
class MultiDiscreteWrapper(gym.Env):
|
||||
"""Wrap a MultiDiscrete action space as a continuous Box for TD3."""
|
||||
|
|
@ -289,6 +291,7 @@ class SCTD3Agent:
|
|||
verbose=0,
|
||||
policy_kwargs=policy_kwargs,
|
||||
)
|
||||
ensure_manual_logger(self.model)
|
||||
|
||||
def select_action(self, state: np.ndarray, deterministic: bool = False):
|
||||
if not deterministic and self.total_steps < self.learning_starts:
|
||||
|
|
@ -314,6 +317,7 @@ class SCTD3Agent:
|
|||
|
||||
def store_transition(self, state, action, reward, next_state, done):
|
||||
self.total_steps += 1
|
||||
sync_manual_timesteps(self.model, self.total_steps)
|
||||
continuous_action = np.array(
|
||||
[action[i] / (self.action_dims[i] - 1) for i in range(self.num_zones)],
|
||||
dtype=np.float32,
|
||||
|
|
@ -333,3 +337,5 @@ class SCTD3Agent:
|
|||
|
||||
def load(self, path: str):
|
||||
self.model = TD3.load(path, device=self.device)
|
||||
ensure_manual_logger(self.model)
|
||||
self.total_steps = int(getattr(self.model, "num_timesteps", 0))
|
||||
|
|
|
|||
|
|
@ -14,6 +14,8 @@ from stable_baselines3 import TD3
|
|||
from stable_baselines3.common.noise import NormalActionNoise
|
||||
from stable_baselines3.common.torch_layers import FlattenExtractor
|
||||
|
||||
from utils.sb3_manual import ensure_manual_logger, sync_manual_timesteps
|
||||
|
||||
|
||||
class MultiDiscreteWrapper(gym.Env):
|
||||
"""Wrap a MultiDiscrete action space as a continuous Box for TD3."""
|
||||
|
|
@ -113,6 +115,7 @@ class TD3Agent:
|
|||
verbose=0,
|
||||
policy_kwargs=policy_kwargs,
|
||||
)
|
||||
ensure_manual_logger(self.model)
|
||||
|
||||
def select_action(self, state: np.ndarray, deterministic: bool = False):
|
||||
if not deterministic and self.total_steps < self.learning_starts:
|
||||
|
|
@ -138,6 +141,7 @@ class TD3Agent:
|
|||
|
||||
def store_transition(self, state, action, reward, next_state, done):
|
||||
self.total_steps += 1
|
||||
sync_manual_timesteps(self.model, self.total_steps)
|
||||
continuous_action = np.array(
|
||||
[action[i] / (self.action_dims[i] - 1) for i in range(self.num_zones)],
|
||||
dtype=np.float32,
|
||||
|
|
@ -157,3 +161,5 @@ class TD3Agent:
|
|||
|
||||
def load(self, path: str):
|
||||
self.model = TD3.load(path, device=self.device)
|
||||
ensure_manual_logger(self.model)
|
||||
self.total_steps = int(getattr(self.model, "num_timesteps", 0))
|
||||
|
|
|
|||
|
|
@ -0,0 +1,15 @@
|
|||
"""Compatibility helpers for manual Stable-Baselines3 off-policy training."""
|
||||
from __future__ import annotations
|
||||
|
||||
from stable_baselines3.common.logger import Logger
|
||||
|
||||
|
||||
def ensure_manual_logger(model) -> None:
|
||||
"""Attach a no-op logger when training without calling ``model.learn()``."""
|
||||
if not hasattr(model, "_logger"):
|
||||
model.set_logger(Logger(folder=None, output_formats=[]))
|
||||
|
||||
|
||||
def sync_manual_timesteps(model, total_steps: int) -> None:
|
||||
"""Keep core SB3 counters roughly aligned with externally driven training."""
|
||||
model.num_timesteps = int(total_steps)
|
||||
Loading…
Reference in New Issue