添加随机种子,自动保存最优模型
This commit is contained in:
parent
fd0dc80a91
commit
39b0134609
|
|
@ -54,6 +54,7 @@ training:
|
||||||
log_freq: 10 # Logging frequency (episodes)
|
log_freq: 10 # Logging frequency (episodes)
|
||||||
checkpoint_dir: "checkpoints" # Checkpoint directory
|
checkpoint_dir: "checkpoints" # Checkpoint directory
|
||||||
log_dir: "logs" # Log directory
|
log_dir: "logs" # Log directory
|
||||||
|
random_seed: 42 # Random seed for reproducibility
|
||||||
|
|
||||||
# Testing Parameters
|
# Testing Parameters
|
||||||
testing:
|
testing:
|
||||||
|
|
|
||||||
33
train.py
33
train.py
|
|
@ -2,7 +2,9 @@
|
||||||
Training script for DQN-based speed limit control.
|
Training script for DQN-based speed limit control.
|
||||||
"""
|
"""
|
||||||
import os
|
import os
|
||||||
|
import random
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import torch
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
from utils import load_config, create_directories
|
from utils import load_config, create_directories
|
||||||
|
|
@ -10,11 +12,28 @@ from environment import TrafficEnvironment
|
||||||
from dqn_agent import DQNAgent
|
from dqn_agent import DQNAgent
|
||||||
|
|
||||||
|
|
||||||
|
def set_random_seed(seed: int):
|
||||||
|
"""Set random seed for reproducibility."""
|
||||||
|
random.seed(seed)
|
||||||
|
np.random.seed(seed)
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.manual_seed(seed)
|
||||||
|
torch.cuda.manual_seed_all(seed)
|
||||||
|
torch.backends.cudnn.deterministic = True
|
||||||
|
torch.backends.cudnn.benchmark = False
|
||||||
|
|
||||||
|
|
||||||
def train(config_path: str = "config.yaml"):
|
def train(config_path: str = "config.yaml"):
|
||||||
"""Train DQN agent."""
|
"""Train DQN agent."""
|
||||||
config = load_config(config_path)
|
config = load_config(config_path)
|
||||||
create_directories(config)
|
create_directories(config)
|
||||||
|
|
||||||
|
# Set random seed for reproducibility
|
||||||
|
random_seed = config["training"].get("random_seed", 42)
|
||||||
|
set_random_seed(random_seed)
|
||||||
|
print(f"Random seed set to: {random_seed}")
|
||||||
|
|
||||||
env = TrafficEnvironment(config)
|
env = TrafficEnvironment(config)
|
||||||
agent = DQNAgent(
|
agent = DQNAgent(
|
||||||
state_dim=env.state_dim,
|
state_dim=env.state_dim,
|
||||||
|
|
@ -40,6 +59,10 @@ def train(config_path: str = "config.yaml"):
|
||||||
episode_losses = []
|
episode_losses = []
|
||||||
episode_throughputs = []
|
episode_throughputs = []
|
||||||
|
|
||||||
|
# Track best model
|
||||||
|
best_reward = float('-inf')
|
||||||
|
best_model_path = os.path.join(checkpoint_dir, "model_best.pt")
|
||||||
|
|
||||||
print(f"Starting training for {num_episodes} episodes...")
|
print(f"Starting training for {num_episodes} episodes...")
|
||||||
print(f"State dim: {env.state_dim}, Action dim: {env.action_dim}")
|
print(f"State dim: {env.state_dim}, Action dim: {env.action_dim}")
|
||||||
print(f"Device: {agent.device}")
|
print(f"Device: {agent.device}")
|
||||||
|
|
@ -74,6 +97,12 @@ def train(config_path: str = "config.yaml"):
|
||||||
avg_throughput = np.mean([m["throughput"] for m in env.episode_metrics])
|
avg_throughput = np.mean([m["throughput"] for m in env.episode_metrics])
|
||||||
episode_throughputs.append(avg_throughput)
|
episode_throughputs.append(avg_throughput)
|
||||||
|
|
||||||
|
# Save best model based on episode reward
|
||||||
|
if episode_reward > best_reward:
|
||||||
|
best_reward = episode_reward
|
||||||
|
agent.save(best_model_path)
|
||||||
|
print(f"\n*** New best model saved! Episode {episode + 1}, Reward: {episode_reward:.2f} ***")
|
||||||
|
|
||||||
if (episode + 1) % log_freq == 0:
|
if (episode + 1) % log_freq == 0:
|
||||||
avg_reward = np.mean(episode_rewards[-log_freq:])
|
avg_reward = np.mean(episode_rewards[-log_freq:])
|
||||||
avg_loss = np.mean(episode_losses[-log_freq:])
|
avg_loss = np.mean(episode_losses[-log_freq:])
|
||||||
|
|
@ -95,7 +124,9 @@ def train(config_path: str = "config.yaml"):
|
||||||
|
|
||||||
final_model_path = os.path.join(checkpoint_dir, "model_final.pt")
|
final_model_path = os.path.join(checkpoint_dir, "model_final.pt")
|
||||||
agent.save(final_model_path)
|
agent.save(final_model_path)
|
||||||
print(f"\nTraining completed! Final model saved to {final_model_path}")
|
print(f"\nTraining completed!")
|
||||||
|
print(f"Final model saved to {final_model_path}")
|
||||||
|
print(f"Best model saved to {best_model_path} (Best Reward: {best_reward:.2f})")
|
||||||
|
|
||||||
plot_training_results(
|
plot_training_results(
|
||||||
episode_rewards, episode_losses, episode_throughputs, config["training"]["log_dir"]
|
episode_rewards, episode_losses, episode_throughputs, config["training"]["log_dir"]
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue