Skip to content

Commit

Permalink
Test stacking via the use of xarray.Dataset.expand_dims (#20)
Browse files Browse the repository at this point in the history
  • Loading branch information
sjperkins authored Mar 19, 2024
1 parent 8f2ff56 commit 234aa0c
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 9 deletions.
4 changes: 4 additions & 0 deletions HISTORY.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@
History
=======

X.Y.Z (YYYY-MM-DD)
------------------
* Test stacking in the globbing case (:pr:`20`)

0.2.1 (2024-03-19)
------------------
* Make distributed an optional package (:pr:`19`)
Expand Down
23 changes: 14 additions & 9 deletions tests/test_xarrayfits.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from astropy.io import fits
from dask.distributed import Client, LocalCluster
import numpy as np
from numpy.testing import assert_array_equal
import pytest
import xarray

Expand Down Expand Up @@ -36,16 +37,20 @@ def test_globbing(multiple_files):
for xds in datasets:
expected = np.arange(np.prod(xds.hdu0.shape), dtype=np.float64)
expected = expected.reshape(xds.hdu0.shape)
np.testing.assert_array_equal(xds.hdu0.data, expected)
assert_array_equal(xds.hdu0.data, expected)

combined = xarray.concat(datasets, dim="hdu0-0")
np.testing.assert_array_equal(
combined.hdu0.data, np.concatenate([expected] * 3, axis=0)
)
assert_array_equal(combined.hdu0.data, np.concatenate([expected] * 3, axis=0))
assert combined.hdu0.dims == ("hdu0-0", "hdu0-1")

combined = xarray.concat(datasets, dim="hdu0-1")
np.testing.assert_array_equal(
combined.hdu0.data, np.concatenate([expected] * 3, axis=1)
)
assert_array_equal(combined.hdu0.data, np.concatenate([expected] * 3, axis=1))
assert combined.hdu0.dims == ("hdu0-0", "hdu0-1")

tds = [ds.expand_dims(dim="time", axis=0) for ds in datasets]
combined = xarray.concat(tds, dim="time")
assert_array_equal(combined.hdu0.data, np.stack([expected] * 3, axis=0))
assert combined.hdu0.dims == ("time", "hdu0-0", "hdu0-1")


@pytest.fixture(scope="session")
Expand Down Expand Up @@ -139,7 +144,7 @@ def test_beam_creation(beam_cube):
(xds,) = xds_from_fits(beam_cube)
cmp_data = np.arange(np.prod(xds.hdu0.shape), dtype=np.float64)
cmp_data = cmp_data.reshape(xds.hdu0.shape)
np.testing.assert_array_equal(xds.hdu0.data, cmp_data)
assert_array_equal(xds.hdu0.data, cmp_data)
assert xds.hdu0.data.shape == (257, 257, 32)
assert xds.hdu0.dims == ("hdu0-0", "hdu0-1", "hdu0-2")
assert xds.hdu0.attrs == {
Expand Down Expand Up @@ -179,5 +184,5 @@ def test_distributed(beam_cube):

(xds,) = xds_from_fits(beam_cube, chunks={0: 100, 1: 100, 2: 15})
expected = np.arange(np.prod(xds.hdu0.shape)).reshape(xds.hdu0.shape)
np.testing.assert_array_equal(expected, xds.hdu0.data)
assert_array_equal(expected, xds.hdu0.data)
assert xds.hdu0.data.chunks == ((100, 100, 57), (100, 100, 57), (15, 15, 2))

0 comments on commit 234aa0c

Please sign in to comment.