Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Correct FITS Proxy usage #8

Merged
merged 2 commits into from
Mar 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions HISTORY.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ History

X.Y.Z (YYYY-MM-DD)
------------------
* Correct FITS Proxy Usage (:pr:`8`)
* Update ruff settings (:pr:`7`)
* Update Github Actions Deployment (:pr:`6`)
* Modernise xarray-fits (:pr:`5`)
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ astropy = "^6.0.0"
dask = {extras = ["array"], version = "^2024.3.1"}
xarray = "^2024.2.0"
pytest = {version = "^8.1.1", optional = true, extras = ["testing"]}
distributed = {version = "^2024.3.1", extras = ["testing"]}

[tool.poetry.extras]
testing = ["pytest"]
Expand Down
23 changes: 19 additions & 4 deletions tests/test_xarrayfits.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,18 @@

"""Tests for `xarrayfits` package."""

import os
from contextlib import ExitStack

from astropy.io import fits
from dask.distributed import Client, LocalCluster
import numpy as np
import pytest

from xarrayfits import xds_from_fits


@pytest.fixture
def beam_cube(tmp_path):
@pytest.fixture(scope="session")
def beam_cube(tmp_path_factory):
frequency = np.linspace(0.856e9, 0.856e9 * 2, 32, endpoint=True)
bandwidth_delta = (frequency[-1] - frequency[0]) / frequency.size
dtype = np.float64
Expand Down Expand Up @@ -83,7 +84,8 @@ def beam_cube(tmp_path):
]
header.update(ax_info)

filename = os.path.join(str(tmp_path), "beam.fits")
filename = tmp_path_factory.mktemp("beam") / "beam.fits"
filename = str(filename)
# Write some data to it
data = np.arange(np.prod(shape), dtype=dtype)
primary_hdu = fits.PrimaryHDU(data.reshape(shape), header=header)
Expand Down Expand Up @@ -126,3 +128,16 @@ def test_beam_creation(beam_cube):
"NAXIS2": 257,
"NAXIS3": 257,
}


def test_distributed(beam_cube):
"""Sanity check for the distributed case"""
with ExitStack() as stack:
cluster = stack.enter_context(LocalCluster(n_workers=8, processes=True))
stack.enter_context(Client(cluster))

xds = xds_from_fits(
beam_cube, chunks={"NAXIS1": 10, "NAXIS2": 10, "NAXIS3": 10}
)
expected = np.arange(np.prod(xds.hdu0.shape)).reshape(xds.hdu0.shape)
np.testing.assert_array_equal(expected, xds.hdu0.data)
56 changes: 21 additions & 35 deletions xarrayfits/fits.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,21 +83,16 @@ def slices(r):
return (slice(s, e) for s, e in zip(r[:-1], r[1:]))


def _get_fn(fp, h, i):
return fp("__getitem__", h).data.__getitem__(i)
def _get_data_function(fp, h, i):
return fp.hdu[h].data[i]


def generate_slice_gets(fits_filename, fits_key, fits_graph, hdu, shape, dtype, chunks):
def generate_slice_gets(fits_proxy, hdu, shape, dtype, chunks):
"""
Parameters
----------
fits_filename : str
FITS filename
fits_key : tuple
dask key referencing an opened FITS file object
fits_graph : dict
dask graph containing ``fits_key`` referencing an
opened FITS file object
fits_proxy : FitsProxy
FITS Proxy
hdu : integer
FITS HDU for which to generate a dask array
shape : tuple
Expand All @@ -114,8 +109,8 @@ def generate_slice_gets(fits_filename, fits_key, fits_graph, hdu, shape, dtype,
with the ``hdu``.
"""

token = dask.base.tokenize(fits_filename)
name = "-".join((short_fits_file(fits_filename), "slice", token))
token = dask.base.tokenize(fits_proxy)
name = "-".join((short_fits_file(fits_proxy._filename), "slice", token))

dsk_chunks = da.core.normalize_chunks(chunks, shape)

Expand All @@ -124,15 +119,16 @@ def generate_slice_gets(fits_filename, fits_key, fits_graph, hdu, shape, dtype,
slices_ = product(*[slices(tuple(ranges(c))) for c in dsk_chunks])

# Create dask graph
dsk = {key: (_get_fn, fits_key, hdu, slice_) for key, slice_ in zip(keys, slices_)}
dsk = {
key: (_get_data_function, fits_proxy, hdu, slice_)
for key, slice_ in zip(keys, slices_)
}

return da.Array({**dsk, **fits_graph}, name, dsk_chunks, dtype)
return da.Array(dsk, name, dsk_chunks, dtype)


def _xarray_from_fits_hdu(
fits_filename,
fits_key,
fits_graph,
def array_from_fits_hdu(
fits_proxy,
name_prefix,
hdu_list,
hdu_index,
Expand All @@ -141,13 +137,8 @@ def _xarray_from_fits_hdu(
"""
Parameters
----------
fits_filename : str
FITS filename
fits_key : tuple
dask key referencing an opened FITS file object
fits_graph : dict
dask graph containing ``fits_key`` referencing an
opened FITS file object
fits_proxy : FitsProxy
The FITS proxy
hdu_list : :class:`astropy.io.fits.hdu.hdulist.HDUList`
FITS HDU list
hdu_index : integer
Expand Down Expand Up @@ -201,9 +192,7 @@ def _xarray_from_fits_hdu(
flat_chunks = tuple(reversed(flat_chunks))

array = generate_slice_gets(
fits_filename,
fits_key,
fits_graph,
fits_proxy,
hdu_index,
shape,
dtype,
Expand Down Expand Up @@ -260,15 +249,12 @@ def xds_from_fits(fits_filename, hdus=None, name_prefix="hdu", chunks=None):
f"chunks ({len(chunks)})"
)

fits_key, fits_graph = fits_open_graph(fits_filename)

fn = _xarray_from_fits_hdu
fits_proxy = FitsProxy(fits_filename)

# Generate xarray datavars for each hdu
xarrays = {
f"{name_prefix}{hdu_index}": fn(
fits_filename,
fits_key,
fits_graph,
f"{name_prefix}{hdu_index}": array_from_fits_hdu(
fits_proxy,
name_prefix,
hdu_list,
hdu_index,
Expand Down
47 changes: 39 additions & 8 deletions xarrayfits/fits_proxy.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,28 @@
from threading import Lock
import weakref

from astropy.io import fits

TABLE_CACHE_LOCK = Lock()
TABLE_CACHE = weakref.WeakValueDictionary()


class FitsProxyMetaClass(type):
"""https://en.wikipedia.org/wiki/Multiton_pattern"""

def __call__(cls, *args, **kwargs):
key = (cls,) + args + tuple(set(kwargs.items()))

with TABLE_CACHE_LOCK:
try:
return TABLE_CACHE[key]
except KeyError:
instance = type.__call__(cls, *args, **kwargs)
TABLE_CACHE[key] = instance
return instance


class FitsProxy(object):
class FitsProxy(metaclass=FitsProxyMetaClass):
"""
Picklable object proxying a :class:`astropy.io.fits` class
"""
Expand All @@ -18,13 +39,23 @@ def __init__(self, filename, **kwargs):
"""
self._filename = filename
self._kwargs = kwargs
self._fits_file = fits.open(filename, **kwargs)
self._lock = Lock()

def __setstate__(self, state):
self.__init__(*state)
@staticmethod
def from_reduce_args(filename, kw):
return FitsProxy(filename, **kw)

def __getstate__(self):
return (self._filename, self._kwargs)
@property
def hdu(self):
try:
return self._hdul
except AttributeError:
with self._lock:
try:
return self._hdul
except AttributeError:
self._hdul = fits.open(self._filename, **self._kwargs)
return self._hdul

def __call__(self, fn, *args, **kwargs):
return getattr(self._fits_file, fn)(*args, **kwargs)
def __reduce__(self):
return (FitsProxy.from_reduce_args, (self._filename, self._kwargs))