Merge 0e57379372 into d92490b808
This commit is contained in:
commit
6d336630be
|
|
@ -186,7 +186,7 @@ def load_dataset(dataset_dir, batch_size, test_batch_size=None, **kwargs):
|
||||||
for category in ['train', 'val', 'test']:
|
for category in ['train', 'val', 'test']:
|
||||||
data['x_' + category][..., 0] = scaler.transform(data['x_' + category][..., 0])
|
data['x_' + category][..., 0] = scaler.transform(data['x_' + category][..., 0])
|
||||||
data['y_' + category][..., 0] = scaler.transform(data['y_' + category][..., 0])
|
data['y_' + category][..., 0] = scaler.transform(data['y_' + category][..., 0])
|
||||||
data['train_loader'] = DataLoader(data['x_train'], data['y_train'], batch_size, shuffle=True)
|
data['train_loader'] = DataLoader(data['x_train'], data['y_train'], batch_size, shuffle=False)
|
||||||
data['val_loader'] = DataLoader(data['x_val'], data['y_val'], test_batch_size, shuffle=False)
|
data['val_loader'] = DataLoader(data['x_val'], data['y_val'], test_batch_size, shuffle=False)
|
||||||
data['test_loader'] = DataLoader(data['x_test'], data['y_test'], test_batch_size, shuffle=False)
|
data['test_loader'] = DataLoader(data['x_test'], data['y_test'], test_batch_size, shuffle=False)
|
||||||
data['scaler'] = scaler
|
data['scaler'] = scaler
|
||||||
|
|
|
||||||
|
|
@ -174,12 +174,23 @@ class DCRNNSupervisor:
|
||||||
|
|
||||||
self.dcrnn_model = self.dcrnn_model.train()
|
self.dcrnn_model = self.dcrnn_model.train()
|
||||||
|
|
||||||
train_iterator = self._data['train_loader'].get_iterator()
|
|
||||||
losses = []
|
|
||||||
|
|
||||||
|
# shuffle the batches
|
||||||
|
train_iterator = self._data['train_loader'].get_iterator()
|
||||||
|
all_train = np.array([(x,y) for _, (x, y) in enumerate(train_iterator)])
|
||||||
|
permutation = np.random.permutation(all_train.shape[0])
|
||||||
|
all_train = all_train[permutation]
|
||||||
|
|
||||||
|
|
||||||
|
losses = []
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
for _, (x, y) in enumerate(train_iterator):
|
for batch in all_train:
|
||||||
|
x = batch[0]
|
||||||
|
y = batch[1]
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
|
|
||||||
x, y = self._prepare_data(x, y)
|
x, y = self._prepare_data(x, y)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue