Skip to content

Commit

Permalink
Support list of fits files
Browse files Browse the repository at this point in the history
  • Loading branch information
sjperkins committed Mar 20, 2024
1 parent 234aa0c commit d653702
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 6 deletions.
20 changes: 17 additions & 3 deletions tests/test_xarrayfits.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
"""Tests for `xarrayfits` package."""

from contextlib import ExitStack
import os.path

from astropy.io import fits
from dask.distributed import Client, LocalCluster
Expand All @@ -22,16 +23,18 @@ def multiple_files(tmp_path_factory):
data = np.arange(np.prod(shape), dtype=np.float64)
data = data.reshape(shape)

filenames = []

for i in range(3):
filename = str(path / f"data-{i}.fits")
filenames.append(filename)
primary_hdu = fits.PrimaryHDU(data)
primary_hdu.writeto(filename, overwrite=True)

return str(path / f"data*.fits")
return filenames


def test_globbing(multiple_files):
datasets = xds_from_fits(multiple_files)
def multiple_dataset_tester(datasets):
assert len(datasets) == 3

for xds in datasets:
Expand All @@ -53,6 +56,17 @@ def test_globbing(multiple_files):
assert combined.hdu0.dims == ("time", "hdu0-0", "hdu0-1")


def test_list_files(multiple_files):
datasets = xds_from_fits(multiple_files)
return multiple_dataset_tester(datasets)


def test_globbing(multiple_files):
path, _ = os.path.split(multiple_files[0])
datasets = xds_from_fits(f"{path}{os.sep}data*.fits")
return multiple_dataset_tester(datasets)


@pytest.fixture(scope="session")
def beam_cube(tmp_path_factory):
frequency = np.linspace(0.856e9, 0.856e9 * 2, 32, endpoint=True)
Expand Down
14 changes: 11 additions & 3 deletions xarrayfits/fits.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import logging
import os
import os.path
from collections.abc import Sequence

import dask
import dask.array as da
Expand Down Expand Up @@ -180,8 +181,9 @@ def xds_from_fits(fits_filename, hdus=None, prefix="hdu", chunks=None):
"""
Parameters
----------
fits_filename : str
FITS filename. This can contain a globbed pattern.
fits_filename : str or list of str
FITS filename or a list of FITS filenames.
The first case supports a globbed pattern.
hdus : integer or list of integers, optional
hdus to represent on the returned Dataset.
If ``None``, all HDUs are selected
Expand All @@ -202,7 +204,13 @@ def xds_from_fits(fits_filename, hdus=None, prefix="hdu", chunks=None):
to each HDU on the FITS file.
"""

openfiles = fsspec.open_files(fits_filename)
if isinstance(fits_filename, str):
openfiles = fsspec.open_files(fits_filename)
elif isinstance(fits_filename, Sequence):
openfiles = fsspec.open_files(fits_filename)
else:
raise TypeError(f"{type(fits_filename)} is not a " f"string or Sequence")

datasets = []

for filename in (f.path for f in openfiles):
Expand Down

0 comments on commit d653702

Please sign in to comment.