From bc5e29bd390e5518bace3cf1fa49cae8f9bbdc6d Mon Sep 17 00:00:00 2001 From: brianreicher Date: Wed, 29 Nov 2023 11:40:06 -0500 Subject: [PATCH] Update toml and use setup unet --- pyproject.toml | 2 +- src/autoseg/train/ACLSDTrain.py | 4 ++-- src/autoseg/train/MTLSDTrain.py | 3 +-- src/autoseg/train/STELARRTrain.py | 12 ++---------- 4 files changed, 6 insertions(+), 15 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index b4d35b4..efa310c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,3 +1,3 @@ [build-system] -requires = ["setuptools", "wheel", "Cython", "numpy"] +requires = ["setuptools", "wheel", "cython", "numpy"] build-backend = "setuptools.build_meta" diff --git a/src/autoseg/train/ACLSDTrain.py b/src/autoseg/train/ACLSDTrain.py index d7d698d..ef16f8f 100644 --- a/src/autoseg/train/ACLSDTrain.py +++ b/src/autoseg/train/ACLSDTrain.py @@ -46,14 +46,14 @@ def aclsd_train( unet = setup_unet(downsample_factors=[(2, 2, 2), (2, 2, 2)], padding="same") unet_ac = setup_unet(downsample_factors=[(2, 2, 2), (2, 2, 2)], padding="same", num_heads=1) - mtlsd_model = MTLSDModel(unet=unet, num_fmaps=unet.output_nc) + mtlsd_model = MTLSDModel(unet=unet, num_fmaps=unet.out_channels) mtlsd_loss = Weighted_MSELoss() # aff_lambda=0) mtlsd_optimizer = torch.optim.Adam( params=mtlsd_model.parameters(), lr=0.5e-4, betas=(0.95, 0.999) ) # second ACLSD UNet - aclsd_model = ACLSDModel(unet=unet_ac, num_fmaps=unet_ac.output_nc) + aclsd_model = ACLSDModel(unet=unet_ac, num_fmaps=unet_ac.out_channels) aclsd_loss = WeightedACLSD_MSELoss() # aff_lambda=0) aclsd_optimizer = torch.optim.Adam( aclsd_model.parameters(), lr=0.5e-4, betas=(0.95, 0.999) diff --git a/src/autoseg/train/MTLSDTrain.py b/src/autoseg/train/MTLSDTrain.py index 93905f3..ca6a2d1 100644 --- a/src/autoseg/train/MTLSDTrain.py +++ b/src/autoseg/train/MTLSDTrain.py @@ -34,8 +34,7 @@ def mtlsd_train(iterations: int, raw_file: str, voxel_size: int = 33): # initial MTLSD UNet unet = setup_unet() - - mtlsd_model = MTLSDModel(unet=unet, num_fmaps=unet.output_nc) + mtlsd_model = MTLSDModel(unet=unet, num_fmaps=unet.out_channels) mtlsd_loss = Weighted_MSELoss() # aff_lambda=0) mtlsd_optimizer = torch.optim.Adam( params=mtlsd_model.parameters(), lr=0.5e-4, betas=(0.95, 0.999) diff --git a/src/autoseg/train/STELARRTrain.py b/src/autoseg/train/STELARRTrain.py index 6623883..b6edbac 100644 --- a/src/autoseg/train/STELARRTrain.py +++ b/src/autoseg/train/STELARRTrain.py @@ -13,7 +13,7 @@ from ..models.STELARRModel import STELARRModel from ..postprocess.segment_skel_correct import get_skel_correct_segmentation from ..networks.NLayerDiscriminator import NLayerDiscriminator, NLayerDiscriminator3D -from ..networks.UNet import UNet +from ..networks.FLibUNet import setup_unet from ..losses.GANLoss import GANLoss from ..losses.MSELoss import Weighted_MSELoss from ..gp_filters.random_noise import RandomNoiseAugment @@ -44,15 +44,7 @@ def stelarr_train( fake_pred = gp.ArrayKey("FAKE_PRED") real_pred = gp.ArrayKey("REAL_PRED") - unet: UNet = UNet( - input_nc=1, - ngf=12, - fmap_inc_factor=3, - downsample_factors=[(2, 2, 2), (2, 2, 2)], - constant_upsample=True, - num_heads=3, - padding_type="valid", - ) + unet = se model: STELARRModel = STELARRModel(unet=unet, num_fmaps=unet.output_nc) discriminator: NLayerDiscriminator3D = NLayerDiscriminator( ndims=3,