Skip to content

Commit

Permalink
Alignment: added support for choosing the axis in cross-correlation s…
Browse files Browse the repository at this point in the history
…hift fitting

Signed-off-by: Nicola VIGANO <[email protected]>
  • Loading branch information
Obi-Wan committed Aug 25, 2023
1 parent a27fafd commit 01cde9c
Showing 1 changed file with 42 additions and 33 deletions.
75 changes: 42 additions & 33 deletions corrct/alignment/fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,10 @@
and ESRF - The European Synchrotron, Grenoble, France
"""

from collections.abc import Sequence
from typing import Optional, Union
from typing import Literal, Optional, Union, Sequence

import numpy as np
import numpy.polynomial
from numpy.polynomial import Polynomial
import scipy.ndimage as spimg
import scipy.optimize as spopt
from numpy.typing import ArrayLike, NDArray
Expand Down Expand Up @@ -69,6 +68,7 @@ def fit_shifts_vu_xc(
pad_u: bool = False,
normalize_fourier: bool = True,
use_rfft: bool = True,
stack_axis: int = -2,
decimals: int = 2,
) -> NDArrayFloat:
"""
Expand All @@ -86,6 +86,8 @@ def fit_shifts_vu_xc(
Whether to normalize the Fourier representation of the cross-correlation. The default is True.
use_rfft : bool, optional
Whether to use the `rfft` transform in place of the complex `fft` transform. The default is True.
stack_axis : int, optional
The axis along which the VU images are stacked. The default is -2.
decimals : int, optional
Decimals for the truncation of the sub-pixel The default is 2.
Expand All @@ -94,7 +96,7 @@ def fit_shifts_vu_xc(
NDArrayFloat
The VU shifts.
"""
num_angles = data_vwu.shape[-2]
num_angles = data_vwu.shape[stack_axis]

if use_rfft:
local_fftn = np.fft.rfftn
Expand All @@ -103,27 +105,36 @@ def fit_shifts_vu_xc(
local_fftn = np.fft.fftn
local_ifftn = np.fft.ifftn

fft_dims = np.delete(np.arange(-len(data_vwu.shape), 0), -2)
fft_dims = np.delete(np.arange(-len(data_vwu.shape), 0), stack_axis)
u_axis = fft_dims[-1]

old_fft_shapes = np.array(np.array(data_vwu.shape)[fft_dims], ndmin=1, dtype=int)
new_fft_shapes = old_fft_shapes.copy()
if pad_u:
new_fft_shapes[-1] *= 2
new_fft_shapes[u_axis] *= 2
cc_coords = [np.fft.fftfreq(s, 1 / s) for s in new_fft_shapes]

if len(fft_dims) == 2:
shifts_vu = np.empty((len(data_vwu.shape) - 1, num_angles))
for ii in range(num_angles):
slices = [slice(None)] * len(data_vwu.shape)
for ii_a in range(num_angles):
# For performance reasons, it is better to do the fft on each image
data_vwu_f = local_fftn(data_vwu[..., ii, :], s=list(new_fft_shapes))
proj_vwu_f = local_fftn(proj_vwu[..., ii, :], s=list(new_fft_shapes))
slices[stack_axis] = slice(ii_a, ii_a + 1)
data_vu = data_vwu[tuple(slices)].squeeze(axis=stack_axis)
if proj_vwu.shape[stack_axis] == 1:
proj_vu = proj_vwu.squeeze(axis=stack_axis)
else:
proj_vu = proj_vwu[tuple(slices)].squeeze(axis=stack_axis)
data_vwu_f = local_fftn(data_vu, s=list(new_fft_shapes))
proj_vwu_f = local_fftn(proj_vu, s=list(new_fft_shapes))

cc_f = data_vwu_f * proj_vwu_f.conj()
if normalize_fourier:
cc_f /= np.fmax(np.abs(cc_f), eps)
cc: NDArrayFloat = local_ifftn(cc_f).real
cc_r: NDArrayFloat = local_ifftn(cc_f).real

f_vals, f_coords = extract_peak_region_nd(cc, cc_coords=cc_coords)
shifts_vu[..., ii] = np.array([f_coords[0][1], f_coords[1][1]])
f_vals, f_coords = extract_peak_region_nd(cc_r, cc_coords=cc_coords)
shifts_vu[..., ii_a] = np.array([f_coords[0][1], f_coords[1][1]])

if decimals > 0:
f_vals_v = f_vals[:, 1]
Expand All @@ -132,18 +143,18 @@ def fit_shifts_vu_xc(
sub_pixel_v = refine_max_position_1d(f_vals_v, decimals=decimals)
sub_pixel_u = refine_max_position_1d(f_vals_u, decimals=decimals)

shifts_vu[..., ii] += [sub_pixel_v, sub_pixel_u]
shifts_vu[..., ii_a] += [sub_pixel_v, sub_pixel_u]
else:
data_vwu_f = local_fftn(data_vwu, s=list(new_fft_shapes), axes=list(fft_dims))
proj_vwu_f = local_fftn(proj_vwu, s=list(new_fft_shapes), axes=list(fft_dims))

ccs_f = data_vwu_f * proj_vwu_f.conj()
if normalize_fourier:
ccs_f /= np.fmax(np.abs(ccs_f).max(axis=-1, keepdims=True), eps)
ccs_f /= np.fmax(np.abs(ccs_f).max(axis=u_axis, keepdims=True), eps)
ccs = local_ifftn(ccs_f, axes=fft_dims).real

f_vals, fh = extract_peak_regions_1d(ccs, axis=-1, cc_coords=cc_coords[-1])
shifts_vu = fh[1, :]
f_vals, f_h = extract_peak_regions_1d(ccs, axis=u_axis, cc_coords=cc_coords[u_axis])
shifts_vu = f_h[1, :]

if decimals > 0:
shifts_vu += refine_max_position_1d(f_vals, decimals=decimals)
Expand Down Expand Up @@ -347,7 +358,7 @@ def extract_peak_regions_1d(


def refine_max_position_1d(
f_vals: NDArrayFloat, fx: Union[ArrayLike, NDArray, None] = None, return_vertex_val: bool = False, decimals: int = 2
f_vals: NDArrayFloat, f_x: Union[ArrayLike, NDArray, None] = None, return_vertex_val: bool = False, decimals: int = 2
) -> Union[NDArrayFloat, tuple[NDArrayFloat, NDArrayFloat]]:
"""Compute the sub-pixel max position of the given function sampling.
Expand Down Expand Up @@ -378,47 +389,45 @@ def refine_max_position_1d(
)
num_vals = f_vals.shape[0]

if fx is None:
fx_half_size = (num_vals - 1) / 2
fx = np.linspace(-fx_half_size, fx_half_size, num_vals)
if f_x is None:
f_x_half_size = (num_vals - 1) / 2
f_x = np.linspace(-f_x_half_size, f_x_half_size, num_vals)
else:
fx = np.squeeze(fx)
if not (len(fx.shape) == 1 and np.all(fx.size == num_vals)):
f_x = np.squeeze(f_x)
if not (len(f_x.shape) == 1 and np.all(f_x.size == num_vals)):
raise ValueError(
"Base coordinates should have the same length as values array. Sizes of fx: %d, f_vals: %d"
% (fx.size, num_vals)
% (f_x.size, num_vals)
)

if len(f_vals.shape) == 1:
# using Polynomial.fit, because supposed to be more numerically
# stable than previous solutions (according to numpy).
poly = np.polynomial.Polynomial.fit(fx, f_vals, deg=2)
poly = Polynomial.fit(f_x, f_vals, deg=2)
coeffs = poly.convert().coef
else:
coords = np.array([np.ones(num_vals), fx, fx**2])
coords = np.array([np.ones(num_vals), f_x, f_x**2])
coeffs = np.linalg.lstsq(coords.T, f_vals, rcond=None)[0]

# For a 1D parabola `f(x) = c + bx + ax^2`, the vertex position is:
# x_v = -b / 2a.
vertex_x = -coeffs[1, ...] / (2 * coeffs[2, ...])
vertex_x = np.around(vertex_x, decimals=decimals)

vertex_min_x = np.min(fx)
vertex_max_x = np.max(fx)
vertex_min_x = np.min(f_x)
vertex_max_x = np.max(f_x)
lower_bound_ok = vertex_min_x < vertex_x
upper_bound_ok = vertex_x < vertex_max_x
if not np.all(lower_bound_ok * upper_bound_ok):
if len(f_vals.shape) == 1:
message = (
f"Fitted position {vertex_x} is outide the input margins [{vertex_min_x}, {vertex_max_x}]."
+ f" Input values: {f_vals}"
f" Input values: {f_vals}"
)
else:
message = "Fitted positions outide the input margins [{}, {}]: {} below and {} above".format(
vertex_min_x,
vertex_max_x,
np.sum(1 - lower_bound_ok),
np.sum(1 - upper_bound_ok),
message = (
f"Fitted positions outide the input margins [{vertex_min_x}, {vertex_max_x}]:"
f" {np.sum(1 - lower_bound_ok)} below and {np.sum(1 - upper_bound_ok)} above"
)
raise ValueError(message)

Expand Down

0 comments on commit 01cde9c

Please sign in to comment.