Skip to content

Commit

Permalink
Add basic Affine Transformed Coordinates to Xarray datasets
Browse files Browse the repository at this point in the history
  • Loading branch information
sjperkins committed Apr 5, 2024
1 parent 4026dbd commit aa56230
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 46 deletions.
37 changes: 26 additions & 11 deletions tests/test_axes.py → tests/test_grid.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from astropy.io import fits
import numpy as np
from numpy.testing import assert_array_equal
import pytest

from xarrayfits.axes import Axes
from xarrayfits.grid import AffineGrid


@pytest.fixture(params=[(10, 20, 30)])
Expand Down Expand Up @@ -30,14 +32,27 @@ def header(request):
)


def test_axes(header):
axes = Axes(header)
ndims = axes.ndims
def test_affine_grid(header):
grid = AffineGrid(header)
ndims = grid.ndims
assert ndims == header["NAXIS"]
assert axes.naxis == [10, 20, 30]
assert axes.crpix == [7, 6, 5]
assert axes.crval == [3.0, 2.0, 1.0]
assert axes.cdelt == [4.0, 3.0, 2.0]
assert axes.cname == [header[f"CNAME{ndims - i}"] for i in range(ndims)]
assert axes.cunit == [header[f"CUNIT{ndims - i}"] for i in range(ndims)]
assert axes.ctype == [header[f"CTYPE{ndims - i}"] for i in range(ndims)]
assert grid.naxis == [10, 20, 30]
assert grid.crpix == [7, 6, 5]
assert grid.crval == [3.0, 2.0, 1.0]
assert grid.cdelt == [4.0, 3.0, 2.0]
assert grid.cname == [header[f"CNAME{ndims - i}"] for i in range(ndims)]
assert grid.cunit == [header[f"CUNIT{ndims - i}"] for i in range(ndims)]
assert grid.ctype == [header[f"CTYPE{ndims - i}"] for i in range(ndims)]

# Worked coordinate example
assert_array_equal(grid.coords(0), 3.0 + (np.arange(1, 10 + 1) - 7) * 4.0)
assert_array_equal(grid.coords(1), 2.0 + (np.arange(1, 20 + 1) - 6) * 3.0)
assert_array_equal(grid.coords(2), 1.0 + (np.arange(1, 30 + 1) - 5) * 2.0)

# More automatic version
for d in range(ndims):
assert_array_equal(
grid.coords(d),
grid.crval[d]
+ (np.arange(1, grid.naxis[d] + 1) - grid.crpix[d]) * grid.cdelt[d],
)
20 changes: 9 additions & 11 deletions xarrayfits/fits.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

import xarray as xr

from xarrayfits.axes import Axes, UndefinedGridError
from xarrayfits.grid import AffineGrid
from xarrayfits.fits_proxy import FitsProxy

log = logging.getLogger("xarray-fits")
Expand Down Expand Up @@ -161,17 +161,17 @@ def array_from_fits_hdu(

shape = []
flat_chunks = []
axes = Axes(hdu.header)
grid = AffineGrid(hdu.header)

# Determine shapes and apply chunking
for i in range(axes.ndims):
shape.append(axes.naxis[i])
for i in range(grid.ndims):
shape.append(grid.naxis[i])

try:
# Try add existing chunking strategies to the list
flat_chunks.append(chunks[i])
except KeyError:
flat_chunks.append(axes.naxis[i])
flat_chunks.append(grid.naxis[i])

array = generate_slice_gets(
fits_proxy,
Expand All @@ -182,14 +182,12 @@ def array_from_fits_hdu(
)

dims = tuple(
f"{name}{hdu_index}" if (name := axes.name(i)) else f"{prefix}{hdu_index}-{i}"
for i in range(axes.ndims)
f"{name}{hdu_index}" if (name := grid.name(i)) else f"{prefix}{hdu_index}-{i}"
for i in range(grid.ndims)
)

coords = {d: (d, axes.grid(i)) for i, d in enumerate(dims)}

coords = {d: (d, grid.coords(i)) for i, d in enumerate(dims)}
attrs = {"header": {k: v for k, v in sorted(hdu.header.items())}}
return xr.DataArray(array, dims=tuple(dims), coords=coords, attrs=attrs)
return xr.DataArray(array, dims=dims, coords=coords, attrs=attrs)


def xds_from_fits(fits_filename, hdus=None, prefix="hdu", chunks=None):
Expand Down
49 changes: 25 additions & 24 deletions xarrayfits/axes.py → xarrayfits/grid.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,34 @@
from collections.abc import Mapping
import numpy as np

HEADER_PREFIXES = ["NAXIS", "CTYPE", "CRPIX", "CRVAL", "CDELT", "CUNIT", "CNAME"]


def property_factory(prefix):
class UndefinedGridError(ValueError):
pass


def property_factory(prefix: str):
def impl(self):
return getattr(self, f"_{prefix}")

return property(impl)


class UndefinedGridError(ValueError):
pass


class AxesMetaClass(type):
class AffineGridMetaclass(type):
def __new__(cls, name, bases, dct):
for prefix in (p.lower() for p in HEADER_PREFIXES):
dct[prefix] = property_factory(prefix)
return type.__new__(cls, name, bases, dct)


class Axes(metaclass=AxesMetaClass):
class AffineGrid(metaclass=AffineGridMetaclass):
"""Presents a C-ordered view over FITS Header grid attributes"""

def __init__(self, header):
def __init__(self, header: Mapping):
self._ndims = ndims = header["NAXIS"]
axr = tuple(range(1, ndims + 1))
h = header

# Read headers into C-order
for prefix in HEADER_PREFIXES:
Expand All @@ -39,20 +41,26 @@ def __init__(self, header):
if a is None:
raise UndefinedGridError(f"NAXIS{ndims - i} undefined")

# Fill in any None CRVAL
self._crval = [0 if v is None else v for v in self._crval]
# Fill in any None CRPIX
# Fill in any missing CRVAL
self._crval = [0.0 if v is None else v for v in self._crval]
# Fill in any missing CRPIX
self._crpix = [1 if p is None else p for p in self._crpix]
# Fill in any None CDELT
self._cdelt = [1 if d is None else d for d in self._cdelt]
# Fill in any missing CDELT
self._cdelt = [1.0 if d is None else d for d in self._cdelt]

self._grid = [None] * ndims
self._grid = []

for d in range(ndims):
pixels = np.arange(1, self._naxis[d] + 1, dtype=np.float64)
self._grid.append(
(pixels - self._crpix[d]) * self._cdelt[d] + self._crval[d]
)

@property
def ndims(self):
return self._ndims

def name(self, dim):
def name(self, dim: int):
"""Return a name for dimension :code:`dim`"""
if result := self.cname[dim]:
return result
Expand All @@ -61,13 +69,6 @@ def name(self, dim):
else:
return None

def grid(self, dim):
"""Return the axis grid for dimension :code:`dim`"""
if self._grid[dim] is None:
# Create the grid
pixels = np.arange(1, self.naxis[dim] + 1)
self._grid[dim] = (pixels - self.crpix[dim]) * self.cdelt[dim] + self.crval[
dim
]

def coords(self, dim: int):
"""Return the affine coordinates for dimension :code:`dim`"""
return self._grid[dim]

0 comments on commit aa56230

Please sign in to comment.