Skip to content

Commit

Permalink
BSREM bb
Browse files Browse the repository at this point in the history
  • Loading branch information
alexdenker committed Sep 4, 2024
1 parent a7967fc commit d322121
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 5 deletions.
23 changes: 20 additions & 3 deletions BSREM.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
# Copyright 2024 University College London

import numpy
import numpy as np
import sirf.STIR as STIR
from sirf.Utilities import examples_data_path

Expand Down Expand Up @@ -59,6 +60,9 @@ def __init__(self, data, initial, initial_step_size, relaxation_eta,

self.alpha = None

self.x_prev = None
self.x_update_prev = None

def subset_sensitivity(self, subset_num):
raise NotImplementedError

Expand All @@ -74,12 +78,25 @@ def step_size(self):
def update(self):
g = self.subset_gradient(self.x, self.subset_order[self.subset])

self.alpha = self.step_size()
self.x_update = (self.x + self.eps) * g / self.average_sensitivity * self.alpha
self.x_update = (self.x + self.eps) * g / self.average_sensitivity

if self.iteration == 0:
self.alpha = self.initial_step_size
else:
delta_x = self.x - self.x_prev
delta_g = self.x_update_prev - self.x_update

self.alpha = delta_x.norm()**2 / (delta_x * delta_g).sum()

if self.update_filter is not None:
self.update_filter.apply(self.x_update)
self.x += self.x_update

self.x_prev = self.x.clone()
self.x_update_prev = self.x_update.clone()

#print("Use step size: ", self.alpha)
#print(self.x_update)
self.x += self.x_update * self.alpha
# threshold to non-negative
self.x.maximum(0, out=self.x)
self.subset = (self.subset + 1) % self.num_subsets
Expand Down
4 changes: 2 additions & 2 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,6 @@ def apply(self, algorithm, gradient, out=None):
class Submission(BSREM):
# note that `issubclass(BSREM1, Algorithm) == True`
def __init__(self, data,
num_subsets: int = 7,
update_objective_interval: int = 10,
preconditioner = "osem",
**kwargs):
Expand All @@ -116,7 +115,8 @@ def __init__(self, data,
This is just an example. Try to modify and improve it!
"""
mode = kwargs.get("mode", "sequential")
num_subsets = 1
mode = kwargs.get("num_subsets", 1)

data_sub, acq_models, obj_funs = partitioner.data_partition(data.acquired_data, data.additive_term,
data.mult_factors, num_subsets,
initial_image=data.OSEM_image,
Expand Down

0 comments on commit d322121

Please sign in to comment.