Skip to content

Commit

Permalink
minor updates
Browse files Browse the repository at this point in the history
  • Loading branch information
cmdupuis3 committed Nov 30, 2021
1 parent da42a9c commit a54a9b7
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 9 deletions.
6 changes: 3 additions & 3 deletions xbatcher/generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,9 +99,9 @@ class BatchGenerator:
If ``True``, each batch will be loaded into memory before reshaping /
processing, triggering any dask arrays to be computed.
squeeze_batch_dim : bool, optional
If ``False", each batch's dataset will have a "batch" dimension of size 1
prepended to the array. This functionality is useful for interoperability
with Keras / Tensorflow.
If ``False" and all dims are input dims, each batch's dataset will have a
"batch" dimension of size 1 prepended to the array. This functionality is
useful for interoperability with Keras / Tensorflow.
Yields
------
Expand Down
11 changes: 5 additions & 6 deletions xbatcher/tests/test_generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,24 +169,23 @@ def test_batch_3d_2d_input_concat(sample_ds_3d, bsize):
)


@pytest.mark.parametrize('bsize', [5, 10])
@pytest.mark.parametrize('bsize', [10, 20])
def test_batch_1d_squeeze_batch_dim(sample_ds_1d, bsize):
xbsize = 20
bg = BatchGenerator(
sample_ds_1d,
input_dims={'x': xbsize},
input_dims={'x': bsize},
squeeze_batch_dim=False,
)
for ds_batch in bg:
assert list(ds_batch['foo'].shape) == [1, xbsize]
assert list(ds_batch['foo'].shape) == [1, bsize]

bg2 = BatchGenerator(
sample_ds_1d,
input_dims={'x': xbsize},
input_dims={'x': bsize},
squeeze_batch_dim=True,
)
for ds_batch in bg2:
assert list(ds_batch['foo'].shape) == [xbsize]
assert list(ds_batch['foo'].shape) == [bsize]


@pytest.mark.parametrize('bsize', [5, 10])
Expand Down

0 comments on commit a54a9b7

Please sign in to comment.