66 lines
2.2 KiB
Python
66 lines
2.2 KiB
Python
"""Training CSV logger."""
|
|
|
|
import csv
|
|
import os
|
|
from typing import Mapping, Optional
|
|
|
|
from envs.reward_system import REWARD_COMPONENT_COLUMNS
|
|
|
|
|
|
class TrainingLogger:
|
|
def __init__(self, log_dir, model_name, resume=False):
|
|
self.log_path = os.path.join(log_dir, f"{model_name}_training_log.csv")
|
|
self.fieldnames = [
|
|
"episode",
|
|
"reward",
|
|
"throughput",
|
|
"mean_speed",
|
|
"speed_variance_norm",
|
|
*REWARD_COMPONENT_COLUMNS,
|
|
"hard_brakes",
|
|
"policy_loss",
|
|
"value_loss",
|
|
"entropy",
|
|
]
|
|
|
|
if not resume or not os.path.exists(self.log_path):
|
|
with open(self.log_path, "w", newline="", encoding="utf-8") as f:
|
|
writer = csv.DictWriter(f, fieldnames=self.fieldnames)
|
|
writer.writeheader()
|
|
|
|
def log(
|
|
self,
|
|
episode,
|
|
reward,
|
|
throughput,
|
|
mean_speed,
|
|
*,
|
|
speed_variance_norm: Optional[float] = None,
|
|
reward_components: Optional[Mapping[str, float]] = None,
|
|
hard_brakes=0,
|
|
policy_loss: Optional[float] = None,
|
|
value_loss: Optional[float] = None,
|
|
entropy: Optional[float] = None,
|
|
):
|
|
reward_components = dict(reward_components or {})
|
|
row = {
|
|
"episode": episode,
|
|
"reward": f"{reward:.4f}",
|
|
"throughput": f"{throughput:.2f}",
|
|
"mean_speed": f"{mean_speed:.2f}",
|
|
"speed_variance_norm": (
|
|
f"{speed_variance_norm:.6f}" if speed_variance_norm is not None else ""
|
|
),
|
|
"hard_brakes": f"{hard_brakes:.0f}",
|
|
"policy_loss": f"{policy_loss:.6f}" if policy_loss is not None else "",
|
|
"value_loss": f"{value_loss:.6f}" if value_loss is not None else "",
|
|
"entropy": f"{entropy:.6f}" if entropy is not None else "",
|
|
}
|
|
for column in REWARD_COMPONENT_COLUMNS:
|
|
value = reward_components.get(column)
|
|
row[column] = f"{value:.6f}" if value is not None else ""
|
|
|
|
with open(self.log_path, "a", newline="", encoding="utf-8") as f:
|
|
writer = csv.DictWriter(f, fieldnames=self.fieldnames)
|
|
writer.writerow(row)
|