ctm-dqn/utils/sb3_manual.py

16 lines
579 B
Python

"""Compatibility helpers for manual Stable-Baselines3 off-policy training."""
from __future__ import annotations
from stable_baselines3.common.logger import Logger
def ensure_manual_logger(model) -> None:
"""Attach a no-op logger when training without calling ``model.learn()``."""
if not hasattr(model, "_logger"):
model.set_logger(Logger(folder=None, output_formats=[]))
def sync_manual_timesteps(model, total_steps: int) -> None:
"""Keep core SB3 counters roughly aligned with externally driven training."""
model.num_timesteps = int(total_steps)