Skip to content

Commit

Permalink
More squeeze_batch_dim tests; fix bug
Browse files Browse the repository at this point in the history
  • Loading branch information
cmdupuis3 committed Nov 19, 2021
1 parent fb29cba commit da42a9c
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 5 deletions.
13 changes: 8 additions & 5 deletions xbatcher/generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,15 +58,18 @@ def _maybe_stack_batch_dims(
ds, input_dims, squeeze_batch_dim, stacked_dim_name='sample'
):
batch_dims = [d for d in ds.dims if d not in input_dims]
if len(batch_dims) < 2:
if len(batch_dims) == 0:
if squeeze_batch_dim:
return ds
else:
return ds.expand_dims(stacked_dim_name, 0)
ds_stack = ds.stack(**{stacked_dim_name: batch_dims})
# ensure correct order
dim_order = (stacked_dim_name,) + tuple(input_dims)
return ds_stack.transpose(*dim_order)
elif len(batch_dims) == 1:
return ds
else:
ds_stack = ds.stack(**{stacked_dim_name: batch_dims})
# ensure correct order
dim_order = (stacked_dim_name,) + tuple(input_dims)
return ds_stack.transpose(*dim_order)


class BatchGenerator:
Expand Down
40 changes: 40 additions & 0 deletions xbatcher/tests/test_generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,46 @@ def test_batch_1d_squeeze_batch_dim(sample_ds_1d, bsize):
assert list(ds_batch['foo'].shape) == [xbsize]


@pytest.mark.parametrize('bsize', [5, 10])
def test_batch_3d_squeeze_batch_dim(sample_ds_3d, bsize):
xbsize = 20
bg = BatchGenerator(
sample_ds_3d,
input_dims={'y': bsize, 'x': xbsize},
squeeze_batch_dim=False,
)
for ds_batch in bg:
assert list(ds_batch['foo'].shape) == [10, bsize, xbsize]

bg2 = BatchGenerator(
sample_ds_3d,
input_dims={'y': bsize, 'x': xbsize},
squeeze_batch_dim=True,
)
for ds_batch in bg2:
assert list(ds_batch['foo'].shape) == [10, bsize, xbsize]


@pytest.mark.parametrize('bsize', [5, 10])
def test_batch_3d_squeeze_batch_dim2(sample_ds_3d, bsize):
xbsize = 20
bg = BatchGenerator(
sample_ds_3d,
input_dims={'x': xbsize},
squeeze_batch_dim=False,
)
for ds_batch in bg:
assert list(ds_batch['foo'].shape) == [500, xbsize]

bg2 = BatchGenerator(
sample_ds_3d,
input_dims={'x': xbsize},
squeeze_batch_dim=True,
)
for ds_batch in bg2:
assert list(ds_batch['foo'].shape) == [500, xbsize]


def test_preload_batch_false(sample_ds_1d):
sample_ds_1d_dask = sample_ds_1d.chunk({'x': 2})
bg = BatchGenerator(
Expand Down

0 comments on commit da42a9c

Please sign in to comment.