diff --git a/model/pytorch/dcrnn_supervisor.py b/model/pytorch/dcrnn_supervisor.py index 4ce2369..5f78ab3 100644 --- a/model/pytorch/dcrnn_supervisor.py +++ b/model/pytorch/dcrnn_supervisor.py @@ -121,7 +121,7 @@ class DCRNNSupervisor: def _train(self, base_lr, steps, patience=50, epochs=100, lr_decay_ratio=0.1, log_every=1, save_model=1, - test_every_n_epochs=10): + test_every_n_epochs=10, **kwargs): # steps is used in learning rate - will see if need to use it? min_val_loss = float('inf') wait = 0