Skip to content

Commit

Permalink
cursed BSREM with accumulated gradient and weird preconditioner
Browse files Browse the repository at this point in the history
  • Loading branch information
Imraj-Singh committed Aug 9, 2024
1 parent 1df9cac commit 9451925
Show file tree
Hide file tree
Showing 2 changed files with 177 additions and 0 deletions.
127 changes: 127 additions & 0 deletions cursed_BSREM.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
#
# SPDX-License-Identifier: Apache-2.0
#
# Classes implementing the BSREM algorithm in sirf.STIR
#
# Authors: Kris Thielemans
#
# Copyright 2024 University College London


import sirf.STIR as STIR
from sirf.Utilities import examples_data_path

from cil.optimisation.algorithms import Algorithm
from herman_meyer import herman_meyer_order
import numpy as np
import torch


class RDPDiagHessTorch:
def __init__(self, rdp_diag_hess, prior):
self.epsilon = prior.get_epsilon()
self.gamma = prior.get_gamma()
self.penalty_strength = prior.get_penalisation_factor()
self.rdp_diag_hess = rdp_diag_hess
self.weights = torch.zeros([3,3,3]).cuda()
self.kappa = torch.tensor(prior.get_kappa().as_array()).cuda()
self.kappa_padded = torch.nn.functional.pad(self.kappa[None], pad=(1, 1, 1, 1, 1, 1), mode='replicate')[0]
voxel_sizes = rdp_diag_hess.voxel_sizes()
z_dim, y_dim, x_dim = rdp_diag_hess.shape
for i in range(3):
for j in range(3):
for k in range(3):
self.weights[i,j,k] = voxel_sizes[2]/np.sqrt(((i-1)*voxel_sizes[0])**2 + ((j-1)*voxel_sizes[1])**2 + ((k-1)*voxel_sizes[2])**2)
self.weights[1,1,1] = 0
self.z_dim = z_dim
self.y_dim = y_dim
self.x_dim = x_dim

def compute(self, x):
x = torch.tensor(x.as_array(), dtype=torch.float32).cuda()
x_padded = torch.nn.functional.pad(x[None], pad=(1, 1, 1, 1, 1, 1), mode='replicate')[0]
x_rdp_diag_hess = torch.zeros_like(x)
for dz in range(3):
for dy in range(3):
for dx in range(3):
x_neighbour = x_padded[dz:dz+self.z_dim, dy:dy+self.y_dim, dx:dx+self.x_dim]
kappa_neighbour = self.kappa_padded[dz:dz+self.z_dim, dy:dy+self.y_dim, dx:dx+self.x_dim]
kappa_val = self.kappa * kappa_neighbour
numerator = 4 * (2 * x_neighbour + self.epsilon) ** 2
denominator = (x + x_neighbour + self.gamma * torch.abs(x - x_neighbour) + self.epsilon) ** 3
x_rdp_diag_hess += self.weights[dz, dy, dx] * self.penalty_strength * kappa_val * numerator / denominator
return self.rdp_diag_hess.fill(x_rdp_diag_hess.cpu().numpy())


class BSREMSkeleton(Algorithm):
def __init__(self, data,
update_filter=STIR.TruncateToCylinderProcessor(), **kwargs):
super().__init__(**kwargs)
initial = self.dataset.OSEM_image
self.x = initial.copy()
self.data = data
self.num_subsets = len(data)
self.g = initial.get_uniform_copy(0)
self.precond = initial.get_uniform_copy(0)
self.rdp_diag_hess_obj = RDPDiagHessTorch(self.dataset.OSEM_image.copy(), self.dataset.prior)
self.rdp_diag_hess = self.rdp_diag_hess_obj.compute(self.x)
self.precond = self.dataset.kappa + self.rdp_diag_hess + 1e-9
self.x_update = initial.get_uniform_copy(0)
self.eps = initial.max()/1e3
self.subset = 0
self.update_filter = update_filter
self.subset_order = herman_meyer_order(self.num_subsets)
self.configured = True
print("Configured cursed_BSREM")

class cursed_BSREM(BSREMSkeleton):
def __init__(self, data, obj_funs, accumulate_gradient_iter, accumulate_gradient_num, update_rdp_diag_hess_iter, **kwargs):

self.obj_funs = obj_funs

super().__init__(data, **kwargs)
self.update_rdp_diag_hess_iter = update_rdp_diag_hess_iter
self.accumulate_gradient_iter = accumulate_gradient_iter
self.accumulate_gradient_num = accumulate_gradient_num
#self.accumulate_gradient_iter = [10, 15, 20]
# check list of accumulate_gradient_iter is monotonically increasing
assert all(self.accumulate_gradient_iter[i] < self.accumulate_gradient_iter[i+1] for i in range(len(self.accumulate_gradient_iter)-1))
#self.accumulate_gradient_num = [1, 10, 20]
# check if accumulate_gradient_iter and accumulate_gradient_num have the same length
assert len(self.accumulate_gradient_iter) == len(self.accumulate_gradient_num)

def get_number_of_subsets_to_accumulate_gradient(self):
for index, boundary in enumerate(self.accumulate_gradient_iter):
if self.iteration < boundary:
return self.accumulate_gradient_num[index]
return self.num_subsets

def update(self):
num_to_accumulate = self.get_number_of_subsets_to_accumulate_gradient()
for i in range(num_to_accumulate):
if i == 0:
self.g = self.obj_funs[self.subset_order[self.subset]].gradient(self.x)
else:
self.g += self.obj_funs[self.subset_order[self.subset]].gradient(self.x)
self.subset = (self.subset + 1) % self.num_subsets
#print(f"\n Added subset {i+1} (i.e. {self.subset}) of {num_to_accumulate}\n")
self.g /= num_to_accumulate
if self.iteration in self.update_rdp_diag_hess_iter:
self.rdp_diag_hess = self.rdp_diag_hess_obj.compute(self.x)
self.precond = self.dataset.kappa + self.rdp_diag_hess + 1e-9
self.x_update = self.g / self.precond
if self.update_filter is not None:
self.update_filter.apply(self.x_update)
self.x += self.x_update
self.x.maximum(0, out=self.x)

def update_objective(self):
# required for current CIL (needs to set self.loss)
self.loss.append(self.objective_function(self.x))

def objective_function(self, x):
''' value of objective function summed over all subsets '''
v = 0
for s in range(len(self.data)):
v += self.obj_funs[s](x)
return v
50 changes: 50 additions & 0 deletions main_cursed_BSREM.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@

from cil.optimisation.algorithms import Algorithm
from cil.optimisation.utilities import callbacks
from petric import Dataset
from cursed_BSREM import cursed_BSREM
from sirf.contrib.partitioner import partitioner

assert issubclass(cursed_BSREM, Algorithm)

class MaxIteration(callbacks.Callback):
def __init__(self, max_iteration: int, verbose: int = 1):
super().__init__(verbose)
self.max_iteration = max_iteration

def __call__(self, algorithm: Algorithm):
if algorithm.iteration >= self.max_iteration:
raise StopIteration


class Submission(cursed_BSREM):
# note that `issubclass(BSREM1, Algorithm) == True`
def __init__(self, data: Dataset,
num_subsets: int = 7,
update_objective_interval: int = 10,
**kwargs):

mode = kwargs.get("mode", "staggered")
accumulate_gradient_iter = kwargs.get("accumulate_gradient_iter", [10, 15, 20])
accumulate_gradient_num = kwargs.get("accumulate_gradient_num", [1, 10, 20])
update_rdp_diag_hess_iter = kwargs.get("update_rdp_diag_hess_iter", [10])

data_sub, _, obj_funs = partitioner.data_partition(data.acquired_data, data.additive_term,
data.mult_factors, num_subsets,
initial_image=data.OSEM_image,
mode = mode)
self.dataset = data
# WARNING: modifies prior strength with 1/num_subsets (as currently needed for BSREM implementations)
data.prior.set_penalisation_factor(data.prior.get_penalisation_factor() / len(obj_funs))
data.prior.set_up(data.OSEM_image)
for f in obj_funs: # add prior evenly to every objective function
f.set_prior(data.prior)

super().__init__(data_sub,
obj_funs,
accumulate_gradient_iter,
accumulate_gradient_num,
update_rdp_diag_hess_iter,
update_objective_interval=update_objective_interval)

submission_callbacks = []

0 comments on commit 9451925

Please sign in to comment.