Skip to content

Commit

Permalink
Create xrray coordinates based on the pixel grid
Browse files Browse the repository at this point in the history
  • Loading branch information
sjperkins committed Apr 4, 2024
1 parent 5cc48da commit 4026dbd
Show file tree
Hide file tree
Showing 4 changed files with 132 additions and 10 deletions.
43 changes: 43 additions & 0 deletions tests/test_axes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
from astropy.io import fits
import pytest

from xarrayfits.axes import Axes


@pytest.fixture(params=[(10, 20, 30)])
def header(request):
# Reverse into FORTRAN order
rev_dims = list(reversed(request.param))
naxes = {f"NAXIS{d + 1}": s for d, s in enumerate(rev_dims)}
crpix = {f"CRPIX{d + 1}": 5 + d for d, _ in enumerate(rev_dims)}
crval = {f"CRVAL{d + 1}": 1.0 + d for d, _ in enumerate(rev_dims)}
cdelt = {f"CDELT{d + 1}": 2.0 + d for d, _ in enumerate(rev_dims)}
cunit = {f"CUNIT{d + 1}": f"UNIT-{len(rev_dims) - d}" for d in range(len(rev_dims))}
ctype = {f"CTYPE{d + 1}": f"TYPE-{len(rev_dims) - d}" for d in range(len(rev_dims))}
cname = {f"CNAME{d + 1}": f"NAME-{len(rev_dims) - d}" for d in range(len(rev_dims))}

return fits.Header(
{
"NAXIS": len(request.param),
**naxes,
**crpix,
**crval,
**cdelt,
**cname,
**ctype,
**cunit,
}
)


def test_axes(header):
axes = Axes(header)
ndims = axes.ndims
assert ndims == header["NAXIS"]
assert axes.naxis == [10, 20, 30]
assert axes.crpix == [7, 6, 5]
assert axes.crval == [3.0, 2.0, 1.0]
assert axes.cdelt == [4.0, 3.0, 2.0]
assert axes.cname == [header[f"CNAME{ndims - i}"] for i in range(ndims)]
assert axes.cunit == [header[f"CUNIT{ndims - i}"] for i in range(ndims)]
assert axes.ctype == [header[f"CTYPE{ndims - i}"] for i in range(ndims)]
4 changes: 2 additions & 2 deletions tests/test_xarrayfits.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def beam_cube(tmp_path_factory):
def test_name_prefix(beam_cube):
"""Test specification of a name prefix"""
(xds,) = xds_from_fits(beam_cube, prefix="beam")
assert xds.beam0.dims == ("beam0-0", "beam0-1", "beam0-2")
assert xds.beam0.dims == ("X0", "Y0", "FREQ0")


def test_beam_creation(beam_cube):
Expand All @@ -162,7 +162,7 @@ def test_beam_creation(beam_cube):
cmp_data = cmp_data.reshape(xds.hdu0.shape)
assert_array_equal(xds.hdu0.data, cmp_data)
assert xds.hdu0.data.shape == (257, 257, 32)
assert xds.hdu0.dims == ("hdu0-0", "hdu0-1", "hdu0-2")
assert xds.hdu0.dims == ("X0", "Y0", "FREQ0")
assert xds.hdu0.attrs == {
"header": {
"BITPIX": -64,
Expand Down
73 changes: 73 additions & 0 deletions xarrayfits/axes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import numpy as np

HEADER_PREFIXES = ["NAXIS", "CTYPE", "CRPIX", "CRVAL", "CDELT", "CUNIT", "CNAME"]


def property_factory(prefix):
def impl(self):
return getattr(self, f"_{prefix}")

return property(impl)


class UndefinedGridError(ValueError):
pass


class AxesMetaClass(type):
def __new__(cls, name, bases, dct):
for prefix in (p.lower() for p in HEADER_PREFIXES):
dct[prefix] = property_factory(prefix)
return type.__new__(cls, name, bases, dct)


class Axes(metaclass=AxesMetaClass):
"""Presents a C-ordered view over FITS Header grid attributes"""

def __init__(self, header):
self._ndims = ndims = header["NAXIS"]
axr = tuple(range(1, ndims + 1))

# Read headers into C-order
for prefix in HEADER_PREFIXES:
values = reversed([header.get(f"{prefix}{n}") for n in axr])
values = [s.strip() if isinstance(s, str) else s for s in values]
setattr(self, f"_{prefix.lower()}", values)

# We must have all NAXIS
for i, a in enumerate(self.naxis):
if a is None:
raise UndefinedGridError(f"NAXIS{ndims - i} undefined")

# Fill in any None CRVAL
self._crval = [0 if v is None else v for v in self._crval]
# Fill in any None CRPIX
self._crpix = [1 if p is None else p for p in self._crpix]
# Fill in any None CDELT
self._cdelt = [1 if d is None else d for d in self._cdelt]

self._grid = [None] * ndims

@property
def ndims(self):
return self._ndims

def name(self, dim):
"""Return a name for dimension :code:`dim`"""
if result := self.cname[dim]:
return result
elif result := self.ctype[dim]:
return result
else:
return None

def grid(self, dim):
"""Return the axis grid for dimension :code:`dim`"""
if self._grid[dim] is None:
# Create the grid
pixels = np.arange(1, self.naxis[dim] + 1)
self._grid[dim] = (pixels - self.crpix[dim]) * self.cdelt[dim] + self.crval[
dim
]

return self._grid[dim]
22 changes: 14 additions & 8 deletions xarrayfits/fits.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import xarray as xr

from xarrayfits.axes import Axes, UndefinedGridError
from xarrayfits.fits_proxy import FitsProxy

log = logging.getLogger("xarray-fits")
Expand Down Expand Up @@ -160,18 +161,17 @@ def array_from_fits_hdu(

shape = []
flat_chunks = []
axes = Axes(hdu.header)

# At this point we are dealing with FORTRAN ordered axes
for i in range(naxis):
ax_key = f"NAXIS{naxis - i}"
ax_shape = hdu.header[ax_key]
shape.append(ax_shape)
# Determine shapes and apply chunking
for i in range(axes.ndims):
shape.append(axes.naxis[i])

try:
# Try add existing chunking strategies to the list
flat_chunks.append(chunks[i])
except KeyError:
flat_chunks.append(ax_shape)
flat_chunks.append(axes.naxis[i])

array = generate_slice_gets(
fits_proxy,
Expand All @@ -181,9 +181,15 @@ def array_from_fits_hdu(
tuple(flat_chunks),
)

dims = tuple(f"{prefix}{hdu_index}-{i}" for i in range(0, naxis))
dims = tuple(
f"{name}{hdu_index}" if (name := axes.name(i)) else f"{prefix}{hdu_index}-{i}"
for i in range(axes.ndims)
)

coords = {d: (d, axes.grid(i)) for i, d in enumerate(dims)}

attrs = {"header": {k: v for k, v in sorted(hdu.header.items())}}
return xr.DataArray(array, dims=dims, attrs=attrs)
return xr.DataArray(array, dims=tuple(dims), coords=coords, attrs=attrs)


def xds_from_fits(fits_filename, hdus=None, prefix="hdu", chunks=None):
Expand Down

0 comments on commit 4026dbd

Please sign in to comment.