ctm-dqn/training_logger.py

33 lines
1.4 KiB
Python

"""训练日志CSV记录器"""
import csv
import os
class TrainingLogger:
def __init__(self, log_dir, model_name, resume=False):
self.log_path = os.path.join(log_dir, f"{model_name}_training_log.csv")
self.fieldnames = [
"episode", "reward", "throughput", "mean_speed", "hard_brakes",
"policy_loss", "value_loss", "entropy"
]
if not resume or not os.path.exists(self.log_path):
with open(self.log_path, "w", newline="", encoding="utf-8") as f:
writer = csv.DictWriter(f, fieldnames=self.fieldnames)
writer.writeheader()
def log(self, episode, reward, throughput, mean_speed, hard_brakes=0,
policy_loss=None, value_loss=None, entropy=None):
with open(self.log_path, "a", newline="", encoding="utf-8") as f:
writer = csv.DictWriter(f, fieldnames=self.fieldnames)
writer.writerow({
"episode": episode,
"reward": f"{reward:.4f}",
"throughput": f"{throughput:.2f}",
"mean_speed": f"{mean_speed:.2f}",
"hard_brakes": f"{hard_brakes:.0f}",
"policy_loss": f"{policy_loss:.6f}" if policy_loss is not None else "",
"value_loss": f"{value_loss:.6f}" if value_loss is not None else "",
"entropy": f"{entropy:.6f}" if entropy is not None else ""
})