Skip to content

Commit

Permalink
REFAC: removed warnings, work with latest nitorch
Browse files Browse the repository at this point in the history
  • Loading branch information
brudfors committed Dec 21, 2021
1 parent bfe2905 commit 39c70c4
Show file tree
Hide file tree
Showing 6 changed files with 28 additions and 23 deletions.
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@
description='UniRes: Unified Super-Resolution of Medical Imaging Data',
entry_points={'console_scripts': ['unires=unires._cli:run']},
install_requires=[
"nitorch[all]@git+https://github.com/balbasty/nitorch@0.1#egg=nitorch",
"nitorch[all]@git+https://github.com/balbasty/nitorch#ff6ab05c888325735ea344d9d924256653541700",
],
name='unires',
packages=find_packages(),
python_requires='>=3.6',
url='https://github.com/brudfors/UniRes',
version='0.0.1a',
version='0.1',
)
20 changes: 12 additions & 8 deletions unires/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def _crop_y(y, sett):
mat_mu = mat_mu.mm(mat_vx)
dim_mu = mat_vx[:3, :3].inverse().mm(dim_mu[:, None]).floor().squeeze()
# Make output grid
M = mat_mu.solve(mat_y)[0].type(y[0].dat.dtype)
M = torch.linalg.solve(mat_y, mat_mu).type(y[0].dat.dtype)
grid = affine_grid(M, dim_mu)[None, ...]
# Crop
for c in range(len(y)):
Expand Down Expand Up @@ -122,9 +122,13 @@ def _estimate_hyperpar(x, sett):
mu_fg = torch.tensor(4096, device=dat.device, dtype=dat.dtype)
else:
# Get noise and foreground statistics
sd_bg, sd_fg, mu_bg, mu_fg = estimate_noise(dat, num_class=2, show_fit=sett.show_hyperpar,
fig_num=100 + cnt)
mu_bg = torch.tensor(0.0, device=dat.device, dtype=dat.dtype)
prm_noise, prm_not_noise = estimate_noise(
dat, num_class=2, show_fit=sett.show_hyperpar,fig_num=100 + cnt
)
sd_bg = prm_noise['sd']
sd_fg = prm_not_noise['sd']
mu_bg = prm_noise['mean']
mu_fg = prm_not_noise['mean']
# Set values
x[c][n].sd = sd_bg.float()
x[c][n].tau = 1 / sd_bg.float() ** 2
Expand Down Expand Up @@ -328,7 +332,7 @@ def _init_reg(x, sett):
i = 0
for c in range(len(x)):
for n in range(len(x[c])):
imgs[i][1] = imgs[i][1].solve(mat_a[i, ...])[0]
imgs[i][1] = torch.linalg.solve(mat_a[i, ...], imgs[i][1])
i += 1
_print_info('init-reg', sett, 'co', 'finished', N, t0)

Expand All @@ -343,7 +347,7 @@ def _init_reg(x, sett):
i = 0
for c in range(len(x)):
for n in range(len(x[c])):
imgs[i][1] = imgs[i][1].solve(mat_a)[0]
imgs[i][1] = torch.linalg.solve(mat_a, imgs[i][1])
i += 1

# Modify image affine (label uses the same as the image, so no need to modify that one)
Expand Down Expand Up @@ -376,7 +380,7 @@ def _init_y_dat(x, y, sett):
# Get image data
dat = x[c][n].dat[None, None, ...]
# Make output grid
mat = mat_y.solve(x[c][n].mat)[0] # mat_x\mat_y
mat = torch.linalg.solve(x[c][n].mat, mat_y) # mat_x\mat_y
grid = affine_grid(mat.type(dat.dtype), dim_y)
# Do resampling
mn = torch.min(dat)
Expand All @@ -402,7 +406,7 @@ def _init_y_label(x, y, sett):
n = 0
if x[c][n].label is not None:
# Make output grid
mat = mat_y.solve(x[c][n].mat)[0] # mat_x\mat_y
mat = torch.linalg.solve(x[c][n].mat, mat_y) # mat_x\mat_y
grid = affine_grid(mat.type(x[c][n].dat.dtype), dim_y)
# Do resampling
y[c].label = _warp_label(x[c][n].label[0], grid)
Expand Down
8 changes: 4 additions & 4 deletions unires/_project.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,10 +144,10 @@ def _proj_apply(operator, dat, po, method='super-resolution', bound='zero', inte
dim_thick = po.dim_thick
if method == 'super-resolution':
dim = dim_yx
mat = rigid.mm(mat_yx).solve(mat_y)[0] # mat_y\rigid*mat_yx
mat = torch.linalg.solve(mat_y, rigid.mm(mat_yx)) # mat_y\rigid*mat_yx
elif method == 'denoising':
dim = dim_x
mat = rigid.mm(mat_x).solve(mat_y)[0] # mat_y\rigid*mat_x
mat = torch.linalg.solve(mat_y, rigid.mm(mat_x)) # mat_y\rigid*mat_x
# Smoothing operator
if len(ratio) == 3: # 3D
conv = lambda x: F.conv3d(x, smo_ker, stride=ratio)
Expand Down Expand Up @@ -263,7 +263,7 @@ def _proj_info(dim_y, mat_y, dim_x, mat_x, rigid=None,
po.dim_y = D_y.inverse()[:ndim, :ndim].mm(po.dim_y[..., None]).floor().squeeze()
po.vx_x = voxel_size(po.mat_x)
# Make intermediate
ratio = torch.solve(po.mat_x, po.mat_y)[0] # mat_y\mat_x
ratio = torch.linalg.solve(po.mat_y, po.mat_x) # mat_y\mat_x
ratio = (ratio[:ndim, :ndim] ** 2).sum(0).sqrt()
ratio = ratio.ceil().clamp(1) # ratio low/high >= 1
mat_yx = torch.cat((ratio, torch.ones(1, device=device, dtype=dtype))).diag()
Expand All @@ -278,7 +278,7 @@ def _proj_info(dim_y, mat_y, dim_x, mat_x, rigid=None,
po.smo_ker = smo_ker
# Add offset to intermediate space
off = torch.tensor(smo_ker.shape[-ndim:], dtype=dtype, device=device)
off = -(off - 1) // 2 # set offset
off = torch.div(-(off - 1), 2, rounding_mode='floor') # set offset
mat_off = torch.eye(ndim + 1, dtype=torch.float64, device=device)
mat_off[:ndim, -1] = off
po.dim_yx = po.dim_yx + 2 * torch.abs(off)
Expand Down
10 changes: 5 additions & 5 deletions unires/_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def _update_admm(x, y, z, w, rho, tmp, obj, n_iter, sett):
cg(A=lhs, b=tmp, x=y[c].dat,
verbose=sett.cgs_verbose,
max_iter=sett.cgs_max_iter,
stop='residuals',
stop='max_gain',
inplace=True,
precond=precond,
tolerance=sett.cgs_tol) # OBS: y[c].dat is here updated in-place
Expand Down Expand Up @@ -297,7 +297,7 @@ def _update_scaling(x, y, sett, max_niter_gn=1, num_linesearch=4, verbose=0):
mat_yx = x[c][n_x].po.mat_yx
mat_y = x[c][n_x].po.mat_y
rigid = _expm(x[c][n_x].rigid_q, sett.rigid_basis)
mat = rigid.mm(mat_yx).solve(mat_y)[0] # mat_y\rigid*mat_yx
mat = torch.linalg.solve(mat_y, rigid.mm(mat_yx)) # mat_y\rigid*mat_yx
# Observed data
dat_x = x[c][n_x].dat
msk = dat_x != 0
Expand Down Expand Up @@ -494,7 +494,7 @@ def _rigid_match(dat_x, dat_y, po, tau, rigid, sett, CtC=None, diff=False, verbo
mat = mat_x

# Get grid
mat = rigid.mm(mat).solve(mat_y)[0] # mat_y\rigid*mat
mat = torch.linalg.solve(mat_y, rigid.mm(mat)) # mat_y\rigid*mat
grid = affine_grid(mat.type(torch.float32), dim, jitter=False)

# Warp y and compute spatial derivatives
Expand Down Expand Up @@ -619,7 +619,7 @@ def _update_rigid_channel(xc, yc, sett, max_niter_gn=1, num_linesearch=4,
d_rigid = d_rigid.permute((1, 2, 0)) # make compatible with old affine_basis
d_rigid_q = torch.zeros(4, 4, num_q, device=device, dtype=torch.float64)
for i in range(num_q):
d_rigid_q[:, :, i] = d_rigid[:, :, i].mm(mat).solve(po.mat_y)[0] # mat_y\d_rigid*mat
d_rigid_q[:, :, i] = torch.linalg.solve(po.mat_y, d_rigid[:, :, i].mm(mat)) # mat_y\d_rigid*mat

# Compute gradient and Hessian
gr = torch.zeros(num_q, 1, device=device, dtype=torch.float64)
Expand Down Expand Up @@ -661,7 +661,7 @@ def _update_rigid_channel(xc, yc, sett, max_niter_gn=1, num_linesearch=4,
# Hes += 1e-5*Hes.diag().max()*torch.eye(num_q, dtype=Hes.dtype, device=device)

# Compute Gauss-Newton update step
Update = gr.solve(Hes)[0][:, 0]
Update = torch.linalg.solve(Hes, gr)[:, 0]

# Do update..
old_ll = ll.clone()
Expand Down
7 changes: 4 additions & 3 deletions unires/_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,9 +171,6 @@ def _read_image(data, device='cpu', could_be_ct=False):
dat = dat.float()
dat = dat.to(device)
dat[~torch.isfinite(dat)] = 0
# Add some random noise
torch.manual_seed(0)
dat[dat > 0] += torch.rand_like(dat[dat > 0]) - 1 / 2
# Affine matrix
mat = data[1]
if not isinstance(mat, torch.Tensor):
Expand All @@ -184,7 +181,11 @@ def _read_image(data, device='cpu', could_be_ct=False):
direc = None
nam = None
# Get dimensions
dat = dat.squeeze()
dim = tuple(dat.shape)
if len(dim) != 3:
raise ValueError("Input image dimension required to be 3D, recieved {:}D!". \
format(len(dim)))
# CT?
if could_be_ct and _is_ct(dat):
ct = True
Expand Down
2 changes: 1 addition & 1 deletion unires/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def fit(x, y, sett):
msk_fov = torch.ones(y[c].dim, dtype=torch.bool, device=sett.device)
for n in range(len(x[c])):
# Map to voxels in low-res image
M = x[c][n].po.rigid.mm(x[c][n].mat).solve(y[c].mat)[0].inverse()
M = torch.linalg.solve(y[c].mat, x[c][n].po.rigid.mm(x[c][n].mat)).inverse()
grid = affine_grid(M.type(x[c][n].dat.dtype), y[c].dim)[None, ...]
# Mask of low-res image FOV projected into high-res space
msk_fov = msk_fov & \
Expand Down

0 comments on commit 39c70c4

Please sign in to comment.