Skip to content

Commit

Permalink
Fix a TypeError that occurred when using predefined_split
Browse files Browse the repository at this point in the history
This bug was introduced by PR 646 after a change in the function
signature of utils._make_split. Here the bug is fixed by returning to
the old signature but giving default values for valid_ds and y, so
that the desired behavior introduced by 646 is still possible.

A test is introduced to catch this bug.

As a cleanup, the unrelated test_predefined_split now uses the
already defined dataset_cls fixture instead of importing Dataset again.
  • Loading branch information
BenjaminBossan authored and BenjaminBossan committed Aug 26, 2020
1 parent 807974c commit ae2b4ac
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 5 deletions.
27 changes: 23 additions & 4 deletions skorch/tests/test_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -2281,15 +2281,18 @@ def get_step(self):
flatten(net.history[:, 'batches', :, 'train_loss']))
assert np.allclose(side_effect[2::3], expected_losses)

def test_predefined_split(self, net_cls, module_cls, data):
from skorch.dataset import Dataset
@pytest.fixture
def predefined_split(self):
from skorch.helper import predefined_split
return predefined_split

def test_predefined_split(
self, net_cls, module_cls, data, predefined_split, dataset_cls):
train_loader_mock = Mock(side_effect=torch.utils.data.DataLoader)
valid_loader_mock = Mock(side_effect=torch.utils.data.DataLoader)

train_ds = Dataset(*data)
valid_ds = Dataset(*data)
train_ds = dataset_cls(*data)
valid_ds = dataset_cls(*data)

net = net_cls(
module_cls, max_epochs=1,
Expand All @@ -2306,6 +2309,22 @@ def test_predefined_split(self, net_cls, module_cls, data):
assert train_loader_ds == train_ds
assert valid_loader_ds == valid_ds

def test_predefined_split_with_y(
self, net_cls, module_cls, data, predefined_split, dataset_cls):
# A change in the signature of utils._make_split in #646 led
# to a bug reported in #681, namely `TypeError: _make_split()
# got multiple values for argument 'valid_ds'`. This is a test
# for the bug.
X, y = data
X_train, y_train, X_valid, y_valid = X[:800], y[:800], X[800:], y[800:]
valid_ds = dataset_cls(X_valid, y_valid)
net = net_cls(
module_cls,
max_epochs=1,
train_split=predefined_split(valid_ds),
)
net.fit(X_train, y_train)

def test_set_lr_at_runtime_doesnt_reinitialize(self, net_fit):
with patch('skorch.NeuralNet.initialize_optimizer') as f:
net_fit.set_params(lr=0.9)
Expand Down
2 changes: 1 addition & 1 deletion skorch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,7 +491,7 @@ def get_step(self):
return self.step


def _make_split(X, valid_ds, **kwargs):
def _make_split(X, y=None, valid_ds=None, **kwargs):
"""Used by ``predefined_split`` to allow for pickling"""
return X, valid_ds

Expand Down

0 comments on commit ae2b4ac

Please sign in to comment.