diff --git a/How_to_Run.md b/How_to_Run.md new file mode 100644 index 0000000..caca6d9 --- /dev/null +++ b/How_to_Run.md @@ -0,0 +1,17 @@ +Quick Run:(All based on METR-LA) +1. pip install -r requirements.txt +2. Download the data from https://drive.google.com/open?id=10FOTa6HXPqX8Pf5WRoRwcFnW9BrNZEIX +3. Put the data into data/ for making the training data +4. Create data directories +>> mkdir -p data/{METR-LA,PEMS-BAY} +5. generate train/test/val dataset at data/{METR-LA,PEMS-BAY}/{train,val,test}.npz +>> python -m scripts.generate_training_data --output_dir=data/METR-LA --traffic_df_filename=data/metr-la.h5 +6. Constructing the Graph +>> python -m scripts.gen_adj_mx --sensor_ids_filename=data/sensor_graph/graph_sensor_ids.txt --normalized_k=0.1\ + --output_pkl_filename=data/sensor_graph/adj_mx.pkl +7. Run the pre-trained model: +>> python run_demo_pytorch.py --config_filename=data/model/pretrained/METR-LA/config.yaml +8. Model Training +>> python dcrnn_train_pytorch.py --config_filename=data/model/dcrnn_la.yaml +9. Evaluating the baselines: +>> python -m scripts.eval_baseline_methods --traffic_reading_filename=data/metr-la.h5 \ No newline at end of file diff --git a/data/metr-la.h5 b/data/metr-la.h5 new file mode 100644 index 0000000..7df0003 Binary files /dev/null and b/data/metr-la.h5 differ diff --git a/data/model/pretrained/METR-LA/config.yaml b/data/model/pretrained/METR-LA/config.yaml index b0227a8..157b51a 100644 --- a/data/model/pretrained/METR-LA/config.yaml +++ b/data/model/pretrained/METR-LA/config.yaml @@ -21,7 +21,7 @@ model: train: base_lr: 0.01 dropout: 0 - epoch: 64 + epoch: 0 epochs: 100 epsilon: 0.001 global_step: 24375 diff --git a/run_demo_pytorch.py b/run_demo_pytorch.py index 714facd..1a3bef6 100644 --- a/run_demo_pytorch.py +++ b/run_demo_pytorch.py @@ -10,7 +10,7 @@ from model.pytorch.dcrnn_supervisor import DCRNNSupervisor def run_dcrnn(args): with open(args.config_filename) as f: - supervisor_config = yaml.load(f) + supervisor_config = yaml.load(f, Loader=yaml.Loader) graph_pkl_filename = supervisor_config['data'].get('graph_pkl_filename') sensor_ids, sensor_id_to_ind, adj_mx = load_graph_data(graph_pkl_filename)