Skip to content

Commit

Permalink
fix: asuper-resolution ow works when given multiple repeats
Browse files Browse the repository at this point in the history
  • Loading branch information
brudfors committed Aug 10, 2021
1 parent d552ba3 commit fdeee7b
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 20 deletions.
5 changes: 4 additions & 1 deletion unires/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,7 +463,10 @@ def _resample_inplane(x, sett):
# make grid
D = I.clone()
for i in range(3):
D[i, i] = sett.vx[i] / vx_x[i]
if isinstance(sett.vx, (list, tuple)):
D[i, i] = sett.vx[i] / vx_x[i]
else:
D[i, i] = sett.vx / vx_x[i]
if D[i, i] < 1.0:
D[i, i] = 1
if float((I - D).abs().sum()) < 1e-4:
Expand Down
31 changes: 17 additions & 14 deletions unires/_project.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,30 +67,33 @@ def _proj(operator, dat, x, y, method='super-resolution', do=True,
diff (str, optional): Gradient difference operator, defaults to 'forward'.
Returns:
dat (torch.tensor()): Projected image data (dim_y|dim_x).
dat_p (torch.tensor()): Projected image data (dim_y|dim_x).
"""
if operator == 'AtA':
# AtA
# (dat = y)
if not do: # return dat
operator = 'none'
dat1 = rho * y.lam ** 2 * _DtD(dat, vx_y=vx_y, bound=bound, diff=diff)
operator = 'none'
dat = dat[None, None, ...]
dat = x[n].tau * _proj_apply(operator, dat, x[n].po, method=method,
# sum likelihood terms
dat_p = x[n].tau * _proj_apply(operator, dat, x[n].po, method=method,
bound=bound, interpolation=interpolation)
for n1 in range(1, len(x)):
dat = dat + x[n1].tau * _proj_apply(operator, dat, x[n1].po, method=method,
bound=bound, interpolation=interpolation)
dat = dat[0, 0, ...]
dat += dat1
else: # A, At
dat_p += x[n1].tau * _proj_apply(operator, dat, x[n1].po, method=method,
bound=bound, interpolation=interpolation)
dat_p = dat_p[0, 0, ...]
# add prior term
dat_p += rho * y.lam ** 2 * _DtD(dat[0, 0, ...], vx_y=vx_y, bound=bound, diff=diff)
else:
# A, At
# (dat = x or y)
if not do: # return dat
operator = 'none'
dat = dat[None, None, ...]
dat = _proj_apply(operator, dat, x[n].po, method=method,
bound=bound, interpolation=interpolation)
dat = dat[0, 0, ...]
dat_p = _proj_apply(operator, dat[None, None, ...], x[n].po, method=method,
bound=bound, interpolation=interpolation)[0, 0, ...]

return dat
return dat_p


def _proj_apply(operator, dat, po, method='super-resolution', bound='zero', interpolation='linear'):
Expand Down
8 changes: 4 additions & 4 deletions unires/_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def _update_admm(x, y, z, w, rho, tmp, obj, n_iter, sett):
t0 = _print_info('fit-update', sett, 'y', n_iter) # PRINT
for c in range(len(x)): # Loop over channels
# RHS
tmp[:] = 0
tmp[:] = 0.0
for n in range(len(x[c])): # Loop over observations of channel 'c'
# _ = _print_info('int', sett, n) # PRINT
tmp += x[c][n].tau * _proj('At', x[c][n].dat, x[c], y[c], method=sett.method, do=sett.do_proj,
Expand Down Expand Up @@ -409,19 +409,19 @@ def _compute_nll(x, y, sett, rho, sum_dtype=torch.float64):
vx_y = voxel_size(y[0].mat).float()
nll_xy = torch.tensor(0, device=sett.device, dtype=torch.float64)
for c in range(len(x)):
# Neg. log-likelihood term
# Sum neg. log-likelihood term
for n in range(len(x[c])):
msk = x[c][n].dat != 0
Ay = _proj('A', y[c].dat, x[c], y[c],
n=n, method=sett.method, do=sett.do_proj, bound=sett.bound, interpolation=sett.interpolation)
nll_xy += 0.5 * x[c][n].tau * torch.sum((x[c][n].dat[msk] - Ay[msk]) ** 2, dtype=sum_dtype)
# Neg. log-prior term
# Sum neg. log-prior term
Dy = y[c].lam * im_gradient(y[c].dat, vx=vx_y, bound=sett.bound, which=sett.diff)
if c > 0:
nll_y += torch.sum(Dy ** 2, dim=0)
else:
nll_y = torch.sum(Dy ** 2, dim=0)

# Neg. log-prior term
nll_y = torch.sum(torch.sqrt(nll_y), dtype=sum_dtype)

return nll_xy + nll_y, nll_xy, nll_y
Expand Down
2 changes: 1 addition & 1 deletion unires/struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def __init__(self):
self.do_print: int = 1 # Print progress to terminal (0, 1, 2, 3)
self.do_proj: bool = None # Use projection matrices, defined in format_output()
self.do_res_origin: bool = False # Resets origin, if CT data
self.force_inplane_res: bool = True # Force in-plane resolution of observed data to be greater or equal to recon vx
self.force_inplane_res: bool = False # Force in-plane resolution of observed data to be greater or equal to recon vx
self.fov: str = 'brain' # If crop=True, uses this field-of-view ('brain'|'head').
self.gap: float = 0.0 # Slice gap, between 0 and 1
self.interpolation: str = 'linear' # Interpolation order (see nitorch.spatial)
Expand Down

0 comments on commit fdeee7b

Please sign in to comment.