Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add an MLEM solver #20

Merged
merged 19 commits into from
Dec 9, 2023
Merged
177 changes: 172 additions & 5 deletions corrct/solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,12 +230,14 @@ def _initialize_data_fidelity_function(data_term: Union[str, DataFidelityBase])
if isinstance(data_term, str):
if data_term.lower() == "l2":
return data_terms.DataFidelity_l2()
elif data_term.lower() == "kl":
return data_terms.DataFidelity_KL()
else:
raise ValueError('Unknown data term: "%s", only accepted terms are: "l2".' % data_term)
raise ValueError(f"Unknown data term: '{data_term}', only accepted terms are: 'l2' | 'kl'.")
elif isinstance(data_term, (data_terms.DataFidelity_l2, data_terms.DataFidelity_KL)):
return cp.deepcopy(data_term)
else:
raise ValueError('Unsupported data term: "%s", only accepted terms are "l2"-based.' % data_term.info())
raise ValueError(f"Unsupported data term: '{data_term.info()}', only accepted terms are 'kl' and 'l2'-based.")

@staticmethod
def _initialize_regularizer(
Expand All @@ -249,8 +251,8 @@ def _initialize_regularizer(
check_regs_ok = [isinstance(r, BaseRegularizer) for r in regularizer]
if not np.all(check_regs_ok):
raise ValueError(
"The following regularizers are not derived from the BaseRegularizer class: %s"
% np.array(np.arange(len(check_regs_ok))[np.array(check_regs_ok, dtype=bool)])
"The following regularizers are not derived from the BaseRegularizer class: "
f"{np.array(np.arange(len(check_regs_ok))[np.array(check_regs_ok, dtype=bool)])}"
)
else:
return list(regularizer)
Expand Down Expand Up @@ -551,6 +553,171 @@ def __call__( # noqa: C901
return x, info


class MLEM(Solver):
"""
Initialize the MLEM solver class.

This class implements the Maximul Likelihood Expectation Maximization (MLEM) algorithm.

Parameters
----------
verbose : bool, optional
Turn on verbose output. The default is False.
tolerance : Optional[float], optional
Tolerance on the data residual for computing when to stop iterations.
The default is None.
regularizer : Sequence[BaseRegularizer] | BaseRegularizer | None, optional
Regularizer to be used. The default is None.
data_term : Union[str, DataFidelityBase], optional
Data fidelity term for computing the data residual. The default is "l2".
data_term_test : Optional[DataFidelityBase], optional
The data fidelity to be used for the test set.
If None, it will use the same as for the rest of the data.
The default is None.
"""

def __init__(
self,
verbose: bool = False,
tolerance: Optional[float] = None,
regularizer: Union[Sequence[BaseRegularizer], BaseRegularizer, None] = None,
data_term: Union[str, DataFidelityBase] = "kl",
data_term_test: Union[str, DataFidelityBase, None] = None,
):
super().__init__(verbose=verbose, tolerance=tolerance, data_term=data_term, data_term_test=data_term_test)
self.regularizer = self._initialize_regularizer(regularizer)

def info(self) -> str:
"""
Return the MLEM info.

Returns
-------
str
info string.
"""
return Solver.info(self) + f"(B:{self.data_term.background:g})" if self.data_term.background is not None else ""

def __call__( # noqa: C901
self,
A: operators.BaseTransform,
b: NDArrayFloat,
iterations: int,
x0: Optional[NDArrayFloat] = None,
lower_limit: Union[float, NDArrayFloat, None] = None,
upper_limit: Union[float, NDArrayFloat, None] = None,
x_mask: Optional[NDArrayFloat] = None,
b_mask: Optional[NDArrayFloat] = None,
b_test_mask: Optional[NDArrayFloat] = None,
) -> Tuple[NDArrayFloat, SolutionInfo]:
"""
Reconstruct the data, using the MLEM algorithm.

Parameters
----------
A : BaseTransform
Projection operator.
b : NDArrayFloat
Data to reconstruct.
iterations : int
Number of iterations.
x0 : Optional[NDArrayFloat], optional
Initial solution. The default is None.
lower_limit : Union[float, NDArrayFloat], optional
Lower clipping value. The default is None.
upper_limit : Union[float, NDArrayFloat], optional
Upper clipping value. The default is None.
x_mask : Optional[NDArrayFloat], optional
Solution mask. The default is None.
b_mask : Optional[NDArrayFloat], optional
Data mask. The default is None.
b_test_mask : Optional[NDArrayFloat], optional
Test data mask. The default is None.

Returns
-------
Tuple[NDArrayFloat, SolutionInfo]
The reconstruction, and the residuals.
"""
b = np.array(b)

(b_mask, b_test_mask) = self._initialize_b_masks(b, b_mask, b_test_mask)

# Back-projection diagonal re-scaling
b_ones = np.ones_like(b)
if b_mask is not None:
b_ones *= b_mask
tau = A.T(b_ones)

# Forward-projection diagonal re-scaling
x_ones = np.ones_like(tau)
if x_mask is not None:
x_ones *= x_mask
sigma = np.abs(A(x_ones))
sigma[(sigma / np.max(sigma)) < 1e-5] = 1
sigma = 1 / sigma

if x0 is None:
x = np.ones_like(tau)
else:
x = np.array(x0).copy()
if x_mask is not None:
x *= x_mask

self.data_term.assign_data(b)

info = SolutionInfo(self.info(), max_iterations=iterations, tolerance=self.tolerance)

if b_test_mask is not None or self.tolerance is not None:
Ax = A(x)

if b_test_mask is not None:
if self.data_term_test.background != self.data_term.background:
print("WARNING - the data_term and and data_term_test should have the same background. Making them equal.")
self.data_term_test.background = self.data_term.background
self.data_term_test.assign_data(b)

res_test_0 = self.data_term_test.compute_residual(Ax, mask=b_test_mask)
info.residual0_cv = self.data_term_test.compute_residual_norm(res_test_0)

if self.tolerance is not None:
res_0 = self.data_term.compute_residual(Ax, mask=b_mask)
info.residual0 = self.data_term.compute_residual_norm(res_0)

reg_info = "".join(["-" + r.info().upper() for r in self.regularizer])
algo_info = "- Performing %s-%s%s iterations: " % (self.upper(), self.data_term.upper(), reg_info)

for ii in tqdm(range(iterations), desc=algo_info, disable=(not self.verbose)):
info.iterations += 1

# The MLEM update
Ax = A(x)

if b_test_mask is not None:
res_test = self.data_term_test.compute_residual(Ax, mask=b_test_mask)
info.residuals_cv[ii] = self.data_term_test.compute_residual_norm(res_test)

if self.tolerance is not None:
res = self.data_term.compute_residual(Ax, mask=b_mask)
info.residuals[ii] = self.data_term.compute_residual_norm(res)
if self.tolerance > info.residuals[ii]:
break

if self.data_term.background is not None:
Ax = Ax + self.data_term.background
Ax = Ax.clip(eps, None)

upd = A.T(b / Ax)
x *= upd / tau

if lower_limit is not None or upper_limit is not None:
x = x.clip(lower_limit, upper_limit)
if x_mask is not None:
x *= x_mask

return x, info


class SIRT(Solver):
"""
Initialize the SIRT solver class.
Expand Down Expand Up @@ -730,7 +897,7 @@ class PDHG(Solver):
"""
Initialize the PDHG solver class.

PDHG stands for primal-dual hybridg gradient algorithm from Chambolle and Pock.
PDHG stands for primal-dual hybrid gradient algorithm from Chambolle and Pock.

Parameters
----------
Expand Down
125 changes: 125 additions & 0 deletions examples/example_10_synthetic_phantom_MLEM_vs_SIRT.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
"""
This example compares the MLEM solver against SIRT and weighted least-squares
reconstructions (implemented with the PDHG algorithm) of the Shepp-Logan phantom.

@author: Jérome Lesaint, ESRF - The European Synchrotron, Grenoble, France
"""

import matplotlib.pyplot as plt
import numpy as np
from numpy.typing import ArrayLike
import corrct as cct


try:
import phantom
except ImportError:
cct.testing.download_phantom()
import phantom


def cm2inch(x: ArrayLike) -> tuple[float, float]:
"""Convert cm to inch.

Parameters
----------
x : ArrayLike
Sizes in cm.

Returns
-------
Tuple[float, float]
Sizes in inch.
"""
return tuple(np.array(x) / 2.54)


vol_shape = [256, 256, 3]
data_type = np.float32

ph_or = np.squeeze(phantom.modified_shepp_logan(vol_shape).astype(data_type))
ph_or = ph_or[:, :, 1]

(ph, vol_att_in, vol_att_out) = cct.testing.phantom_assign_concentration(ph_or)
# Create sino with no background noise.
(sino, angles, expected_ph, background_avg) = cct.testing.create_sino(
ph, 30, psf=None, add_poisson=True, dwell_time_s=1e-2, background_avg=1e-2, background_std=1e-4
)

num_iterations = 100
lower_limit = 0
vol_mask = cct.processing.circular_mask(ph_or.shape)

sino_variance = cct.processing.compute_variance_poisson(sino)
sino_weights = cct.processing.compute_variance_weight(sino_variance)

data_term_ls = cct.solvers.DataFidelity_l2(background=background_avg)
data_term_lsw = cct.solvers.DataFidelity_wl2(sino_weights, background=background_avg)
data_term_kl = cct.solvers.DataFidelity_KL(background=background_avg)

with cct.projectors.ProjectorUncorrected(ph.shape, angles) as A:
solver_pdhg = cct.solvers.PDHG(verbose=True, data_term=data_term_lsw)
rec_pdhg, _ = solver_pdhg(A, sino, num_iterations, x_mask=vol_mask)

solver_sirt = cct.solvers.SIRT(verbose=True, data_term=data_term_ls)
rec_sirt, _ = solver_sirt(A, sino, num_iterations, x_mask=vol_mask)

solver_mlem = cct.solvers.MLEM(verbose=True, data_term=data_term_kl)
rec_mlem, _ = solver_mlem(A, sino, num_iterations, x_mask=vol_mask)


# Reconstructions
fig = plt.figure(figsize=cm2inch([36, 24]))
gs = fig.add_gridspec(8, 3)
ax_ph = fig.add_subplot(gs[:4, 0])
im_ph = ax_ph.imshow(expected_ph, vmin=0.0, vmax=3.0)
ax_ph.set_title("Phantom")
fig.colorbar(im_ph, ax=ax_ph)

ax_sino_clean = fig.add_subplot(gs[4, 0])
with cct.projectors.ProjectorUncorrected(ph_or.shape, angles) as p:
sino_clean = p.fp(expected_ph)
im_sino_clean = ax_sino_clean.imshow(sino_clean)
ax_sino_clean.set_title("Clean sinogram")

ax_sino_noise = fig.add_subplot(gs[5, 0])
im_sino_noise = ax_sino_noise.imshow(sino - background_avg)
ax_sino_noise.set_title("Noisy sinogram")

ax_sino_lines = fig.add_subplot(gs[6:, 0])
im_sino_lines = ax_sino_lines.plot(sino[9, :] - background_avg, label="Noisy")
im_sino_lines = ax_sino_lines.plot(sino_clean[9, :], label="Clean")
ax_sino_lines.set_title("Sinograms - angle: 10")
ax_sino_lines.legend()
ax_sino_lines.grid()

ax_wls_l = fig.add_subplot(gs[:4, 1], sharex=ax_ph, sharey=ax_ph)
im_wls_l = ax_wls_l.imshow(np.squeeze(rec_pdhg), vmin=0.0, vmax=3.0)
ax_wls_l.set_title(solver_pdhg.info().upper())
fig.colorbar(im_wls_l, ax=ax_wls_l)

ax_sirt = fig.add_subplot(gs[4:, 1], sharex=ax_ph, sharey=ax_ph)
im_sirt = ax_sirt.imshow(np.squeeze(rec_sirt), vmin=0.0, vmax=3.0)
ax_sirt.set_title(solver_sirt.info().upper())
fig.colorbar(im_sirt, ax=ax_sirt)

ax_mlem = fig.add_subplot(gs[:4, 2], sharex=ax_ph, sharey=ax_ph)
im_mlem = ax_mlem.imshow(np.squeeze(rec_mlem), vmin=0.0, vmax=3.0)
ax_mlem.set_title(solver_mlem.info().upper())
fig.colorbar(im_mlem, ax=ax_mlem)

axs = fig.add_subplot(gs[4:, 2])
axs.plot(np.squeeze(expected_ph[..., 172]), label="Phantom")
axs.plot(np.squeeze(rec_pdhg[..., 172]), label=solver_pdhg.info().upper())
axs.plot(np.squeeze(rec_sirt[..., 172]), label=solver_sirt.info().upper())
axs.plot(np.squeeze(rec_mlem[..., 172]), label=solver_mlem.info().upper())
axs.grid()
axs.legend()
fig.tight_layout()

# Comparing FRCs for each reconstruction
labels = [solver_pdhg.info().upper(), solver_sirt.info().upper(), solver_mlem.info().upper()]
vols = [rec_pdhg, rec_sirt, rec_mlem]
cct.processing.post.plot_frcs([(expected_ph, rec) for rec in vols], labels=labels, snrt=0.4142)

plt.show(block=False)
Loading