diff --git a/nnunetv2/dataset_conversion/Dataset224_AbdomenAtlas1.0.py b/nnunetv2/dataset_conversion/Dataset224_AbdomenAtlas1.0.py index 6a0a9ad59..e8f878088 100644 --- a/nnunetv2/dataset_conversion/Dataset224_AbdomenAtlas1.0.py +++ b/nnunetv2/dataset_conversion/Dataset224_AbdomenAtlas1.0.py @@ -4,6 +4,20 @@ if __name__ == '__main__': + """ + How to train our submission to the JHU benchmark + + 1. Execute this script here to convert the dataset into nnU-Net format. Adapt the paths to your system! + 2. Run planning and preprocessing: `nnUNetv2_plan_and_preprocess -d 224 -npfp 64 -np 64 -c 3d_fullres -pl + nnUNetPlannerResEncL_torchres`. Adapt the number of processes to your System (-np; -npfp)! Note that each process + will again spawn 4 threads for resampling. This custom planner replaces the nnU-Net default resampling scheme with + a torch-based implementation which is faster but less accurate. This is needed to satisfy the inference speed + constraints. + 3. Run training with `nnUNetv2_train 224 3d_fullres all -p nnUNetResEncUNetLPlans_torchres`. 24GB VRAM required, + training will take ~28-30h. + """ + + base = '/home/isensee/Downloads/AbdomenAtlas1.0Mini' cases = subdirs(base, join=False, prefix='BDMAP') diff --git a/nnunetv2/inference/JHU_inference.py b/nnunetv2/inference/JHU_inference.py index d57c2a606..0933600a9 100644 --- a/nnunetv2/inference/JHU_inference.py +++ b/nnunetv2/inference/JHU_inference.py @@ -176,7 +176,7 @@ def predict_from_data_iterator(self, predictor.initialize_from_trained_model_folder( args.model, ('all', ), - 'checkpoint_latest.pth' + 'checkpoint_final.pth' ) # we need to create list of list of input files diff --git a/nnunetv2/training/nnUNetTrainer/nnUNetTrainer.py b/nnunetv2/training/nnUNetTrainer/nnUNetTrainer.py index d3803f391..b23847cb2 100644 --- a/nnunetv2/training/nnUNetTrainer/nnUNetTrainer.py +++ b/nnunetv2/training/nnUNetTrainer/nnUNetTrainer.py @@ -238,13 +238,24 @@ def initialize(self): def _do_i_compile(self): # new default: compile is enabled! + # compile does not work on mps + if self.device == torch.device('mps'): + if 'nnUNet_compile' in os.environ.keys() and os.environ['nnUNet_compile'].lower() in ('true', '1', 't'): + self.print_to_log_file("INFO: torch.compile disabled because of unsupported mps device") + return False + # CPU compile crashes for 2D models. Not sure if we even want to support CPU compile!? Better disable if self.device == torch.device('cpu'): + if 'nnUNet_compile' in os.environ.keys() and os.environ['nnUNet_compile'].lower() in ('true', '1', 't'): + self.print_to_log_file("INFO: torch.compile disabled because device is CPU") return False # default torch.compile doesn't work on windows because there are apparently no triton wheels for it # https://discuss.pytorch.org/t/windows-support-timeline-for-torch-compile/182268/2 if os.name == 'nt': + if 'nnUNet_compile' in os.environ.keys() and os.environ['nnUNet_compile'].lower() in ('true', '1', 't'): + self.print_to_log_file("INFO: torch.compile disabled because Windows is not natively supported. If " + "you know what you are doing, check https://discuss.pytorch.org/t/windows-support-timeline-for-torch-compile/182268/2") return False if 'nnUNet_compile' not in os.environ.keys():