diff --git a/skorch/tests/test_net.py b/skorch/tests/test_net.py index bef4d21e0..0d333e347 100644 --- a/skorch/tests/test_net.py +++ b/skorch/tests/test_net.py @@ -2281,15 +2281,18 @@ def get_step(self): flatten(net.history[:, 'batches', :, 'train_loss'])) assert np.allclose(side_effect[2::3], expected_losses) - def test_predefined_split(self, net_cls, module_cls, data): - from skorch.dataset import Dataset + @pytest.fixture + def predefined_split(self): from skorch.helper import predefined_split + return predefined_split + def test_predefined_split( + self, net_cls, module_cls, data, predefined_split, dataset_cls): train_loader_mock = Mock(side_effect=torch.utils.data.DataLoader) valid_loader_mock = Mock(side_effect=torch.utils.data.DataLoader) - train_ds = Dataset(*data) - valid_ds = Dataset(*data) + train_ds = dataset_cls(*data) + valid_ds = dataset_cls(*data) net = net_cls( module_cls, max_epochs=1, @@ -2306,6 +2309,22 @@ def test_predefined_split(self, net_cls, module_cls, data): assert train_loader_ds == train_ds assert valid_loader_ds == valid_ds + def test_predefined_split_with_y( + self, net_cls, module_cls, data, predefined_split, dataset_cls): + # A change in the signature of utils._make_split in #646 led + # to a bug reported in #681, namely `TypeError: _make_split() + # got multiple values for argument 'valid_ds'`. This is a test + # for the bug. + X, y = data + X_train, y_train, X_valid, y_valid = X[:800], y[:800], X[800:], y[800:] + valid_ds = dataset_cls(X_valid, y_valid) + net = net_cls( + module_cls, + max_epochs=1, + train_split=predefined_split(valid_ds), + ) + net.fit(X_train, y_train) + def test_set_lr_at_runtime_doesnt_reinitialize(self, net_fit): with patch('skorch.NeuralNet.initialize_optimizer') as f: net_fit.set_params(lr=0.9) diff --git a/skorch/utils.py b/skorch/utils.py index 9c19bb599..d5fed89d3 100644 --- a/skorch/utils.py +++ b/skorch/utils.py @@ -491,7 +491,7 @@ def get_step(self): return self.step -def _make_split(X, valid_ds, **kwargs): +def _make_split(X, y=None, valid_ds=None, **kwargs): """Used by ``predefined_split`` to allow for pickling""" return X, valid_ds