ctm-dqn/README.md

306 lines
8.3 KiB
Markdown
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# CTM-DQN: 基于深度强化学习的动态限速控制系统
基于深度Q网络(DQN)和元胞传输模型(CTM)的高速公路动态限速控制系统。
> 声明:本 Readme 由AI辅助生成进行了人工校对并调整了小部分内容
## 项目结构
```
ctm/
├── config.yaml # 配置文件
├── main.py # 主入口
├── ctm_model.py # 元胞传输模型实现
├── dqn_agent.py # DQN智能体与经验回放
├── environment.py # 训练环境
├── demand_loader.py # CSV流量数据加载器
├── train.py # 训练脚本
├── test.py # 测试/评估脚本
├── utils.py # 工具函数
├── latest/ # 最新训练运行目录(自动创建)
│ ├── checkpoints/ # 模型检查点
│ ├── logs/ # 训练日志和图表
│ └── config.yaml # 训练时使用的配置
└── runs/ # 历史训练运行归档(自动创建)
```
## 主要特性
- **CTM交通模型**: 真实的高速公路交通流仿真
- **DQN智能体**: 基于深度强化学习的限速控制
- **灵活配置**: 通过YAML配置文件轻松调整参数
- **训练与测试**: 独立的训练和评估模式
- **Baseline对比**: 自动对比无控制策略的性能
- **可视化**: 自动生成训练结果和交通模式图表
- **检查点保存**: 训练过程中定期保存模型
- **CSV流量输入**: 支持从CSV文件读取真实交通流量数据
- **随机种子固定**: 确保训练结果可重复
- **最佳模型保存**: 自动保存效果最好的模型
- **运行管理**: 自动管理训练运行最新运行保存在latest/历史运行归档到runs/
## 安装
1. 使用 uv 安装依赖:
```bash
uv sync
```
或手动安装:
```bash
pip install torch numpy matplotlib pyyaml tqdm pandas
```
## 快速开始
### 训练
使用默认配置训练DQN智能体
```bash
python main.py --mode train
```
使用自定义配置训练:
```bash
python main.py --mode train --config custom_config.yaml
```
### 测试
测试训练好的模型:
```bash
python main.py --mode test
```
使用特定检查点测试:
```bash
python main.py --mode test --model checkpoints/model_best.pt
```
## 配置说明
所有参数都可以在 `config.yaml` 中调整。主要配置部分:
### 环境参数
- `num_cells`: 道路单元数量默认10
- `cell_length`: 每个单元长度单位米默认500.0
- `free_flow_speed`: 自由流速度单位m/s默认30.0
- `demand_pattern`: 交通需求模式 - "constant"、"sine"、"random"、"csv"
- `demand_csv_path`: CSV流量文件路径当使用csv模式时
- `num_speed_actions`: 离散限速动作数量默认5
- `episode_length`: 每个episode的时间步数默认360
### DQN智能体参数
- `hidden_layers`: 神经网络架构(默认:[256, 256]
- `learning_rate`: 学习率默认0.0001
- `gamma`: 折扣因子默认0.99
- `epsilon_start/end/decay`: 探索参数
- `buffer_size`: 经验回放缓冲区容量默认100000
- `batch_size`: 训练批量大小默认128
- `target_update_freq`: 目标网络更新频率默认10
### 训练参数
- `num_episodes`: 训练episode数量默认500
- `save_freq`: 模型检查点保存频率默认50
- `log_freq`: 日志记录频率默认10
- `random_seed`: 随机种子默认42
- `train_freq`: 训练频率每N步训练一次默认1
### 奖励函数权重
- `throughput_weight`: 通行量奖励权重默认1.0
- `speed_weight`: 平均速度奖励权重默认0.5
- `density_weight`: 密度惩罚权重(默认:-0.3
- `action_change_weight`: 动作变化惩罚权重(默认:-0.1
---
## 高级功能
### 1. CSV流量输入
系统支持从CSV文件读取真实交通流量数据。
#### CSV文件格式
**格式1单列格式推荐**
```csv
demand
1500
1600
1700
1800
```
**格式2带时间列**
```csv
time,demand
0,1500
10,1600
20,1700
```
**注意事项:**
- 流量单位:车辆/小时 (vehicles/hour)
- 每行代表一个时间步的流量需求
- 流量值必须为非负数
- 如果episode长度超过CSV数据长度数据会循环使用
#### 配置方法
`config.yaml` 中设置:
```yaml
environment:
demand_pattern: "csv"
demand_csv_path: "demand_example.csv"
demand_csv_column: "demand"
```
#### 使用示例
使用提供的示例文件:
```yaml
environment:
demand_pattern: "csv"
demand_csv_path: "demand_example.csv"
```
使用自定义CSV文件
```yaml
environment:
demand_pattern: "csv"
demand_csv_path: "data/my_traffic_data.csv"
demand_csv_column: "flow"
```
切换回内置流量模式:
```yaml
environment:
demand_pattern: "sine" # 或 "constant"、"random"
```
---
## 模型架构
### DQN智能体
- **状态**: 所有单元的交通密度和限速值的拼接
- **动作**: 离散的限速值(在最小和最大限速之间均匀分布)
- **网络**: 全连接层 + ReLU激活函数
- **训练**: 经验回放 + 目标网络,实现稳定学习
### CTM模型
元胞传输模型基于以下原理模拟交通流:
- **发送流**: 受密度和限速限制
- **接收流**: 受下游容量限制
- **守恒性**: 车辆在单元边界守恒
- **基本图**: 密度、流量和速度之间的关系
---
## 输出文件
训练和测试后会生成以下文件:
- `checkpoints/model_episode_*.pt`: 训练过程中保存的模型检查点
- `checkpoints/model_best.pt`: 效果最好的模型基于episode奖励
- `checkpoints/model_final.pt`: 最终训练的模型
- `logs/training_results.png`: 训练曲线(奖励、损失、通行量)
- `logs/test_results.png`: 测试可视化(密度热图和限速控制)
---
## 使用示例工作流
1. **调整配置**: 根据你的场景修改 `config.yaml`
2. **训练模型**: `python main.py --mode train`
3. **监控进度**: 查看控制台输出和 `logs/training_results.png`
4. **测试模型**: `python main.py --mode test`
5. **分析结果**: 查看 `logs/test_results.png`
6. **迭代优化**: 根据需要调整参数并重新训练
---
## 故障排查
### 训练相关
**问题CUDA内存不足**
```yaml
agent:
batch_size: 64 # 减小批量大小
device: "cpu" # 或切换到CPU
```
**问题:训练速度慢**
```yaml
training:
num_episodes: 200 # 减少episode数量
train_freq: 5 # 降低训练频率
environment:
episode_length: 180 # 减少episode长度
```
**问题:训练不稳定**
```yaml
agent:
learning_rate: 0.00005 # 降低学习率
target_update_freq: 20 # 增加目标网络更新频率
```
**问题:性能不佳**
```yaml
reward:
throughput_weight: 2.0 # 调整奖励权重
speed_weight: 1.0
agent:
hidden_layers: [256, 256, 128] # 增加网络容量
```
### CSV流量输入相关
**问题找不到CSV文件**
- 检查文件路径是否正确(相对路径或绝对路径)
- 确保文件存在于指定位置
**问题CSV数据格式错误**
- 确保CSV文件包含指定的列名
- 检查数据是否为非负数
- 验证CSV文件编码建议使用UTF-8
---
## 技术细节
### 运行管理
系统自动管理训练运行,保持目录整洁:
- 最新训练保存在 `latest/` 目录
- 每次新训练开始时,旧的 `latest/` 会自动移动到 `runs/run_YYYYMMDD_HHMMSS/`
- 每个运行包含:配置文件、模型检查点、训练日志和图表
- 测试时自动使用 `latest/checkpoints/model_best.pt``latest/config.yaml`
### Baseline 对比测试
测试模式会自动运行 baseline 对比实验:
- Baseline 使用固定的最大速度限制(无动态控制)
- 自动计算 DQN 相对于 baseline 的性能提升百分比
- 对比指标包括:平均奖励、平均吞吐量
- 帮助评估 DQN 控制策略的实际效果
### 随机种子固定
系统在训练开始时会固定所有随机种子Python、NumPy、PyTorch确保
- 训练结果可重复
- 便于调试和对比实验
- 默认种子值为42可在配置文件中修改
### 最佳模型保存
训练过程中会自动跟踪并保存效果最好的模型:
- 基于episode总奖励评估
- 保存为 `latest/checkpoints/model_best.pt`
- 训练结束时显示最佳奖励值
- 训练完成后自动使用最佳模型进行测试
---
## 许可证
本项目仅供学术研究使用。