添加随机种子,自动保存最优模型

This commit is contained in:
Zihan Ye 2026-01-05 17:04:27 +08:00
parent fd0dc80a91
commit 39b0134609
2 changed files with 33 additions and 1 deletions

View File

@ -54,6 +54,7 @@ training:
log_freq: 10 # Logging frequency (episodes) log_freq: 10 # Logging frequency (episodes)
checkpoint_dir: "checkpoints" # Checkpoint directory checkpoint_dir: "checkpoints" # Checkpoint directory
log_dir: "logs" # Log directory log_dir: "logs" # Log directory
random_seed: 42 # Random seed for reproducibility
# Testing Parameters # Testing Parameters
testing: testing:

View File

@ -2,7 +2,9 @@
Training script for DQN-based speed limit control. Training script for DQN-based speed limit control.
""" """
import os import os
import random
import numpy as np import numpy as np
import torch
from tqdm import tqdm from tqdm import tqdm
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
from utils import load_config, create_directories from utils import load_config, create_directories
@ -10,11 +12,28 @@ from environment import TrafficEnvironment
from dqn_agent import DQNAgent from dqn_agent import DQNAgent
def set_random_seed(seed: int):
"""Set random seed for reproducibility."""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
def train(config_path: str = "config.yaml"): def train(config_path: str = "config.yaml"):
"""Train DQN agent.""" """Train DQN agent."""
config = load_config(config_path) config = load_config(config_path)
create_directories(config) create_directories(config)
# Set random seed for reproducibility
random_seed = config["training"].get("random_seed", 42)
set_random_seed(random_seed)
print(f"Random seed set to: {random_seed}")
env = TrafficEnvironment(config) env = TrafficEnvironment(config)
agent = DQNAgent( agent = DQNAgent(
state_dim=env.state_dim, state_dim=env.state_dim,
@ -40,6 +59,10 @@ def train(config_path: str = "config.yaml"):
episode_losses = [] episode_losses = []
episode_throughputs = [] episode_throughputs = []
# Track best model
best_reward = float('-inf')
best_model_path = os.path.join(checkpoint_dir, "model_best.pt")
print(f"Starting training for {num_episodes} episodes...") print(f"Starting training for {num_episodes} episodes...")
print(f"State dim: {env.state_dim}, Action dim: {env.action_dim}") print(f"State dim: {env.state_dim}, Action dim: {env.action_dim}")
print(f"Device: {agent.device}") print(f"Device: {agent.device}")
@ -74,6 +97,12 @@ def train(config_path: str = "config.yaml"):
avg_throughput = np.mean([m["throughput"] for m in env.episode_metrics]) avg_throughput = np.mean([m["throughput"] for m in env.episode_metrics])
episode_throughputs.append(avg_throughput) episode_throughputs.append(avg_throughput)
# Save best model based on episode reward
if episode_reward > best_reward:
best_reward = episode_reward
agent.save(best_model_path)
print(f"\n*** New best model saved! Episode {episode + 1}, Reward: {episode_reward:.2f} ***")
if (episode + 1) % log_freq == 0: if (episode + 1) % log_freq == 0:
avg_reward = np.mean(episode_rewards[-log_freq:]) avg_reward = np.mean(episode_rewards[-log_freq:])
avg_loss = np.mean(episode_losses[-log_freq:]) avg_loss = np.mean(episode_losses[-log_freq:])
@ -95,7 +124,9 @@ def train(config_path: str = "config.yaml"):
final_model_path = os.path.join(checkpoint_dir, "model_final.pt") final_model_path = os.path.join(checkpoint_dir, "model_final.pt")
agent.save(final_model_path) agent.save(final_model_path)
print(f"\nTraining completed! Final model saved to {final_model_path}") print(f"\nTraining completed!")
print(f"Final model saved to {final_model_path}")
print(f"Best model saved to {best_model_path} (Best Reward: {best_reward:.2f})")
plot_training_results( plot_training_results(
episode_rewards, episode_losses, episode_throughputs, config["training"]["log_dir"] episode_rewards, episode_losses, episode_throughputs, config["training"]["log_dir"]