-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
cursed BSREM with accumulated gradient and weird preconditioner
- Loading branch information
1 parent
1df9cac
commit 9451925
Showing
2 changed files
with
177 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,127 @@ | ||
# | ||
# SPDX-License-Identifier: Apache-2.0 | ||
# | ||
# Classes implementing the BSREM algorithm in sirf.STIR | ||
# | ||
# Authors: Kris Thielemans | ||
# | ||
# Copyright 2024 University College London | ||
|
||
|
||
import sirf.STIR as STIR | ||
from sirf.Utilities import examples_data_path | ||
|
||
from cil.optimisation.algorithms import Algorithm | ||
from herman_meyer import herman_meyer_order | ||
import numpy as np | ||
import torch | ||
|
||
|
||
class RDPDiagHessTorch: | ||
def __init__(self, rdp_diag_hess, prior): | ||
self.epsilon = prior.get_epsilon() | ||
self.gamma = prior.get_gamma() | ||
self.penalty_strength = prior.get_penalisation_factor() | ||
self.rdp_diag_hess = rdp_diag_hess | ||
self.weights = torch.zeros([3,3,3]).cuda() | ||
self.kappa = torch.tensor(prior.get_kappa().as_array()).cuda() | ||
self.kappa_padded = torch.nn.functional.pad(self.kappa[None], pad=(1, 1, 1, 1, 1, 1), mode='replicate')[0] | ||
voxel_sizes = rdp_diag_hess.voxel_sizes() | ||
z_dim, y_dim, x_dim = rdp_diag_hess.shape | ||
for i in range(3): | ||
for j in range(3): | ||
for k in range(3): | ||
self.weights[i,j,k] = voxel_sizes[2]/np.sqrt(((i-1)*voxel_sizes[0])**2 + ((j-1)*voxel_sizes[1])**2 + ((k-1)*voxel_sizes[2])**2) | ||
self.weights[1,1,1] = 0 | ||
self.z_dim = z_dim | ||
self.y_dim = y_dim | ||
self.x_dim = x_dim | ||
|
||
def compute(self, x): | ||
x = torch.tensor(x.as_array(), dtype=torch.float32).cuda() | ||
x_padded = torch.nn.functional.pad(x[None], pad=(1, 1, 1, 1, 1, 1), mode='replicate')[0] | ||
x_rdp_diag_hess = torch.zeros_like(x) | ||
for dz in range(3): | ||
for dy in range(3): | ||
for dx in range(3): | ||
x_neighbour = x_padded[dz:dz+self.z_dim, dy:dy+self.y_dim, dx:dx+self.x_dim] | ||
kappa_neighbour = self.kappa_padded[dz:dz+self.z_dim, dy:dy+self.y_dim, dx:dx+self.x_dim] | ||
kappa_val = self.kappa * kappa_neighbour | ||
numerator = 4 * (2 * x_neighbour + self.epsilon) ** 2 | ||
denominator = (x + x_neighbour + self.gamma * torch.abs(x - x_neighbour) + self.epsilon) ** 3 | ||
x_rdp_diag_hess += self.weights[dz, dy, dx] * self.penalty_strength * kappa_val * numerator / denominator | ||
return self.rdp_diag_hess.fill(x_rdp_diag_hess.cpu().numpy()) | ||
|
||
|
||
class BSREMSkeleton(Algorithm): | ||
def __init__(self, data, | ||
update_filter=STIR.TruncateToCylinderProcessor(), **kwargs): | ||
super().__init__(**kwargs) | ||
initial = self.dataset.OSEM_image | ||
self.x = initial.copy() | ||
self.data = data | ||
self.num_subsets = len(data) | ||
self.g = initial.get_uniform_copy(0) | ||
self.precond = initial.get_uniform_copy(0) | ||
self.rdp_diag_hess_obj = RDPDiagHessTorch(self.dataset.OSEM_image.copy(), self.dataset.prior) | ||
self.rdp_diag_hess = self.rdp_diag_hess_obj.compute(self.x) | ||
self.precond = self.dataset.kappa + self.rdp_diag_hess + 1e-9 | ||
self.x_update = initial.get_uniform_copy(0) | ||
self.eps = initial.max()/1e3 | ||
self.subset = 0 | ||
self.update_filter = update_filter | ||
self.subset_order = herman_meyer_order(self.num_subsets) | ||
self.configured = True | ||
print("Configured cursed_BSREM") | ||
|
||
class cursed_BSREM(BSREMSkeleton): | ||
def __init__(self, data, obj_funs, accumulate_gradient_iter, accumulate_gradient_num, update_rdp_diag_hess_iter, **kwargs): | ||
|
||
self.obj_funs = obj_funs | ||
|
||
super().__init__(data, **kwargs) | ||
self.update_rdp_diag_hess_iter = update_rdp_diag_hess_iter | ||
self.accumulate_gradient_iter = accumulate_gradient_iter | ||
self.accumulate_gradient_num = accumulate_gradient_num | ||
#self.accumulate_gradient_iter = [10, 15, 20] | ||
# check list of accumulate_gradient_iter is monotonically increasing | ||
assert all(self.accumulate_gradient_iter[i] < self.accumulate_gradient_iter[i+1] for i in range(len(self.accumulate_gradient_iter)-1)) | ||
#self.accumulate_gradient_num = [1, 10, 20] | ||
# check if accumulate_gradient_iter and accumulate_gradient_num have the same length | ||
assert len(self.accumulate_gradient_iter) == len(self.accumulate_gradient_num) | ||
|
||
def get_number_of_subsets_to_accumulate_gradient(self): | ||
for index, boundary in enumerate(self.accumulate_gradient_iter): | ||
if self.iteration < boundary: | ||
return self.accumulate_gradient_num[index] | ||
return self.num_subsets | ||
|
||
def update(self): | ||
num_to_accumulate = self.get_number_of_subsets_to_accumulate_gradient() | ||
for i in range(num_to_accumulate): | ||
if i == 0: | ||
self.g = self.obj_funs[self.subset_order[self.subset]].gradient(self.x) | ||
else: | ||
self.g += self.obj_funs[self.subset_order[self.subset]].gradient(self.x) | ||
self.subset = (self.subset + 1) % self.num_subsets | ||
#print(f"\n Added subset {i+1} (i.e. {self.subset}) of {num_to_accumulate}\n") | ||
self.g /= num_to_accumulate | ||
if self.iteration in self.update_rdp_diag_hess_iter: | ||
self.rdp_diag_hess = self.rdp_diag_hess_obj.compute(self.x) | ||
self.precond = self.dataset.kappa + self.rdp_diag_hess + 1e-9 | ||
self.x_update = self.g / self.precond | ||
if self.update_filter is not None: | ||
self.update_filter.apply(self.x_update) | ||
self.x += self.x_update | ||
self.x.maximum(0, out=self.x) | ||
|
||
def update_objective(self): | ||
# required for current CIL (needs to set self.loss) | ||
self.loss.append(self.objective_function(self.x)) | ||
|
||
def objective_function(self, x): | ||
''' value of objective function summed over all subsets ''' | ||
v = 0 | ||
for s in range(len(self.data)): | ||
v += self.obj_funs[s](x) | ||
return v |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
|
||
from cil.optimisation.algorithms import Algorithm | ||
from cil.optimisation.utilities import callbacks | ||
from petric import Dataset | ||
from cursed_BSREM import cursed_BSREM | ||
from sirf.contrib.partitioner import partitioner | ||
|
||
assert issubclass(cursed_BSREM, Algorithm) | ||
|
||
class MaxIteration(callbacks.Callback): | ||
def __init__(self, max_iteration: int, verbose: int = 1): | ||
super().__init__(verbose) | ||
self.max_iteration = max_iteration | ||
|
||
def __call__(self, algorithm: Algorithm): | ||
if algorithm.iteration >= self.max_iteration: | ||
raise StopIteration | ||
|
||
|
||
class Submission(cursed_BSREM): | ||
# note that `issubclass(BSREM1, Algorithm) == True` | ||
def __init__(self, data: Dataset, | ||
num_subsets: int = 7, | ||
update_objective_interval: int = 10, | ||
**kwargs): | ||
|
||
mode = kwargs.get("mode", "staggered") | ||
accumulate_gradient_iter = kwargs.get("accumulate_gradient_iter", [10, 15, 20]) | ||
accumulate_gradient_num = kwargs.get("accumulate_gradient_num", [1, 10, 20]) | ||
update_rdp_diag_hess_iter = kwargs.get("update_rdp_diag_hess_iter", [10]) | ||
|
||
data_sub, _, obj_funs = partitioner.data_partition(data.acquired_data, data.additive_term, | ||
data.mult_factors, num_subsets, | ||
initial_image=data.OSEM_image, | ||
mode = mode) | ||
self.dataset = data | ||
# WARNING: modifies prior strength with 1/num_subsets (as currently needed for BSREM implementations) | ||
data.prior.set_penalisation_factor(data.prior.get_penalisation_factor() / len(obj_funs)) | ||
data.prior.set_up(data.OSEM_image) | ||
for f in obj_funs: # add prior evenly to every objective function | ||
f.set_prior(data.prior) | ||
|
||
super().__init__(data_sub, | ||
obj_funs, | ||
accumulate_gradient_iter, | ||
accumulate_gradient_num, | ||
update_rdp_diag_hess_iter, | ||
update_objective_interval=update_objective_interval) | ||
|
||
submission_callbacks = [] |