Skip to content

Commit

Permalink
Update swav
Browse files Browse the repository at this point in the history
  • Loading branch information
surajpaib committed Sep 8, 2023
1 parent 475b221 commit 9c852b5
Showing 1 changed file with 106 additions and 108 deletions.
214 changes: 106 additions & 108 deletions experiments/pretraining/swav_pretrain.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -29,117 +29,115 @@ trainer:
max_samples: 10

system:
_target_: torch.compile
model:
_target_: lighter.LighterSystem
batch_size: 64 # Change to lower batch size if GPU memory is smaller.
pin_memory: True
drop_last_batch: True # Used in SSL cases because of negatives
num_workers: 12
_target_: lighter.LighterSystem
batch_size: 64 # Change to lower batch size if GPU memory is smaller.
pin_memory: True
drop_last_batch: True # Used in SSL cases because of negatives
num_workers: 12

model:
_target_: project.ssl.models.swav.SwaV
num_ftrs: 4096
out_dim: 128
n_prototypes: 100
n_queues: 0
n_steps_frozen_prototypes: 50
queue_length: "$(@system#model#batch_size * @trainer#devices) * 2" # 15 * number of effective batches
start_queue_at_epoch: 15 # SwaV starts at the 15th *epoch* in the paper
backbone:
_target_: monai.networks.nets.resnet.resnet50
pretrained: False
n_input_channels: 1
widen_factor: 2
conv1_t_stride: 2
feed_forward: False
model:
_target_: project.ssl.models.swav.SwaV
num_ftrs: 4096
out_dim: 128
n_prototypes: 100
n_queues: 0
n_steps_frozen_prototypes: 50
queue_length: "$(@system#batch_size * @trainer#devices) * 2" # 15 * number of effective batches
start_queue_at_epoch: 15 # SwaV starts at the 15th *epoch* in the paper
backbone:
_target_: monai.networks.nets.resnet.resnet50
pretrained: False
n_input_channels: 1
widen_factor: 2
conv1_t_stride: 2
feed_forward: False

criterion:
_target_: project.ssl.losses.SwaVLoss
temperature: 0.1
sinkhorn_gather_distributed: True
sinkhorn_epsilon: 0.03
criterion:
_target_: project.ssl.losses.SwaVLoss
temperature: 0.1
sinkhorn_gather_distributed: True
sinkhorn_epsilon: 0.03

optimizer:
_target_: torch.optim.SGD
params: "$@system#model#model.parameters()"
lr: "$((@system#model#batch_size * @trainer#devices)/256) * 0.6" # Compute LR dynamically for different batch sizes
weight_decay: 1.0e-6
momentum: 0.9

optimizer:
_target_: torch.optim.SGD
params: "$@system#model.parameters()"
lr: "$((@system#batch_size * @trainer#devices)/256) * 0.6" # Compute LR dynamically for different batch sizes
weight_decay: 1.0e-6
momentum: 0.9

scheduler:
scheduler:
scheduler:
_target_: torch.optim.lr_scheduler.CosineAnnealingLR
optimizer: "@system#model#optimizer"
eta_min: "$((@system#model#batch_size * @trainer#devices)/256) * 0.0006"
T_max: "$(@trainer#max_epochs) * len(@system#model#datasets#train)//(@system#model#batch_size * @trainer#devices)" # Compute total steps
interval: "step"
_target_: torch.optim.lr_scheduler.CosineAnnealingLR
optimizer: "@system#optimizer"
eta_min: "$((@system#batch_size * @trainer#devices)/256) * 0.0006"
T_max: "$(@trainer#max_epochs) * len(@system#datasets#train)//(@system#batch_size * @trainer#devices)" # Compute total steps
interval: "step"

metrics:
train: null
val: "%#train"
test: "%#train"
datasets:
train:
_target_: project.datasets.SSLRadiomicsDataset
path: "data/preprocessing/deeplesion/annotations/deeplesion_annotations_training.csv"
orient: True
enable_negatives: False
resample_spacing: [1, 1, 1]
transform:
_target_: torchvision.transforms.Compose
transforms:
- _target_: monai.transforms.ToTensord
keys: ["image"]
- _target_: monai.transforms.AddChanneld
keys: ["image"]
- _target_: monai.transforms.ScaleIntensityRanged
keys: ["image"]
a_min: -1024
a_max: 3072
b_min: 0
b_max: 1
clip: True
- _target_: project.ssl.transforms.Replicated
keys: ["image"]
transforms:
- _target_: monai.transforms.Compose
transforms:
# Random Transforms begin
- _target_: project.ssl.transforms.RandomResizedCrop3Dd
keys: ["image"]
size: 50
- _target_: monai.transforms.RandFlipd
keys: ["image"]
prob: 0.5
spatial_axis: [1, 2]
prob: 0.5
- _target_: monai.transforms.RandAffined
keys: ["image"]
prob: 0.5
rotate_range: $((22/7)/180)*10
shear_range: 0.1
padding_mode: zeros
- _target_: monai.transforms.RandHistogramShiftd
keys: ["image"]
prob: 0.5
- _target_: monai.transforms.RandGaussianSmoothd
keys: ["image"]
prob: 0.5
- _target_: monai.transforms.SpatialPadd
keys: ["image"]
spatial_size: [50, 50, 50]
- "%#0"
- "%#0"
- "%#0"
- _target_: monai.transforms.SelectItemsd
keys: ["image"]
- _target_: torchvision.transforms.Lambda
lambd: "$lambda x: x['image']"
val: null
test: null
metrics:
train: null
val: "%#train"
test: "%#train"

datasets:
train:
_target_: project.datasets.SSLRadiomicsDataset
path: "data/preprocessing/deeplesion/annotations/deeplesion_annotations_training.csv"
orient: True
enable_negatives: False
resample_spacing: [1, 1, 1]
transform:
_target_: torchvision.transforms.Compose
transforms:
- _target_: monai.transforms.ToTensord
keys: ["image"]
- _target_: monai.transforms.AddChanneld
keys: ["image"]
- _target_: monai.transforms.ScaleIntensityRanged
keys: ["image"]
a_min: -1024
a_max: 3072
b_min: 0
b_max: 1
clip: True
- _target_: project.ssl.transforms.Replicated
keys: ["image"]
transforms:
- _target_: monai.transforms.Compose
transforms:
# Random Transforms begin
- _target_: project.ssl.transforms.RandomResizedCrop3Dd
keys: ["image"]
size: 50
- _target_: monai.transforms.RandFlipd
keys: ["image"]
prob: 0.5
spatial_axis: [1, 2]
prob: 0.5
- _target_: monai.transforms.RandAffined
keys: ["image"]
prob: 0.5
rotate_range: $((22/7)/180)*10
shear_range: 0.1
padding_mode: zeros
- _target_: monai.transforms.RandHistogramShiftd
keys: ["image"]
prob: 0.5
- _target_: monai.transforms.RandGaussianSmoothd
keys: ["image"]
prob: 0.5
- _target_: monai.transforms.SpatialPadd
keys: ["image"]
spatial_size: [50, 50, 50]
- "%#0"
- "%#0"
- "%#0"
- _target_: monai.transforms.SelectItemsd
keys: ["image"]
- _target_: torchvision.transforms.Lambda
lambd: "$lambda x: x['image']"
val: null
test: null

postprocessing:
logging:
pred: "$lambda x: torch.argmax(x[0][0], dim=-1)"
postprocessing:
logging:
pred: "$lambda x: torch.argmax(x[0][0], dim=-1)"

0 comments on commit 9c852b5

Please sign in to comment.