From 39c70c474f5e94175d358dafb8693f7224a9c51d Mon Sep 17 00:00:00 2001 From: Mikael Brudfors Date: Tue, 21 Dec 2021 15:03:33 +0000 Subject: [PATCH] REFAC: removed warnings, work with latest nitorch --- setup.py | 4 ++-- unires/_core.py | 20 ++++++++++++-------- unires/_project.py | 8 ++++---- unires/_update.py | 10 +++++----- unires/_util.py | 7 ++++--- unires/run.py | 2 +- 6 files changed, 28 insertions(+), 23 deletions(-) diff --git a/setup.py b/setup.py index d123eda..2e129e0 100644 --- a/setup.py +++ b/setup.py @@ -7,11 +7,11 @@ description='UniRes: Unified Super-Resolution of Medical Imaging Data', entry_points={'console_scripts': ['unires=unires._cli:run']}, install_requires=[ - "nitorch[all]@git+https://github.com/balbasty/nitorch@0.1#egg=nitorch", + "nitorch[all]@git+https://github.com/balbasty/nitorch#ff6ab05c888325735ea344d9d924256653541700", ], name='unires', packages=find_packages(), python_requires='>=3.6', url='https://github.com/brudfors/UniRes', - version='0.0.1a', + version='0.1', ) diff --git a/unires/_core.py b/unires/_core.py index b675f9c..71fda72 100644 --- a/unires/_core.py +++ b/unires/_core.py @@ -75,7 +75,7 @@ def _crop_y(y, sett): mat_mu = mat_mu.mm(mat_vx) dim_mu = mat_vx[:3, :3].inverse().mm(dim_mu[:, None]).floor().squeeze() # Make output grid - M = mat_mu.solve(mat_y)[0].type(y[0].dat.dtype) + M = torch.linalg.solve(mat_y, mat_mu).type(y[0].dat.dtype) grid = affine_grid(M, dim_mu)[None, ...] # Crop for c in range(len(y)): @@ -122,9 +122,13 @@ def _estimate_hyperpar(x, sett): mu_fg = torch.tensor(4096, device=dat.device, dtype=dat.dtype) else: # Get noise and foreground statistics - sd_bg, sd_fg, mu_bg, mu_fg = estimate_noise(dat, num_class=2, show_fit=sett.show_hyperpar, - fig_num=100 + cnt) - mu_bg = torch.tensor(0.0, device=dat.device, dtype=dat.dtype) + prm_noise, prm_not_noise = estimate_noise( + dat, num_class=2, show_fit=sett.show_hyperpar,fig_num=100 + cnt + ) + sd_bg = prm_noise['sd'] + sd_fg = prm_not_noise['sd'] + mu_bg = prm_noise['mean'] + mu_fg = prm_not_noise['mean'] # Set values x[c][n].sd = sd_bg.float() x[c][n].tau = 1 / sd_bg.float() ** 2 @@ -328,7 +332,7 @@ def _init_reg(x, sett): i = 0 for c in range(len(x)): for n in range(len(x[c])): - imgs[i][1] = imgs[i][1].solve(mat_a[i, ...])[0] + imgs[i][1] = torch.linalg.solve(mat_a[i, ...], imgs[i][1]) i += 1 _print_info('init-reg', sett, 'co', 'finished', N, t0) @@ -343,7 +347,7 @@ def _init_reg(x, sett): i = 0 for c in range(len(x)): for n in range(len(x[c])): - imgs[i][1] = imgs[i][1].solve(mat_a)[0] + imgs[i][1] = torch.linalg.solve(mat_a, imgs[i][1]) i += 1 # Modify image affine (label uses the same as the image, so no need to modify that one) @@ -376,7 +380,7 @@ def _init_y_dat(x, y, sett): # Get image data dat = x[c][n].dat[None, None, ...] # Make output grid - mat = mat_y.solve(x[c][n].mat)[0] # mat_x\mat_y + mat = torch.linalg.solve(x[c][n].mat, mat_y) # mat_x\mat_y grid = affine_grid(mat.type(dat.dtype), dim_y) # Do resampling mn = torch.min(dat) @@ -402,7 +406,7 @@ def _init_y_label(x, y, sett): n = 0 if x[c][n].label is not None: # Make output grid - mat = mat_y.solve(x[c][n].mat)[0] # mat_x\mat_y + mat = torch.linalg.solve(x[c][n].mat, mat_y) # mat_x\mat_y grid = affine_grid(mat.type(x[c][n].dat.dtype), dim_y) # Do resampling y[c].label = _warp_label(x[c][n].label[0], grid) diff --git a/unires/_project.py b/unires/_project.py index e77f958..f26945a 100644 --- a/unires/_project.py +++ b/unires/_project.py @@ -144,10 +144,10 @@ def _proj_apply(operator, dat, po, method='super-resolution', bound='zero', inte dim_thick = po.dim_thick if method == 'super-resolution': dim = dim_yx - mat = rigid.mm(mat_yx).solve(mat_y)[0] # mat_y\rigid*mat_yx + mat = torch.linalg.solve(mat_y, rigid.mm(mat_yx)) # mat_y\rigid*mat_yx elif method == 'denoising': dim = dim_x - mat = rigid.mm(mat_x).solve(mat_y)[0] # mat_y\rigid*mat_x + mat = torch.linalg.solve(mat_y, rigid.mm(mat_x)) # mat_y\rigid*mat_x # Smoothing operator if len(ratio) == 3: # 3D conv = lambda x: F.conv3d(x, smo_ker, stride=ratio) @@ -263,7 +263,7 @@ def _proj_info(dim_y, mat_y, dim_x, mat_x, rigid=None, po.dim_y = D_y.inverse()[:ndim, :ndim].mm(po.dim_y[..., None]).floor().squeeze() po.vx_x = voxel_size(po.mat_x) # Make intermediate - ratio = torch.solve(po.mat_x, po.mat_y)[0] # mat_y\mat_x + ratio = torch.linalg.solve(po.mat_y, po.mat_x) # mat_y\mat_x ratio = (ratio[:ndim, :ndim] ** 2).sum(0).sqrt() ratio = ratio.ceil().clamp(1) # ratio low/high >= 1 mat_yx = torch.cat((ratio, torch.ones(1, device=device, dtype=dtype))).diag() @@ -278,7 +278,7 @@ def _proj_info(dim_y, mat_y, dim_x, mat_x, rigid=None, po.smo_ker = smo_ker # Add offset to intermediate space off = torch.tensor(smo_ker.shape[-ndim:], dtype=dtype, device=device) - off = -(off - 1) // 2 # set offset + off = torch.div(-(off - 1), 2, rounding_mode='floor') # set offset mat_off = torch.eye(ndim + 1, dtype=torch.float64, device=device) mat_off[:ndim, -1] = off po.dim_yx = po.dim_yx + 2 * torch.abs(off) diff --git a/unires/_update.py b/unires/_update.py index 8474751..35a6fee 100644 --- a/unires/_update.py +++ b/unires/_update.py @@ -142,7 +142,7 @@ def _update_admm(x, y, z, w, rho, tmp, obj, n_iter, sett): cg(A=lhs, b=tmp, x=y[c].dat, verbose=sett.cgs_verbose, max_iter=sett.cgs_max_iter, - stop='residuals', + stop='max_gain', inplace=True, precond=precond, tolerance=sett.cgs_tol) # OBS: y[c].dat is here updated in-place @@ -297,7 +297,7 @@ def _update_scaling(x, y, sett, max_niter_gn=1, num_linesearch=4, verbose=0): mat_yx = x[c][n_x].po.mat_yx mat_y = x[c][n_x].po.mat_y rigid = _expm(x[c][n_x].rigid_q, sett.rigid_basis) - mat = rigid.mm(mat_yx).solve(mat_y)[0] # mat_y\rigid*mat_yx + mat = torch.linalg.solve(mat_y, rigid.mm(mat_yx)) # mat_y\rigid*mat_yx # Observed data dat_x = x[c][n_x].dat msk = dat_x != 0 @@ -494,7 +494,7 @@ def _rigid_match(dat_x, dat_y, po, tau, rigid, sett, CtC=None, diff=False, verbo mat = mat_x # Get grid - mat = rigid.mm(mat).solve(mat_y)[0] # mat_y\rigid*mat + mat = torch.linalg.solve(mat_y, rigid.mm(mat)) # mat_y\rigid*mat grid = affine_grid(mat.type(torch.float32), dim, jitter=False) # Warp y and compute spatial derivatives @@ -619,7 +619,7 @@ def _update_rigid_channel(xc, yc, sett, max_niter_gn=1, num_linesearch=4, d_rigid = d_rigid.permute((1, 2, 0)) # make compatible with old affine_basis d_rigid_q = torch.zeros(4, 4, num_q, device=device, dtype=torch.float64) for i in range(num_q): - d_rigid_q[:, :, i] = d_rigid[:, :, i].mm(mat).solve(po.mat_y)[0] # mat_y\d_rigid*mat + d_rigid_q[:, :, i] = torch.linalg.solve(po.mat_y, d_rigid[:, :, i].mm(mat)) # mat_y\d_rigid*mat # Compute gradient and Hessian gr = torch.zeros(num_q, 1, device=device, dtype=torch.float64) @@ -661,7 +661,7 @@ def _update_rigid_channel(xc, yc, sett, max_niter_gn=1, num_linesearch=4, # Hes += 1e-5*Hes.diag().max()*torch.eye(num_q, dtype=Hes.dtype, device=device) # Compute Gauss-Newton update step - Update = gr.solve(Hes)[0][:, 0] + Update = torch.linalg.solve(Hes, gr)[:, 0] # Do update.. old_ll = ll.clone() diff --git a/unires/_util.py b/unires/_util.py index a84a702..8cebe23 100644 --- a/unires/_util.py +++ b/unires/_util.py @@ -171,9 +171,6 @@ def _read_image(data, device='cpu', could_be_ct=False): dat = dat.float() dat = dat.to(device) dat[~torch.isfinite(dat)] = 0 - # Add some random noise - torch.manual_seed(0) - dat[dat > 0] += torch.rand_like(dat[dat > 0]) - 1 / 2 # Affine matrix mat = data[1] if not isinstance(mat, torch.Tensor): @@ -184,7 +181,11 @@ def _read_image(data, device='cpu', could_be_ct=False): direc = None nam = None # Get dimensions + dat = dat.squeeze() dim = tuple(dat.shape) + if len(dim) != 3: + raise ValueError("Input image dimension required to be 3D, recieved {:}D!". \ + format(len(dim))) # CT? if could_be_ct and _is_ct(dat): ct = True diff --git a/unires/run.py b/unires/run.py index 2d57664..93d73f8 100644 --- a/unires/run.py +++ b/unires/run.py @@ -162,7 +162,7 @@ def fit(x, y, sett): msk_fov = torch.ones(y[c].dim, dtype=torch.bool, device=sett.device) for n in range(len(x[c])): # Map to voxels in low-res image - M = x[c][n].po.rigid.mm(x[c][n].mat).solve(y[c].mat)[0].inverse() + M = torch.linalg.solve(y[c].mat, x[c][n].po.rigid.mm(x[c][n].mat)).inverse() grid = affine_grid(M.type(x[c][n].dat.dtype), y[c].dim)[None, ...] # Mask of low-res image FOV projected into high-res space msk_fov = msk_fov & \