28 lines
983 B
Python
28 lines
983 B
Python
import copy
|
|
|
|
|
|
def get_agent_config(config: dict, model_name: str) -> dict:
|
|
"""Return model-specific agent config with backward compatibility."""
|
|
agents_cfg = config.get("agents")
|
|
if isinstance(agents_cfg, dict):
|
|
common_cfg = agents_cfg.get("common", {})
|
|
model_cfg = agents_cfg.get(model_name, {})
|
|
merged = copy.deepcopy(common_cfg)
|
|
merged.update(model_cfg)
|
|
return merged
|
|
|
|
legacy_cfg = config.get("agent", {})
|
|
return copy.deepcopy(legacy_cfg)
|
|
|
|
|
|
def get_training_config(config: dict) -> dict:
|
|
"""Normalize training keys across old/new config formats."""
|
|
training_cfg = copy.deepcopy(config.get("training", {}))
|
|
|
|
if "log_freq" not in training_cfg and "log_interval" in training_cfg:
|
|
training_cfg["log_freq"] = training_cfg["log_interval"]
|
|
if "save_freq" not in training_cfg and "save_interval" in training_cfg:
|
|
training_cfg["save_freq"] = training_cfg["save_interval"]
|
|
|
|
return training_cfg
|