diff --git a/test_postprocessing.py b/test_postprocessing.py new file mode 100644 index 0000000..2c70cc5 --- /dev/null +++ b/test_postprocessing.py @@ -0,0 +1,121 @@ + +import numpy as np +import torch +from torch.utils.data import ConcatDataset, DataLoader +from tqdm import tqdm +import datetime +import os +from model.unet import get_unet_model +from model.normalisation import Normalisation + +from pathlib import Path + +import sirf.STIR as STIR +from skimage.metrics import mean_squared_error as mse + + +class OSEMDataset(torch.utils.data.Dataset): + def __init__(self, osem_file, gt_file, im_size=256): + + self.osem = np.load(osem_file) + self.gt = np.load(gt_file) + + self.im_size = im_size + + def __len__(self): + return self.osem.shape[0] + + def __getitem__(self, idx): + + gt = torch.from_numpy(self.gt[idx]).float().unsqueeze(0) + osem = torch.from_numpy(self.osem[idx]).float().unsqueeze(0) + gt = torch.nn.functional.interpolate(gt.unsqueeze(0), size=[self.im_size, self.im_size], mode='bilinear') + osem = torch.nn.functional.interpolate(osem.unsqueeze(0), size=[self.im_size, self.im_size], mode='bilinear') + + return gt.squeeze(0), osem.squeeze(0) + +def evaluate_quality_metrics(reference, prediction, whole_object_mask, background_mask, voi_masks): + whole_object_indices = np.where(whole_object_mask.as_array()) + background_indices = np.where(background_mask.as_array()) + + norm = reference[background_indices].mean() + + voi_indices = {} + for key, value in voi_masks.items(): + voi_indices[key] = np.where(value.as_array()) + + whole = { + "RMSE_whole_object": np.sqrt( + mse(reference[whole_object_indices], prediction[whole_object_indices])) / norm, + "RMSE_background": np.sqrt( + mse(reference[background_indices], prediction[background_indices])) / norm} + local = { + f"AEM_VOI_{voi_name}": np.abs(prediction[voi_indices].mean() - reference[voi_indices].mean()) / + norm for voi_name, voi_indices in sorted(voi_indices.items())} + return {**whole, **local} + + + +def testing() -> None: + device = "cuda" + test_on = "Siemens_Vision600_thorax" + + #model = get_unet_model(in_ch=1, + # out_ch=1, + # scales=5, + # skip=16, + # im_size=256, + # channels=[16, 32, 64, 128, 256], + # use_sigmoid=False, + # use_norm=True) + # + model = torch.nn.Sequential(torch.nn.Conv2d(1, 1, 15, bias=False,padding=7)) + model.to(device) + model.load_state_dict(torch.load(os.path.join(f"postprocessing_unet/{test_on}/2024-09-06_12-51-16", "model.pt"), weights_only=False)) + model.eval() + print("Number of Parameters: ", sum([p.numel() for p in model.parameters()])) + + test_on = "Siemens_Vision600_thorax" + + + if not (srcdir := Path("/mnt/share/petric")).is_dir(): + srcdir = Path("./data") + def get_image(fname): + if (source := srcdir / test_on / 'PETRIC' / fname).is_file(): + return STIR.ImageData(str(source)) + return None # explicit to suppress linter warnings + + OSEM_image = STIR.ImageData(str(srcdir / test_on / 'OSEM_image.hv')) + reference_image = get_image('reference_image.hv') + whole_object_mask = get_image('VOI_whole_object.hv') + background_mask = get_image('VOI_background.hv') + voi_masks = { + voi.stem[4:]: STIR.ImageData(str(voi)) + for voi in (srcdir / test_on / 'PETRIC').glob("VOI_*.hv") if voi.stem[4:] not in ('background', 'whole_object')} + + # reference, osem, measurements, contamination_factor, attn_factors + get_norm = Normalisation("osem_mean") + + osem = torch.from_numpy(OSEM_image.as_array()).float().to(device).unsqueeze(1) + norm = get_norm(osem, measurements=None, contamination_factor=None) + + with torch.no_grad(): + x_pred = model(osem )# , norm) + pred = x_pred.cpu().squeeze().numpy() + pred[pred < 0] = 0 + print + print("OSEM: ") + print(evaluate_quality_metrics(reference_image.as_array(), + OSEM_image.as_array(), + whole_object_mask, + background_mask, + voi_masks)) + print("Prediction: ") + print(evaluate_quality_metrics(reference_image.as_array(), + pred, + whole_object_mask, + background_mask, + voi_masks)) + +if __name__ == '__main__': + testing() \ No newline at end of file diff --git a/train_postprocessing.py b/train_postprocessing.py index 50b89f7..a8e5c59 100644 --- a/train_postprocessing.py +++ b/train_postprocessing.py @@ -30,45 +30,54 @@ def __getitem__(self, idx): def training() -> None: device = "cuda" - epochs = 100 - - model = get_unet_model(in_ch=1, - out_ch=1, - scales=5, - skip=16, - im_size=256, - channels=[16, 32, 64, 128, 256], - use_sigmoid=False, - use_norm=True) - + epochs = 300 + test_on = "Siemens_Vision600_thorax" + + #model = get_unet_model(in_ch=1, + # out_ch=1, + # scales=5, + # skip=16, + # im_size=256, + # channels=[16, 32, 64, 128, 256], + # use_sigmoid=False, + # use_norm=True) + + model = torch.nn.Sequential(torch.nn.Conv2d(1, 1, 15, bias=False,padding=7)) + print(model[0]) + model.to(device) print("Number of Parameters: ", sum([p.numel() for p in model.parameters()])) ###### SET LOGGING ###### - current_time = datetime.datetime.now().strftime('%b%d_%H-%M-%S') - log_dir = os.path.join('postprocessing_unet', current_time) + current_time = datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S') + log_dir = os.path.join('postprocessing_unet', test_on, current_time) if not os.path.exists(log_dir): os.makedirs(log_dir) - optimizer = torch.optim.Adam(model.parameters(), lr=1e-4) + optimizer = torch.optim.Adam(model.parameters(), lr=2e-4) - dataset_neuro = OSEMDataset( - osem_file="data/NeuroLF_Hoffman_Dataset_osem.npy", - gt_file="data/NeuroLF_Hoffman_Dataset_gt.npy" - ) - dataset_nema = OSEMDataset( - osem_file="data/Siemens_mMR_NEMA_IQ_osem.npy", - gt_file="data/Siemens_mMR_NEMA_IQ_gt.npy" - ) - dataset = ConcatDataset([dataset_neuro, dataset_nema]) - print("LENGTH OF FULL DATASET: ", len(dataset)) + ### cross validation: train on A,B,C - test on D + datasets = ["Siemens_mMR_ACR", "NeuroLF_Hoffman_Dataset", "Siemens_mMR_NEMA_IQ", "Siemens_Vision600_thorax"] - train_size = int(0.9 * len(dataset)) - val_size = len(dataset) - train_size + #datasets.remove(test_on) + train_dataset = [] + for data in datasets: + train_dataset.append(OSEMDataset( + osem_file=f"data/{data}_osem.npy", + gt_file=f"data/{data}_gt.npy" + )) + - train_dl = DataLoader(dataset, batch_size=8, shuffle=True) + train_dataset = ConcatDataset(train_dataset) + print("LENGTH OF FULL DATASET: ", len(train_dataset)) + train_dl = DataLoader(train_dataset, batch_size=8, shuffle=True) + val_dataset = OSEMDataset( + osem_file=f"data/{test_on}_osem.npy", + gt_file=f"data/{test_on}_gt.npy" + ) + val_dl = DataLoader(val_dataset, batch_size=8, shuffle=False) # reference, osem, measurements, contamination_factor, attn_factors get_norm = Normalisation("osem_mean") @@ -77,7 +86,7 @@ def training() -> None: model.train() print(f"Epoch: {epoch}") mean_loss = [] - for idx, batch in tqdm(enumerate(train_dl), total=len(train_dl)): + for _, batch in tqdm(enumerate(train_dl), total=len(train_dl)): # reference, scale_factor, osem, norm, measurements, contamination_factor, attn_factors optimizer.zero_grad() @@ -87,20 +96,50 @@ def training() -> None: osem = batch[1] osem = osem.to(device) - norm = get_norm(osem, measurements=None, contamination_factor=None) + #norm = get_norm(osem, measurements=None, contamination_factor=None) - x_pred = model(osem, norm) - - loss = torch.mean((x_pred - reference)**2) + x_pred = model(osem) #model(osem, norm) + loss = torch.sum((x_pred - reference)**2) loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), 10) optimizer.step() mean_loss.append(loss.item()) - print("Mean loss: ", np.mean(mean_loss)) + print("Train loss: ", np.mean(mean_loss)) + + model.eval() + with torch.no_grad(): + mean_loss = [] + for _, batch in tqdm(enumerate(val_dl), total=len(val_dl)): + # reference, scale_factor, osem, norm, measurements, contamination_factor, attn_factors + + reference = batch[0] + reference = reference.to(device) + + osem = batch[1] + osem = osem.to(device) + + #norm = get_norm(osem, measurements=None, contamination_factor=None) + + x_pred = model(osem)#, norm) + + loss = torch.sum((x_pred - reference)**2) + + mean_loss.append(loss.item()) + + print("Val loss: ", np.mean(mean_loss)) + + #import matplotlib.pyplot as plt - torch.save(model.state_dict(), os.path.join(log_dir, "model.pt")) + #fig, (ax1, ax2, ax3, ax4) = plt.subplots(1,4) + #ax1.imshow(reference[4][0].cpu().numpy(), cmap="gray") + #ax2.imshow(x_pred[4][0].cpu().numpy(), cmap="gray") + #ax3.imshow(osem[4][0].cpu().numpy(), cmap="gray") + #im = ax4.imshow(model[0].weight[0][0].cpu().detach().numpy()) + #fig.colorbar(im, ax=ax4) + #plt.show() + torch.save(model.state_dict(), os.path.join(log_dir, "model.pt"))