Solved the Original Repo Errors
This commit is contained in:
parent
d92490b808
commit
5537ea44ba
|
|
@ -0,0 +1,17 @@
|
||||||
|
Quick Run:(All based on METR-LA)
|
||||||
|
1. pip install -r requirements.txt
|
||||||
|
2. Download the data from https://drive.google.com/open?id=10FOTa6HXPqX8Pf5WRoRwcFnW9BrNZEIX
|
||||||
|
3. Put the data into data/ for making the training data
|
||||||
|
4. Create data directories
|
||||||
|
>> mkdir -p data/{METR-LA,PEMS-BAY}
|
||||||
|
5. generate train/test/val dataset at data/{METR-LA,PEMS-BAY}/{train,val,test}.npz
|
||||||
|
>> python -m scripts.generate_training_data --output_dir=data/METR-LA --traffic_df_filename=data/metr-la.h5
|
||||||
|
6. Constructing the Graph
|
||||||
|
>> python -m scripts.gen_adj_mx --sensor_ids_filename=data/sensor_graph/graph_sensor_ids.txt --normalized_k=0.1\
|
||||||
|
--output_pkl_filename=data/sensor_graph/adj_mx.pkl
|
||||||
|
7. Run the pre-trained model:
|
||||||
|
>> python run_demo_pytorch.py --config_filename=data/model/pretrained/METR-LA/config.yaml
|
||||||
|
8. Model Training
|
||||||
|
>> python dcrnn_train_pytorch.py --config_filename=data/model/dcrnn_la.yaml
|
||||||
|
9. Evaluating the baselines:
|
||||||
|
>> python -m scripts.eval_baseline_methods --traffic_reading_filename=data/metr-la.h5
|
||||||
Binary file not shown.
|
|
@ -21,7 +21,7 @@ model:
|
||||||
train:
|
train:
|
||||||
base_lr: 0.01
|
base_lr: 0.01
|
||||||
dropout: 0
|
dropout: 0
|
||||||
epoch: 64
|
epoch: 0
|
||||||
epochs: 100
|
epochs: 100
|
||||||
epsilon: 0.001
|
epsilon: 0.001
|
||||||
global_step: 24375
|
global_step: 24375
|
||||||
|
|
|
||||||
|
|
@ -10,7 +10,7 @@ from model.pytorch.dcrnn_supervisor import DCRNNSupervisor
|
||||||
|
|
||||||
def run_dcrnn(args):
|
def run_dcrnn(args):
|
||||||
with open(args.config_filename) as f:
|
with open(args.config_filename) as f:
|
||||||
supervisor_config = yaml.load(f)
|
supervisor_config = yaml.load(f, Loader=yaml.Loader)
|
||||||
|
|
||||||
graph_pkl_filename = supervisor_config['data'].get('graph_pkl_filename')
|
graph_pkl_filename = supervisor_config['data'].get('graph_pkl_filename')
|
||||||
sensor_ids, sensor_id_to_ind, adj_mx = load_graph_data(graph_pkl_filename)
|
sensor_ids, sensor_id_to_ind, adj_mx = load_graph_data(graph_pkl_filename)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue