shuffle the order of the batches, but not the entries within a batch
This commit is contained in:
parent
8440002db2
commit
0e57379372
|
|
@ -174,12 +174,23 @@ class DCRNNSupervisor:
|
|||
|
||||
self.dcrnn_model = self.dcrnn_model.train()
|
||||
|
||||
|
||||
# shuffle the batches
|
||||
train_iterator = self._data['train_loader'].get_iterator()
|
||||
all_train = np.array([(x,y) for _, (x, y) in enumerate(train_iterator)])
|
||||
permutation = np.random.permutation(all_train.shape[0])
|
||||
all_train = all_train[permutation]
|
||||
|
||||
|
||||
losses = []
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
for _, (x, y) in enumerate(train_iterator):
|
||||
for batch in all_train:
|
||||
x = batch[0]
|
||||
y = batch[1]
|
||||
|
||||
|
||||
|
||||
optimizer.zero_grad()
|
||||
|
||||
x, y = self._prepare_data(x, y)
|
||||
|
|
|
|||
Loading…
Reference in New Issue