40 lines
1.9 KiB
Python
40 lines
1.9 KiB
Python
"""训练日志CSV记录器"""
|
|
import csv
|
|
import os
|
|
|
|
|
|
class TrainingLogger:
|
|
def __init__(self, log_dir, model_name, resume=False):
|
|
self.log_path = os.path.join(log_dir, f"{model_name}_training_log.csv")
|
|
self.fieldnames = [
|
|
"episode", "reward", "throughput", "mean_speed", "speed_std",
|
|
"r_flow", "r_var", "r_brake", "r_penalty", "hard_brakes",
|
|
"policy_loss", "value_loss", "entropy"
|
|
]
|
|
|
|
if not resume or not os.path.exists(self.log_path):
|
|
with open(self.log_path, "w", newline="", encoding="utf-8") as f:
|
|
writer = csv.DictWriter(f, fieldnames=self.fieldnames)
|
|
writer.writeheader()
|
|
|
|
def log(self, episode, reward, throughput, mean_speed, speed_std=None,
|
|
r_flow=None, r_var=None, r_brake=None, r_penalty=None, hard_brakes=0,
|
|
policy_loss=None, value_loss=None, entropy=None):
|
|
with open(self.log_path, "a", newline="", encoding="utf-8") as f:
|
|
writer = csv.DictWriter(f, fieldnames=self.fieldnames)
|
|
writer.writerow({
|
|
"episode": episode,
|
|
"reward": f"{reward:.4f}",
|
|
"throughput": f"{throughput:.2f}",
|
|
"mean_speed": f"{mean_speed:.2f}",
|
|
"speed_std": f"{speed_std:.4f}" if speed_std is not None else "",
|
|
"r_flow": f"{r_flow:.6f}" if r_flow is not None else "",
|
|
"r_var": f"{r_var:.6f}" if r_var is not None else "",
|
|
"r_brake": f"{r_brake:.6f}" if r_brake is not None else "",
|
|
"r_penalty": f"{r_penalty:.6f}" if r_penalty is not None else "",
|
|
"hard_brakes": f"{hard_brakes:.0f}",
|
|
"policy_loss": f"{policy_loss:.6f}" if policy_loss is not None else "",
|
|
"value_loss": f"{value_loss:.6f}" if value_loss is not None else "",
|
|
"entropy": f"{entropy:.6f}" if entropy is not None else ""
|
|
})
|