From e25ceea3f35ff482e07da373b77d08b0e5a5248a Mon Sep 17 00:00:00 2001 From: brianreicher Date: Wed, 29 Nov 2023 11:45:38 -0500 Subject: [PATCH] Switch AC UNet channels --- log/log_aclsd/events.out.tfevents.1701276331.lee-htem-gpu0 | 0 log/log_mtlsd/events.out.tfevents.1701276331.lee-htem-gpu0 | 0 src/autoseg/train/ACLSDTrain.py | 2 +- 3 files changed, 1 insertion(+), 1 deletion(-) create mode 100644 log/log_aclsd/events.out.tfevents.1701276331.lee-htem-gpu0 create mode 100644 log/log_mtlsd/events.out.tfevents.1701276331.lee-htem-gpu0 diff --git a/log/log_aclsd/events.out.tfevents.1701276331.lee-htem-gpu0 b/log/log_aclsd/events.out.tfevents.1701276331.lee-htem-gpu0 new file mode 100644 index 0000000..e69de29 diff --git a/log/log_mtlsd/events.out.tfevents.1701276331.lee-htem-gpu0 b/log/log_mtlsd/events.out.tfevents.1701276331.lee-htem-gpu0 new file mode 100644 index 0000000..e69de29 diff --git a/src/autoseg/train/ACLSDTrain.py b/src/autoseg/train/ACLSDTrain.py index ef16f8f..6f9784b 100644 --- a/src/autoseg/train/ACLSDTrain.py +++ b/src/autoseg/train/ACLSDTrain.py @@ -44,7 +44,7 @@ def aclsd_train( # initial MTLSD UNet unet = setup_unet(downsample_factors=[(2, 2, 2), (2, 2, 2)], padding="same") - unet_ac = setup_unet(downsample_factors=[(2, 2, 2), (2, 2, 2)], padding="same", num_heads=1) + unet_ac = setup_unet(in_channels=10, downsample_factors=[(2, 2, 2), (2, 2, 2)], padding="same", num_heads=1) mtlsd_model = MTLSDModel(unet=unet, num_fmaps=unet.out_channels) mtlsd_loss = Weighted_MSELoss() # aff_lambda=0)