添加随机种子,自动保存最优模型
This commit is contained in:
parent
fd0dc80a91
commit
39b0134609
|
|
@ -54,6 +54,7 @@ training:
|
|||
log_freq: 10 # Logging frequency (episodes)
|
||||
checkpoint_dir: "checkpoints" # Checkpoint directory
|
||||
log_dir: "logs" # Log directory
|
||||
random_seed: 42 # Random seed for reproducibility
|
||||
|
||||
# Testing Parameters
|
||||
testing:
|
||||
|
|
|
|||
33
train.py
33
train.py
|
|
@ -2,7 +2,9 @@
|
|||
Training script for DQN-based speed limit control.
|
||||
"""
|
||||
import os
|
||||
import random
|
||||
import numpy as np
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
import matplotlib.pyplot as plt
|
||||
from utils import load_config, create_directories
|
||||
|
|
@ -10,11 +12,28 @@ from environment import TrafficEnvironment
|
|||
from dqn_agent import DQNAgent
|
||||
|
||||
|
||||
def set_random_seed(seed: int):
|
||||
"""Set random seed for reproducibility."""
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
torch.backends.cudnn.deterministic = True
|
||||
torch.backends.cudnn.benchmark = False
|
||||
|
||||
|
||||
def train(config_path: str = "config.yaml"):
|
||||
"""Train DQN agent."""
|
||||
config = load_config(config_path)
|
||||
create_directories(config)
|
||||
|
||||
# Set random seed for reproducibility
|
||||
random_seed = config["training"].get("random_seed", 42)
|
||||
set_random_seed(random_seed)
|
||||
print(f"Random seed set to: {random_seed}")
|
||||
|
||||
env = TrafficEnvironment(config)
|
||||
agent = DQNAgent(
|
||||
state_dim=env.state_dim,
|
||||
|
|
@ -40,6 +59,10 @@ def train(config_path: str = "config.yaml"):
|
|||
episode_losses = []
|
||||
episode_throughputs = []
|
||||
|
||||
# Track best model
|
||||
best_reward = float('-inf')
|
||||
best_model_path = os.path.join(checkpoint_dir, "model_best.pt")
|
||||
|
||||
print(f"Starting training for {num_episodes} episodes...")
|
||||
print(f"State dim: {env.state_dim}, Action dim: {env.action_dim}")
|
||||
print(f"Device: {agent.device}")
|
||||
|
|
@ -74,6 +97,12 @@ def train(config_path: str = "config.yaml"):
|
|||
avg_throughput = np.mean([m["throughput"] for m in env.episode_metrics])
|
||||
episode_throughputs.append(avg_throughput)
|
||||
|
||||
# Save best model based on episode reward
|
||||
if episode_reward > best_reward:
|
||||
best_reward = episode_reward
|
||||
agent.save(best_model_path)
|
||||
print(f"\n*** New best model saved! Episode {episode + 1}, Reward: {episode_reward:.2f} ***")
|
||||
|
||||
if (episode + 1) % log_freq == 0:
|
||||
avg_reward = np.mean(episode_rewards[-log_freq:])
|
||||
avg_loss = np.mean(episode_losses[-log_freq:])
|
||||
|
|
@ -95,7 +124,9 @@ def train(config_path: str = "config.yaml"):
|
|||
|
||||
final_model_path = os.path.join(checkpoint_dir, "model_final.pt")
|
||||
agent.save(final_model_path)
|
||||
print(f"\nTraining completed! Final model saved to {final_model_path}")
|
||||
print(f"\nTraining completed!")
|
||||
print(f"Final model saved to {final_model_path}")
|
||||
print(f"Best model saved to {best_model_path} (Best Reward: {best_reward:.2f})")
|
||||
|
||||
plot_training_results(
|
||||
episode_rewards, episode_losses, episode_throughputs, config["training"]["log_dir"]
|
||||
|
|
|
|||
Loading…
Reference in New Issue