手动适配sb3
This commit is contained in:
parent
b48caa9a05
commit
fe60757f0b
|
|
@ -8,6 +8,8 @@ from stable_baselines3.common.noise import NormalActionNoise
|
||||||
import gymnasium as gym
|
import gymnasium as gym
|
||||||
from gymnasium import spaces
|
from gymnasium import spaces
|
||||||
|
|
||||||
|
from utils.sb3_manual import ensure_manual_logger, sync_manual_timesteps
|
||||||
|
|
||||||
|
|
||||||
class MultiDiscreteWrapper(gym.Env):
|
class MultiDiscreteWrapper(gym.Env):
|
||||||
def __init__(self, state_dim, action_dims):
|
def __init__(self, state_dim, action_dims):
|
||||||
|
|
@ -67,6 +69,7 @@ class DDPGAgent:
|
||||||
device=device,
|
device=device,
|
||||||
verbose=0,
|
verbose=0,
|
||||||
)
|
)
|
||||||
|
ensure_manual_logger(self.model)
|
||||||
|
|
||||||
def select_action(self, state: np.ndarray, deterministic: bool = False):
|
def select_action(self, state: np.ndarray, deterministic: bool = False):
|
||||||
if not deterministic and self.total_steps < self.learning_starts:
|
if not deterministic and self.total_steps < self.learning_starts:
|
||||||
|
|
@ -88,6 +91,7 @@ class DDPGAgent:
|
||||||
|
|
||||||
def store_transition(self, state, action, reward, next_state, done):
|
def store_transition(self, state, action, reward, next_state, done):
|
||||||
self.total_steps += 1
|
self.total_steps += 1
|
||||||
|
sync_manual_timesteps(self.model, self.total_steps)
|
||||||
continuous_action = np.array([
|
continuous_action = np.array([
|
||||||
action[i] / (self.action_dims[i] - 1)
|
action[i] / (self.action_dims[i] - 1)
|
||||||
for i in range(self.num_zones)
|
for i in range(self.num_zones)
|
||||||
|
|
@ -105,3 +109,5 @@ class DDPGAgent:
|
||||||
|
|
||||||
def load(self, path: str):
|
def load(self, path: str):
|
||||||
self.model = DDPG.load(path, device=self.device)
|
self.model = DDPG.load(path, device=self.device)
|
||||||
|
ensure_manual_logger(self.model)
|
||||||
|
self.total_steps = int(getattr(self.model, "num_timesteps", 0))
|
||||||
|
|
|
||||||
|
|
@ -15,6 +15,8 @@ from stable_baselines3 import TD3
|
||||||
from stable_baselines3.common.noise import NormalActionNoise
|
from stable_baselines3.common.noise import NormalActionNoise
|
||||||
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor
|
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor
|
||||||
|
|
||||||
|
from utils.sb3_manual import ensure_manual_logger, sync_manual_timesteps
|
||||||
|
|
||||||
|
|
||||||
class MultiDiscreteWrapper(gym.Env):
|
class MultiDiscreteWrapper(gym.Env):
|
||||||
"""Wrap a MultiDiscrete action space as a continuous Box for TD3."""
|
"""Wrap a MultiDiscrete action space as a continuous Box for TD3."""
|
||||||
|
|
@ -289,6 +291,7 @@ class SCTD3Agent:
|
||||||
verbose=0,
|
verbose=0,
|
||||||
policy_kwargs=policy_kwargs,
|
policy_kwargs=policy_kwargs,
|
||||||
)
|
)
|
||||||
|
ensure_manual_logger(self.model)
|
||||||
|
|
||||||
def select_action(self, state: np.ndarray, deterministic: bool = False):
|
def select_action(self, state: np.ndarray, deterministic: bool = False):
|
||||||
if not deterministic and self.total_steps < self.learning_starts:
|
if not deterministic and self.total_steps < self.learning_starts:
|
||||||
|
|
@ -314,6 +317,7 @@ class SCTD3Agent:
|
||||||
|
|
||||||
def store_transition(self, state, action, reward, next_state, done):
|
def store_transition(self, state, action, reward, next_state, done):
|
||||||
self.total_steps += 1
|
self.total_steps += 1
|
||||||
|
sync_manual_timesteps(self.model, self.total_steps)
|
||||||
continuous_action = np.array(
|
continuous_action = np.array(
|
||||||
[action[i] / (self.action_dims[i] - 1) for i in range(self.num_zones)],
|
[action[i] / (self.action_dims[i] - 1) for i in range(self.num_zones)],
|
||||||
dtype=np.float32,
|
dtype=np.float32,
|
||||||
|
|
@ -333,3 +337,5 @@ class SCTD3Agent:
|
||||||
|
|
||||||
def load(self, path: str):
|
def load(self, path: str):
|
||||||
self.model = TD3.load(path, device=self.device)
|
self.model = TD3.load(path, device=self.device)
|
||||||
|
ensure_manual_logger(self.model)
|
||||||
|
self.total_steps = int(getattr(self.model, "num_timesteps", 0))
|
||||||
|
|
|
||||||
|
|
@ -14,6 +14,8 @@ from stable_baselines3 import TD3
|
||||||
from stable_baselines3.common.noise import NormalActionNoise
|
from stable_baselines3.common.noise import NormalActionNoise
|
||||||
from stable_baselines3.common.torch_layers import FlattenExtractor
|
from stable_baselines3.common.torch_layers import FlattenExtractor
|
||||||
|
|
||||||
|
from utils.sb3_manual import ensure_manual_logger, sync_manual_timesteps
|
||||||
|
|
||||||
|
|
||||||
class MultiDiscreteWrapper(gym.Env):
|
class MultiDiscreteWrapper(gym.Env):
|
||||||
"""Wrap a MultiDiscrete action space as a continuous Box for TD3."""
|
"""Wrap a MultiDiscrete action space as a continuous Box for TD3."""
|
||||||
|
|
@ -113,6 +115,7 @@ class TD3Agent:
|
||||||
verbose=0,
|
verbose=0,
|
||||||
policy_kwargs=policy_kwargs,
|
policy_kwargs=policy_kwargs,
|
||||||
)
|
)
|
||||||
|
ensure_manual_logger(self.model)
|
||||||
|
|
||||||
def select_action(self, state: np.ndarray, deterministic: bool = False):
|
def select_action(self, state: np.ndarray, deterministic: bool = False):
|
||||||
if not deterministic and self.total_steps < self.learning_starts:
|
if not deterministic and self.total_steps < self.learning_starts:
|
||||||
|
|
@ -138,6 +141,7 @@ class TD3Agent:
|
||||||
|
|
||||||
def store_transition(self, state, action, reward, next_state, done):
|
def store_transition(self, state, action, reward, next_state, done):
|
||||||
self.total_steps += 1
|
self.total_steps += 1
|
||||||
|
sync_manual_timesteps(self.model, self.total_steps)
|
||||||
continuous_action = np.array(
|
continuous_action = np.array(
|
||||||
[action[i] / (self.action_dims[i] - 1) for i in range(self.num_zones)],
|
[action[i] / (self.action_dims[i] - 1) for i in range(self.num_zones)],
|
||||||
dtype=np.float32,
|
dtype=np.float32,
|
||||||
|
|
@ -157,3 +161,5 @@ class TD3Agent:
|
||||||
|
|
||||||
def load(self, path: str):
|
def load(self, path: str):
|
||||||
self.model = TD3.load(path, device=self.device)
|
self.model = TD3.load(path, device=self.device)
|
||||||
|
ensure_manual_logger(self.model)
|
||||||
|
self.total_steps = int(getattr(self.model, "num_timesteps", 0))
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,15 @@
|
||||||
|
"""Compatibility helpers for manual Stable-Baselines3 off-policy training."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from stable_baselines3.common.logger import Logger
|
||||||
|
|
||||||
|
|
||||||
|
def ensure_manual_logger(model) -> None:
|
||||||
|
"""Attach a no-op logger when training without calling ``model.learn()``."""
|
||||||
|
if not hasattr(model, "_logger"):
|
||||||
|
model.set_logger(Logger(folder=None, output_formats=[]))
|
||||||
|
|
||||||
|
|
||||||
|
def sync_manual_timesteps(model, total_steps: int) -> None:
|
||||||
|
"""Keep core SB3 counters roughly aligned with externally driven training."""
|
||||||
|
model.num_timesteps = int(total_steps)
|
||||||
Loading…
Reference in New Issue