diff --git a/corrct/alignment/fitting.py b/corrct/alignment/fitting.py index 96c79e6..e26bd2d 100644 --- a/corrct/alignment/fitting.py +++ b/corrct/alignment/fitting.py @@ -8,11 +8,10 @@ and ESRF - The European Synchrotron, Grenoble, France """ -from collections.abc import Sequence -from typing import Optional, Union +from typing import Literal, Optional, Union, Sequence import numpy as np -import numpy.polynomial +from numpy.polynomial import Polynomial import scipy.ndimage as spimg import scipy.optimize as spopt from numpy.typing import ArrayLike, NDArray @@ -69,6 +68,7 @@ def fit_shifts_vu_xc( pad_u: bool = False, normalize_fourier: bool = True, use_rfft: bool = True, + stack_axis: int = -2, decimals: int = 2, ) -> NDArrayFloat: """ @@ -86,6 +86,8 @@ def fit_shifts_vu_xc( Whether to normalize the Fourier representation of the cross-correlation. The default is True. use_rfft : bool, optional Whether to use the `rfft` transform in place of the complex `fft` transform. The default is True. + stack_axis : int, optional + The axis along which the VU images are stacked. The default is -2. decimals : int, optional Decimals for the truncation of the sub-pixel The default is 2. @@ -94,7 +96,7 @@ def fit_shifts_vu_xc( NDArrayFloat The VU shifts. """ - num_angles = data_vwu.shape[-2] + num_angles = data_vwu.shape[stack_axis] if use_rfft: local_fftn = np.fft.rfftn @@ -103,27 +105,36 @@ def fit_shifts_vu_xc( local_fftn = np.fft.fftn local_ifftn = np.fft.ifftn - fft_dims = np.delete(np.arange(-len(data_vwu.shape), 0), -2) + fft_dims = np.delete(np.arange(-len(data_vwu.shape), 0), stack_axis) + u_axis = fft_dims[-1] + old_fft_shapes = np.array(np.array(data_vwu.shape)[fft_dims], ndmin=1, dtype=int) new_fft_shapes = old_fft_shapes.copy() if pad_u: - new_fft_shapes[-1] *= 2 + new_fft_shapes[u_axis] *= 2 cc_coords = [np.fft.fftfreq(s, 1 / s) for s in new_fft_shapes] if len(fft_dims) == 2: shifts_vu = np.empty((len(data_vwu.shape) - 1, num_angles)) - for ii in range(num_angles): + slices = [slice(None)] * len(data_vwu.shape) + for ii_a in range(num_angles): # For performance reasons, it is better to do the fft on each image - data_vwu_f = local_fftn(data_vwu[..., ii, :], s=list(new_fft_shapes)) - proj_vwu_f = local_fftn(proj_vwu[..., ii, :], s=list(new_fft_shapes)) + slices[stack_axis] = slice(ii_a, ii_a + 1) + data_vu = data_vwu[tuple(slices)].squeeze(axis=stack_axis) + if proj_vwu.shape[stack_axis] == 1: + proj_vu = proj_vwu.squeeze(axis=stack_axis) + else: + proj_vu = proj_vwu[tuple(slices)].squeeze(axis=stack_axis) + data_vwu_f = local_fftn(data_vu, s=list(new_fft_shapes)) + proj_vwu_f = local_fftn(proj_vu, s=list(new_fft_shapes)) cc_f = data_vwu_f * proj_vwu_f.conj() if normalize_fourier: cc_f /= np.fmax(np.abs(cc_f), eps) - cc: NDArrayFloat = local_ifftn(cc_f).real + cc_r: NDArrayFloat = local_ifftn(cc_f).real - f_vals, f_coords = extract_peak_region_nd(cc, cc_coords=cc_coords) - shifts_vu[..., ii] = np.array([f_coords[0][1], f_coords[1][1]]) + f_vals, f_coords = extract_peak_region_nd(cc_r, cc_coords=cc_coords) + shifts_vu[..., ii_a] = np.array([f_coords[0][1], f_coords[1][1]]) if decimals > 0: f_vals_v = f_vals[:, 1] @@ -132,18 +143,18 @@ def fit_shifts_vu_xc( sub_pixel_v = refine_max_position_1d(f_vals_v, decimals=decimals) sub_pixel_u = refine_max_position_1d(f_vals_u, decimals=decimals) - shifts_vu[..., ii] += [sub_pixel_v, sub_pixel_u] + shifts_vu[..., ii_a] += [sub_pixel_v, sub_pixel_u] else: data_vwu_f = local_fftn(data_vwu, s=list(new_fft_shapes), axes=list(fft_dims)) proj_vwu_f = local_fftn(proj_vwu, s=list(new_fft_shapes), axes=list(fft_dims)) ccs_f = data_vwu_f * proj_vwu_f.conj() if normalize_fourier: - ccs_f /= np.fmax(np.abs(ccs_f).max(axis=-1, keepdims=True), eps) + ccs_f /= np.fmax(np.abs(ccs_f).max(axis=u_axis, keepdims=True), eps) ccs = local_ifftn(ccs_f, axes=fft_dims).real - f_vals, fh = extract_peak_regions_1d(ccs, axis=-1, cc_coords=cc_coords[-1]) - shifts_vu = fh[1, :] + f_vals, f_h = extract_peak_regions_1d(ccs, axis=u_axis, cc_coords=cc_coords[u_axis]) + shifts_vu = f_h[1, :] if decimals > 0: shifts_vu += refine_max_position_1d(f_vals, decimals=decimals) @@ -347,7 +358,7 @@ def extract_peak_regions_1d( def refine_max_position_1d( - f_vals: NDArrayFloat, fx: Union[ArrayLike, NDArray, None] = None, return_vertex_val: bool = False, decimals: int = 2 + f_vals: NDArrayFloat, f_x: Union[ArrayLike, NDArray, None] = None, return_vertex_val: bool = False, decimals: int = 2 ) -> Union[NDArrayFloat, tuple[NDArrayFloat, NDArrayFloat]]: """Compute the sub-pixel max position of the given function sampling. @@ -378,24 +389,24 @@ def refine_max_position_1d( ) num_vals = f_vals.shape[0] - if fx is None: - fx_half_size = (num_vals - 1) / 2 - fx = np.linspace(-fx_half_size, fx_half_size, num_vals) + if f_x is None: + f_x_half_size = (num_vals - 1) / 2 + f_x = np.linspace(-f_x_half_size, f_x_half_size, num_vals) else: - fx = np.squeeze(fx) - if not (len(fx.shape) == 1 and np.all(fx.size == num_vals)): + f_x = np.squeeze(f_x) + if not (len(f_x.shape) == 1 and np.all(f_x.size == num_vals)): raise ValueError( "Base coordinates should have the same length as values array. Sizes of fx: %d, f_vals: %d" - % (fx.size, num_vals) + % (f_x.size, num_vals) ) if len(f_vals.shape) == 1: # using Polynomial.fit, because supposed to be more numerically # stable than previous solutions (according to numpy). - poly = np.polynomial.Polynomial.fit(fx, f_vals, deg=2) + poly = Polynomial.fit(f_x, f_vals, deg=2) coeffs = poly.convert().coef else: - coords = np.array([np.ones(num_vals), fx, fx**2]) + coords = np.array([np.ones(num_vals), f_x, f_x**2]) coeffs = np.linalg.lstsq(coords.T, f_vals, rcond=None)[0] # For a 1D parabola `f(x) = c + bx + ax^2`, the vertex position is: @@ -403,22 +414,20 @@ def refine_max_position_1d( vertex_x = -coeffs[1, ...] / (2 * coeffs[2, ...]) vertex_x = np.around(vertex_x, decimals=decimals) - vertex_min_x = np.min(fx) - vertex_max_x = np.max(fx) + vertex_min_x = np.min(f_x) + vertex_max_x = np.max(f_x) lower_bound_ok = vertex_min_x < vertex_x upper_bound_ok = vertex_x < vertex_max_x if not np.all(lower_bound_ok * upper_bound_ok): if len(f_vals.shape) == 1: message = ( f"Fitted position {vertex_x} is outide the input margins [{vertex_min_x}, {vertex_max_x}]." - + f" Input values: {f_vals}" + f" Input values: {f_vals}" ) else: - message = "Fitted positions outide the input margins [{}, {}]: {} below and {} above".format( - vertex_min_x, - vertex_max_x, - np.sum(1 - lower_bound_ok), - np.sum(1 - upper_bound_ok), + message = ( + f"Fitted positions outide the input margins [{vertex_min_x}, {vertex_max_x}]:" + f" {np.sum(1 - lower_bound_ok)} below and {np.sum(1 - upper_bound_ok)} above" ) raise ValueError(message)