Skip to content

Commit

Permalink
try to learn an single filter
Browse files Browse the repository at this point in the history
  • Loading branch information
alexdenker committed Sep 10, 2024
1 parent 7bb2ec9 commit 082f4e5
Show file tree
Hide file tree
Showing 2 changed files with 194 additions and 34 deletions.
121 changes: 121 additions & 0 deletions test_postprocessing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@

import numpy as np
import torch
from torch.utils.data import ConcatDataset, DataLoader
from tqdm import tqdm
import datetime
import os
from model.unet import get_unet_model
from model.normalisation import Normalisation

from pathlib import Path

import sirf.STIR as STIR
from skimage.metrics import mean_squared_error as mse


class OSEMDataset(torch.utils.data.Dataset):
def __init__(self, osem_file, gt_file, im_size=256):

self.osem = np.load(osem_file)
self.gt = np.load(gt_file)

self.im_size = im_size

def __len__(self):
return self.osem.shape[0]

def __getitem__(self, idx):

gt = torch.from_numpy(self.gt[idx]).float().unsqueeze(0)
osem = torch.from_numpy(self.osem[idx]).float().unsqueeze(0)
gt = torch.nn.functional.interpolate(gt.unsqueeze(0), size=[self.im_size, self.im_size], mode='bilinear')
osem = torch.nn.functional.interpolate(osem.unsqueeze(0), size=[self.im_size, self.im_size], mode='bilinear')

return gt.squeeze(0), osem.squeeze(0)

def evaluate_quality_metrics(reference, prediction, whole_object_mask, background_mask, voi_masks):
whole_object_indices = np.where(whole_object_mask.as_array())
background_indices = np.where(background_mask.as_array())

norm = reference[background_indices].mean()

voi_indices = {}
for key, value in voi_masks.items():
voi_indices[key] = np.where(value.as_array())

whole = {
"RMSE_whole_object": np.sqrt(
mse(reference[whole_object_indices], prediction[whole_object_indices])) / norm,
"RMSE_background": np.sqrt(
mse(reference[background_indices], prediction[background_indices])) / norm}
local = {
f"AEM_VOI_{voi_name}": np.abs(prediction[voi_indices].mean() - reference[voi_indices].mean()) /
norm for voi_name, voi_indices in sorted(voi_indices.items())}
return {**whole, **local}



def testing() -> None:
device = "cuda"
test_on = "Siemens_Vision600_thorax"

#model = get_unet_model(in_ch=1,
# out_ch=1,
# scales=5,
# skip=16,
# im_size=256,
# channels=[16, 32, 64, 128, 256],
# use_sigmoid=False,
# use_norm=True)
#
model = torch.nn.Sequential(torch.nn.Conv2d(1, 1, 15, bias=False,padding=7))
model.to(device)
model.load_state_dict(torch.load(os.path.join(f"postprocessing_unet/{test_on}/2024-09-06_12-51-16", "model.pt"), weights_only=False))
model.eval()
print("Number of Parameters: ", sum([p.numel() for p in model.parameters()]))

test_on = "Siemens_Vision600_thorax"


if not (srcdir := Path("/mnt/share/petric")).is_dir():
srcdir = Path("./data")
def get_image(fname):
if (source := srcdir / test_on / 'PETRIC' / fname).is_file():
return STIR.ImageData(str(source))
return None # explicit to suppress linter warnings

OSEM_image = STIR.ImageData(str(srcdir / test_on / 'OSEM_image.hv'))
reference_image = get_image('reference_image.hv')
whole_object_mask = get_image('VOI_whole_object.hv')
background_mask = get_image('VOI_background.hv')
voi_masks = {
voi.stem[4:]: STIR.ImageData(str(voi))
for voi in (srcdir / test_on / 'PETRIC').glob("VOI_*.hv") if voi.stem[4:] not in ('background', 'whole_object')}

# reference, osem, measurements, contamination_factor, attn_factors
get_norm = Normalisation("osem_mean")

osem = torch.from_numpy(OSEM_image.as_array()).float().to(device).unsqueeze(1)
norm = get_norm(osem, measurements=None, contamination_factor=None)

with torch.no_grad():
x_pred = model(osem )# , norm)
pred = x_pred.cpu().squeeze().numpy()
pred[pred < 0] = 0
print
print("OSEM: ")
print(evaluate_quality_metrics(reference_image.as_array(),
OSEM_image.as_array(),
whole_object_mask,
background_mask,
voi_masks))
print("Prediction: ")
print(evaluate_quality_metrics(reference_image.as_array(),
pred,
whole_object_mask,
background_mask,
voi_masks))

if __name__ == '__main__':
testing()
107 changes: 73 additions & 34 deletions train_postprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,45 +30,54 @@ def __getitem__(self, idx):

def training() -> None:
device = "cuda"
epochs = 100

model = get_unet_model(in_ch=1,
out_ch=1,
scales=5,
skip=16,
im_size=256,
channels=[16, 32, 64, 128, 256],
use_sigmoid=False,
use_norm=True)

epochs = 300
test_on = "Siemens_Vision600_thorax"

#model = get_unet_model(in_ch=1,
# out_ch=1,
# scales=5,
# skip=16,
# im_size=256,
# channels=[16, 32, 64, 128, 256],
# use_sigmoid=False,
# use_norm=True)

model = torch.nn.Sequential(torch.nn.Conv2d(1, 1, 15, bias=False,padding=7))
print(model[0])

model.to(device)
print("Number of Parameters: ", sum([p.numel() for p in model.parameters()]))

###### SET LOGGING ######
current_time = datetime.datetime.now().strftime('%b%d_%H-%M-%S')
log_dir = os.path.join('postprocessing_unet', current_time)
current_time = datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
log_dir = os.path.join('postprocessing_unet', test_on, current_time)
if not os.path.exists(log_dir):
os.makedirs(log_dir)

optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
optimizer = torch.optim.Adam(model.parameters(), lr=2e-4)

dataset_neuro = OSEMDataset(
osem_file="data/NeuroLF_Hoffman_Dataset_osem.npy",
gt_file="data/NeuroLF_Hoffman_Dataset_gt.npy"
)
dataset_nema = OSEMDataset(
osem_file="data/Siemens_mMR_NEMA_IQ_osem.npy",
gt_file="data/Siemens_mMR_NEMA_IQ_gt.npy"
)
dataset = ConcatDataset([dataset_neuro, dataset_nema])
print("LENGTH OF FULL DATASET: ", len(dataset))
### cross validation: train on A,B,C - test on D
datasets = ["Siemens_mMR_ACR", "NeuroLF_Hoffman_Dataset", "Siemens_mMR_NEMA_IQ", "Siemens_Vision600_thorax"]

train_size = int(0.9 * len(dataset))
val_size = len(dataset) - train_size
#datasets.remove(test_on)

train_dataset = []
for data in datasets:
train_dataset.append(OSEMDataset(
osem_file=f"data/{data}_osem.npy",
gt_file=f"data/{data}_gt.npy"
))


train_dl = DataLoader(dataset, batch_size=8, shuffle=True)
train_dataset = ConcatDataset(train_dataset)
print("LENGTH OF FULL DATASET: ", len(train_dataset))
train_dl = DataLoader(train_dataset, batch_size=8, shuffle=True)

val_dataset = OSEMDataset(
osem_file=f"data/{test_on}_osem.npy",
gt_file=f"data/{test_on}_gt.npy"
)
val_dl = DataLoader(val_dataset, batch_size=8, shuffle=False)

# reference, osem, measurements, contamination_factor, attn_factors
get_norm = Normalisation("osem_mean")
Expand All @@ -77,7 +86,7 @@ def training() -> None:
model.train()
print(f"Epoch: {epoch}")
mean_loss = []
for idx, batch in tqdm(enumerate(train_dl), total=len(train_dl)):
for _, batch in tqdm(enumerate(train_dl), total=len(train_dl)):
# reference, scale_factor, osem, norm, measurements, contamination_factor, attn_factors
optimizer.zero_grad()

Expand All @@ -87,20 +96,50 @@ def training() -> None:
osem = batch[1]
osem = osem.to(device)

norm = get_norm(osem, measurements=None, contamination_factor=None)
#norm = get_norm(osem, measurements=None, contamination_factor=None)

x_pred = model(osem, norm)

loss = torch.mean((x_pred - reference)**2)
x_pred = model(osem) #model(osem, norm)
loss = torch.sum((x_pred - reference)**2)
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 10)
optimizer.step()

mean_loss.append(loss.item())

print("Mean loss: ", np.mean(mean_loss))
print("Train loss: ", np.mean(mean_loss))

model.eval()
with torch.no_grad():
mean_loss = []
for _, batch in tqdm(enumerate(val_dl), total=len(val_dl)):
# reference, scale_factor, osem, norm, measurements, contamination_factor, attn_factors

reference = batch[0]
reference = reference.to(device)

osem = batch[1]
osem = osem.to(device)

#norm = get_norm(osem, measurements=None, contamination_factor=None)

x_pred = model(osem)#, norm)

loss = torch.sum((x_pred - reference)**2)

mean_loss.append(loss.item())

print("Val loss: ", np.mean(mean_loss))

#import matplotlib.pyplot as plt

torch.save(model.state_dict(), os.path.join(log_dir, "model.pt"))
#fig, (ax1, ax2, ax3, ax4) = plt.subplots(1,4)
#ax1.imshow(reference[4][0].cpu().numpy(), cmap="gray")
#ax2.imshow(x_pred[4][0].cpu().numpy(), cmap="gray")
#ax3.imshow(osem[4][0].cpu().numpy(), cmap="gray")
#im = ax4.imshow(model[0].weight[0][0].cpu().detach().numpy())
#fig.colorbar(im, ax=ax4)
#plt.show()
torch.save(model.state_dict(), os.path.join(log_dir, "model.pt"))



Expand Down

0 comments on commit 082f4e5

Please sign in to comment.