shuffle the order of the batches, but not the entries within a batch

This commit is contained in:
Seth Ockerman 2024-07-17 13:48:33 -05:00
parent 8440002db2
commit 0e57379372
1 changed files with 13 additions and 2 deletions

View File

@ -174,12 +174,23 @@ class DCRNNSupervisor:
self.dcrnn_model = self.dcrnn_model.train()
# shuffle the batches
train_iterator = self._data['train_loader'].get_iterator()
all_train = np.array([(x,y) for _, (x, y) in enumerate(train_iterator)])
permutation = np.random.permutation(all_train.shape[0])
all_train = all_train[permutation]
losses = []
start_time = time.time()
for _, (x, y) in enumerate(train_iterator):
for batch in all_train:
x = batch[0]
y = batch[1]
optimizer.zero_grad()
x, y = self._prepare_data(x, y)