16 lines
579 B
Python
16 lines
579 B
Python
"""Compatibility helpers for manual Stable-Baselines3 off-policy training."""
|
|
from __future__ import annotations
|
|
|
|
from stable_baselines3.common.logger import Logger
|
|
|
|
|
|
def ensure_manual_logger(model) -> None:
|
|
"""Attach a no-op logger when training without calling ``model.learn()``."""
|
|
if not hasattr(model, "_logger"):
|
|
model.set_logger(Logger(folder=None, output_formats=[]))
|
|
|
|
|
|
def sync_manual_timesteps(model, total_steps: int) -> None:
|
|
"""Keep core SB3 counters roughly aligned with externally driven training."""
|
|
model.num_timesteps = int(total_steps)
|