From 0e57379372753863f9b406aa6f7c5b8764ed3db0 Mon Sep 17 00:00:00 2001 From: Seth Ockerman Date: Wed, 17 Jul 2024 13:48:33 -0500 Subject: [PATCH] shuffle the order of the batches, but not the entries within a batch --- model/pytorch/dcrnn_supervisor.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/model/pytorch/dcrnn_supervisor.py b/model/pytorch/dcrnn_supervisor.py index c2a34d4..53f6a33 100644 --- a/model/pytorch/dcrnn_supervisor.py +++ b/model/pytorch/dcrnn_supervisor.py @@ -174,12 +174,23 @@ class DCRNNSupervisor: self.dcrnn_model = self.dcrnn_model.train() + + # shuffle the batches train_iterator = self._data['train_loader'].get_iterator() + all_train = np.array([(x,y) for _, (x, y) in enumerate(train_iterator)]) + permutation = np.random.permutation(all_train.shape[0]) + all_train = all_train[permutation] + + losses = [] - start_time = time.time() - for _, (x, y) in enumerate(train_iterator): + for batch in all_train: + x = batch[0] + y = batch[1] + + + optimizer.zero_grad() x, y = self._prepare_data(x, y)