Description
Go to file
Zihan Ye 39b0134609 添加随机种子,自动保存最优模型 2026-01-05 17:04:27 +08:00
.gitignore First runnable 2026-01-05 16:15:53 +08:00
.python-version First runnable 2026-01-05 16:15:53 +08:00
README.md First runnable 2026-01-05 16:15:53 +08:00
config.yaml 添加随机种子,自动保存最优模型 2026-01-05 17:04:27 +08:00
ctm_model.py First runnable 2026-01-05 16:15:53 +08:00
dqn_agent.py First runnable 2026-01-05 16:15:53 +08:00
environment.py First runnable 2026-01-05 16:15:53 +08:00
main.py First runnable 2026-01-05 16:15:53 +08:00
pyproject.toml First runnable 2026-01-05 16:15:53 +08:00
test.py First runnable 2026-01-05 16:15:53 +08:00
train.py 添加随机种子,自动保存最优模型 2026-01-05 17:04:27 +08:00
utils.py First runnable 2026-01-05 16:15:53 +08:00
uv.lock First runnable 2026-01-05 16:15:53 +08:00

README.md

CTM-DQN: Dynamic Speed Limit Control System

A Deep Q-Network (DQN) based dynamic speed limit control system using the Cell Transmission Model (CTM) for traffic flow simulation.

Project Structure

ctm/
├── config.yaml          # Configuration file for all parameters
├── main.py             # Main entry point
├── ctm_model.py        # Cell Transmission Model implementation
├── dqn_agent.py        # DQN agent with replay buffer
├── environment.py      # Training environment
├── train.py            # Training script
├── test.py             # Testing/evaluation script
├── utils.py            # Utility functions
├── checkpoints/        # Saved model checkpoints (created automatically)
└── logs/              # Training logs and plots (created automatically)

Features

  • CTM Traffic Model: Realistic highway traffic flow simulation
  • DQN Agent: Deep reinforcement learning for speed limit control
  • Flexible Configuration: Easy parameter adjustment via YAML config
  • Training & Testing: Separate modes for model training and evaluation
  • Visualization: Automatic plotting of training results and traffic patterns
  • Checkpointing: Regular model saving during training

Installation

  1. Install dependencies using uv:
uv sync

Or manually install:

pip install torch numpy matplotlib pyyaml tqdm

Quick Start

Training

Train the DQN agent with default configuration:

python main.py --mode train

Train with custom configuration:

python main.py --mode train --config custom_config.yaml

Testing

Test the trained model:

python main.py --mode test

Test with specific model checkpoint:

python main.py --mode test --model checkpoints/model_episode_500.pt

Configuration

All parameters can be adjusted in config.yaml. Key configuration sections:

Environment Parameters

  • num_cells: Number of road cells (default: 10)
  • cell_length: Length of each cell in meters (default: 500.0)
  • free_flow_speed: Free flow speed in m/s (default: 30.0)
  • demand_pattern: Traffic demand pattern - "constant", "sine", or "random"
  • num_speed_actions: Number of discrete speed limit actions (default: 5)

DQN Agent Parameters

  • hidden_layers: Neural network architecture (default: [128, 128])
  • learning_rate: Learning rate (default: 0.0001)
  • gamma: Discount factor (default: 0.99)
  • epsilon_start/end/decay: Exploration parameters
  • buffer_size: Replay buffer capacity (default: 50000)
  • batch_size: Training batch size (default: 64)

Training Parameters

  • num_episodes: Number of training episodes (default: 500)
  • save_freq: Model checkpoint frequency (default: 50)
  • log_freq: Logging frequency (default: 10)

Reward Function Weights

  • throughput_weight: Weight for throughput reward (default: 1.0)
  • speed_weight: Weight for average speed reward (default: 0.5)
  • density_weight: Weight for density penalty (default: -0.3)
  • action_change_weight: Weight for action change penalty (default: -0.1)

Customization Guide

Changing Traffic Scenarios

Edit the environment parameters in config.yaml:

environment:
  demand_pattern: "sine"  # Change to "constant" or "random"
  demand_mean: 2000.0     # Adjust traffic demand
  num_cells: 15           # Increase road length

Modifying DQN Architecture

Adjust the neural network structure:

agent:
  hidden_layers: [256, 256, 128]  # Deeper network
  learning_rate: 0.0005            # Faster learning
  gamma: 0.95                      # Different discount factor

Tuning Reward Function

Balance different objectives by adjusting weights:

reward:
  throughput_weight: 2.0      # Prioritize throughput
  speed_weight: 1.0           # Increase speed importance
  density_weight: -0.5        # Stronger density penalty
  action_change_weight: -0.2  # Discourage frequent changes

Output Files

After training and testing, the following files will be generated:

  • checkpoints/model_episode_*.pt: Model checkpoints saved during training
  • checkpoints/model_final.pt: Final trained model
  • logs/training_results.png: Training curves (rewards, loss, throughput)
  • logs/test_results.png: Test visualization (density heatmap and speed control)

Model Architecture

The DQN agent uses:

  • State: Concatenation of traffic densities and speed limits for all cells
  • Action: Discrete speed limit values (uniformly distributed between min and max)
  • Network: Fully connected layers with ReLU activation
  • Training: Experience replay + target network for stable learning

CTM Model

The Cell Transmission Model simulates traffic flow based on:

  • Sending flow: Limited by density and speed limit
  • Receiving flow: Limited by downstream capacity
  • Conservation: Vehicles are conserved across cell boundaries
  • Fundamental diagram: Relationship between density, flow, and speed

Example Workflow

  1. Adjust configuration for your scenario in config.yaml
  2. Train the model: python main.py --mode train
  3. Monitor progress in console output and logs/training_results.png
  4. Test the model: python main.py --mode test
  5. Analyze results in logs/test_results.png
  6. Iterate: Adjust parameters and retrain as needed

Troubleshooting

  • CUDA out of memory: Reduce batch_size or use device: "cpu" in config
  • Slow training: Reduce num_episodes or episode_length
  • Poor performance: Adjust reward weights or increase network capacity
  • Unstable training: Reduce learning_rate or increase target_update_freq