From aa56230221763bb03873c80cb94828d1cb91d1bc Mon Sep 17 00:00:00 2001 From: Simon Perkins Date: Fri, 5 Apr 2024 11:44:51 +0200 Subject: [PATCH] Add basic Affine Transformed Coordinates to Xarray datasets --- tests/{test_axes.py => test_grid.py} | 37 ++++++++++++++------- xarrayfits/fits.py | 20 +++++------- xarrayfits/{axes.py => grid.py} | 49 ++++++++++++++-------------- 3 files changed, 60 insertions(+), 46 deletions(-) rename tests/{test_axes.py => test_grid.py} (50%) rename xarrayfits/{axes.py => grid.py} (63%) diff --git a/tests/test_axes.py b/tests/test_grid.py similarity index 50% rename from tests/test_axes.py rename to tests/test_grid.py index 44ed19a..99f52f3 100644 --- a/tests/test_axes.py +++ b/tests/test_grid.py @@ -1,7 +1,9 @@ from astropy.io import fits +import numpy as np +from numpy.testing import assert_array_equal import pytest -from xarrayfits.axes import Axes +from xarrayfits.grid import AffineGrid @pytest.fixture(params=[(10, 20, 30)]) @@ -30,14 +32,27 @@ def header(request): ) -def test_axes(header): - axes = Axes(header) - ndims = axes.ndims +def test_affine_grid(header): + grid = AffineGrid(header) + ndims = grid.ndims assert ndims == header["NAXIS"] - assert axes.naxis == [10, 20, 30] - assert axes.crpix == [7, 6, 5] - assert axes.crval == [3.0, 2.0, 1.0] - assert axes.cdelt == [4.0, 3.0, 2.0] - assert axes.cname == [header[f"CNAME{ndims - i}"] for i in range(ndims)] - assert axes.cunit == [header[f"CUNIT{ndims - i}"] for i in range(ndims)] - assert axes.ctype == [header[f"CTYPE{ndims - i}"] for i in range(ndims)] + assert grid.naxis == [10, 20, 30] + assert grid.crpix == [7, 6, 5] + assert grid.crval == [3.0, 2.0, 1.0] + assert grid.cdelt == [4.0, 3.0, 2.0] + assert grid.cname == [header[f"CNAME{ndims - i}"] for i in range(ndims)] + assert grid.cunit == [header[f"CUNIT{ndims - i}"] for i in range(ndims)] + assert grid.ctype == [header[f"CTYPE{ndims - i}"] for i in range(ndims)] + + # Worked coordinate example + assert_array_equal(grid.coords(0), 3.0 + (np.arange(1, 10 + 1) - 7) * 4.0) + assert_array_equal(grid.coords(1), 2.0 + (np.arange(1, 20 + 1) - 6) * 3.0) + assert_array_equal(grid.coords(2), 1.0 + (np.arange(1, 30 + 1) - 5) * 2.0) + + # More automatic version + for d in range(ndims): + assert_array_equal( + grid.coords(d), + grid.crval[d] + + (np.arange(1, grid.naxis[d] + 1) - grid.crpix[d]) * grid.cdelt[d], + ) diff --git a/xarrayfits/fits.py b/xarrayfits/fits.py index 38da14c..1715927 100644 --- a/xarrayfits/fits.py +++ b/xarrayfits/fits.py @@ -17,7 +17,7 @@ import xarray as xr -from xarrayfits.axes import Axes, UndefinedGridError +from xarrayfits.grid import AffineGrid from xarrayfits.fits_proxy import FitsProxy log = logging.getLogger("xarray-fits") @@ -161,17 +161,17 @@ def array_from_fits_hdu( shape = [] flat_chunks = [] - axes = Axes(hdu.header) + grid = AffineGrid(hdu.header) # Determine shapes and apply chunking - for i in range(axes.ndims): - shape.append(axes.naxis[i]) + for i in range(grid.ndims): + shape.append(grid.naxis[i]) try: # Try add existing chunking strategies to the list flat_chunks.append(chunks[i]) except KeyError: - flat_chunks.append(axes.naxis[i]) + flat_chunks.append(grid.naxis[i]) array = generate_slice_gets( fits_proxy, @@ -182,14 +182,12 @@ def array_from_fits_hdu( ) dims = tuple( - f"{name}{hdu_index}" if (name := axes.name(i)) else f"{prefix}{hdu_index}-{i}" - for i in range(axes.ndims) + f"{name}{hdu_index}" if (name := grid.name(i)) else f"{prefix}{hdu_index}-{i}" + for i in range(grid.ndims) ) - - coords = {d: (d, axes.grid(i)) for i, d in enumerate(dims)} - + coords = {d: (d, grid.coords(i)) for i, d in enumerate(dims)} attrs = {"header": {k: v for k, v in sorted(hdu.header.items())}} - return xr.DataArray(array, dims=tuple(dims), coords=coords, attrs=attrs) + return xr.DataArray(array, dims=dims, coords=coords, attrs=attrs) def xds_from_fits(fits_filename, hdus=None, prefix="hdu", chunks=None): diff --git a/xarrayfits/axes.py b/xarrayfits/grid.py similarity index 63% rename from xarrayfits/axes.py rename to xarrayfits/grid.py index 322222a..1cf0b03 100644 --- a/xarrayfits/axes.py +++ b/xarrayfits/grid.py @@ -1,32 +1,34 @@ +from collections.abc import Mapping import numpy as np HEADER_PREFIXES = ["NAXIS", "CTYPE", "CRPIX", "CRVAL", "CDELT", "CUNIT", "CNAME"] -def property_factory(prefix): +class UndefinedGridError(ValueError): + pass + + +def property_factory(prefix: str): def impl(self): return getattr(self, f"_{prefix}") return property(impl) -class UndefinedGridError(ValueError): - pass - - -class AxesMetaClass(type): +class AffineGridMetaclass(type): def __new__(cls, name, bases, dct): for prefix in (p.lower() for p in HEADER_PREFIXES): dct[prefix] = property_factory(prefix) return type.__new__(cls, name, bases, dct) -class Axes(metaclass=AxesMetaClass): +class AffineGrid(metaclass=AffineGridMetaclass): """Presents a C-ordered view over FITS Header grid attributes""" - def __init__(self, header): + def __init__(self, header: Mapping): self._ndims = ndims = header["NAXIS"] axr = tuple(range(1, ndims + 1)) + h = header # Read headers into C-order for prefix in HEADER_PREFIXES: @@ -39,20 +41,26 @@ def __init__(self, header): if a is None: raise UndefinedGridError(f"NAXIS{ndims - i} undefined") - # Fill in any None CRVAL - self._crval = [0 if v is None else v for v in self._crval] - # Fill in any None CRPIX + # Fill in any missing CRVAL + self._crval = [0.0 if v is None else v for v in self._crval] + # Fill in any missing CRPIX self._crpix = [1 if p is None else p for p in self._crpix] - # Fill in any None CDELT - self._cdelt = [1 if d is None else d for d in self._cdelt] + # Fill in any missing CDELT + self._cdelt = [1.0 if d is None else d for d in self._cdelt] - self._grid = [None] * ndims + self._grid = [] + + for d in range(ndims): + pixels = np.arange(1, self._naxis[d] + 1, dtype=np.float64) + self._grid.append( + (pixels - self._crpix[d]) * self._cdelt[d] + self._crval[d] + ) @property def ndims(self): return self._ndims - def name(self, dim): + def name(self, dim: int): """Return a name for dimension :code:`dim`""" if result := self.cname[dim]: return result @@ -61,13 +69,6 @@ def name(self, dim): else: return None - def grid(self, dim): - """Return the axis grid for dimension :code:`dim`""" - if self._grid[dim] is None: - # Create the grid - pixels = np.arange(1, self.naxis[dim] + 1) - self._grid[dim] = (pixels - self.crpix[dim]) * self.cdelt[dim] + self.crval[ - dim - ] - + def coords(self, dim: int): + """Return the affine coordinates for dimension :code:`dim`""" return self._grid[dim]