ctm-dqn/utils/config.py

28 lines
983 B
Python

import copy
def get_agent_config(config: dict, model_name: str) -> dict:
"""Return model-specific agent config with backward compatibility."""
agents_cfg = config.get("agents")
if isinstance(agents_cfg, dict):
common_cfg = agents_cfg.get("common", {})
model_cfg = agents_cfg.get(model_name, {})
merged = copy.deepcopy(common_cfg)
merged.update(model_cfg)
return merged
legacy_cfg = config.get("agent", {})
return copy.deepcopy(legacy_cfg)
def get_training_config(config: dict) -> dict:
"""Normalize training keys across old/new config formats."""
training_cfg = copy.deepcopy(config.get("training", {}))
if "log_freq" not in training_cfg and "log_interval" in training_cfg:
training_cfg["log_freq"] = training_cfg["log_interval"]
if "save_freq" not in training_cfg and "save_interval" in training_cfg:
training_cfg["save_freq"] = training_cfg["save_interval"]
return training_cfg