Skip to content

Commit

Permalink
Update toml and use setup unet
Browse files Browse the repository at this point in the history
  • Loading branch information
brianreicher committed Nov 29, 2023
1 parent 33f2641 commit bc5e29b
Show file tree
Hide file tree
Showing 4 changed files with 6 additions and 15 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
[build-system]
requires = ["setuptools", "wheel", "Cython", "numpy"]
requires = ["setuptools", "wheel", "cython", "numpy"]
build-backend = "setuptools.build_meta"
4 changes: 2 additions & 2 deletions src/autoseg/train/ACLSDTrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,14 +46,14 @@ def aclsd_train(
unet = setup_unet(downsample_factors=[(2, 2, 2), (2, 2, 2)], padding="same")
unet_ac = setup_unet(downsample_factors=[(2, 2, 2), (2, 2, 2)], padding="same", num_heads=1)

mtlsd_model = MTLSDModel(unet=unet, num_fmaps=unet.output_nc)
mtlsd_model = MTLSDModel(unet=unet, num_fmaps=unet.out_channels)
mtlsd_loss = Weighted_MSELoss() # aff_lambda=0)
mtlsd_optimizer = torch.optim.Adam(
params=mtlsd_model.parameters(), lr=0.5e-4, betas=(0.95, 0.999)
)

# second ACLSD UNet
aclsd_model = ACLSDModel(unet=unet_ac, num_fmaps=unet_ac.output_nc)
aclsd_model = ACLSDModel(unet=unet_ac, num_fmaps=unet_ac.out_channels)
aclsd_loss = WeightedACLSD_MSELoss() # aff_lambda=0)
aclsd_optimizer = torch.optim.Adam(
aclsd_model.parameters(), lr=0.5e-4, betas=(0.95, 0.999)
Expand Down
3 changes: 1 addition & 2 deletions src/autoseg/train/MTLSDTrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,7 @@ def mtlsd_train(iterations: int, raw_file: str, voxel_size: int = 33):
# initial MTLSD UNet
unet = setup_unet()


mtlsd_model = MTLSDModel(unet=unet, num_fmaps=unet.output_nc)
mtlsd_model = MTLSDModel(unet=unet, num_fmaps=unet.out_channels)
mtlsd_loss = Weighted_MSELoss() # aff_lambda=0)
mtlsd_optimizer = torch.optim.Adam(
params=mtlsd_model.parameters(), lr=0.5e-4, betas=(0.95, 0.999)
Expand Down
12 changes: 2 additions & 10 deletions src/autoseg/train/STELARRTrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from ..models.STELARRModel import STELARRModel
from ..postprocess.segment_skel_correct import get_skel_correct_segmentation
from ..networks.NLayerDiscriminator import NLayerDiscriminator, NLayerDiscriminator3D
from ..networks.UNet import UNet
from ..networks.FLibUNet import setup_unet
from ..losses.GANLoss import GANLoss
from ..losses.MSELoss import Weighted_MSELoss
from ..gp_filters.random_noise import RandomNoiseAugment
Expand Down Expand Up @@ -44,15 +44,7 @@ def stelarr_train(
fake_pred = gp.ArrayKey("FAKE_PRED")
real_pred = gp.ArrayKey("REAL_PRED")

unet: UNet = UNet(
input_nc=1,
ngf=12,
fmap_inc_factor=3,
downsample_factors=[(2, 2, 2), (2, 2, 2)],
constant_upsample=True,
num_heads=3,
padding_type="valid",
)
unet = se
model: STELARRModel = STELARRModel(unet=unet, num_fmaps=unet.output_nc)
discriminator: NLayerDiscriminator3D = NLayerDiscriminator(
ndims=3,
Expand Down

0 comments on commit bc5e29b

Please sign in to comment.