ctm-dqn/utils/logger.py

66 lines
2.2 KiB
Python

"""Training CSV logger."""
import csv
import os
from typing import Mapping, Optional
from envs.reward_system import REWARD_COMPONENT_COLUMNS
class TrainingLogger:
def __init__(self, log_dir, model_name, resume=False):
self.log_path = os.path.join(log_dir, f"{model_name}_training_log.csv")
self.fieldnames = [
"episode",
"reward",
"throughput",
"mean_speed",
"speed_variance_norm",
*REWARD_COMPONENT_COLUMNS,
"hard_brakes",
"policy_loss",
"value_loss",
"entropy",
]
if not resume or not os.path.exists(self.log_path):
with open(self.log_path, "w", newline="", encoding="utf-8") as f:
writer = csv.DictWriter(f, fieldnames=self.fieldnames)
writer.writeheader()
def log(
self,
episode,
reward,
throughput,
mean_speed,
*,
speed_variance_norm: Optional[float] = None,
reward_components: Optional[Mapping[str, float]] = None,
hard_brakes=0,
policy_loss: Optional[float] = None,
value_loss: Optional[float] = None,
entropy: Optional[float] = None,
):
reward_components = dict(reward_components or {})
row = {
"episode": episode,
"reward": f"{reward:.4f}",
"throughput": f"{throughput:.2f}",
"mean_speed": f"{mean_speed:.2f}",
"speed_variance_norm": (
f"{speed_variance_norm:.6f}" if speed_variance_norm is not None else ""
),
"hard_brakes": f"{hard_brakes:.0f}",
"policy_loss": f"{policy_loss:.6f}" if policy_loss is not None else "",
"value_loss": f"{value_loss:.6f}" if value_loss is not None else "",
"entropy": f"{entropy:.6f}" if entropy is not None else "",
}
for column in REWARD_COMPONENT_COLUMNS:
value = reward_components.get(column)
row[column] = f"{value:.6f}" if value is not None else ""
with open(self.log_path, "a", newline="", encoding="utf-8") as f:
writer = csv.DictWriter(f, fieldnames=self.fieldnames)
writer.writerow(row)