PreMix: Boosting Multiple Instance Learning in Digital Histopathology through Pre-training with Intra-Batch Slide Mixing
Under Submission Review for Computerized Medical Imaging and Graphics Journal
- PreMix extends the general MIL framework by pre-training the MIL aggregator with an intra-batch slide mixing approach
- Specifically, PreMix incorporates Barlow Twins Slide Mixing during pre-training, enhancing its ability to handle diverse WSI sizes and maximizing the utility of unlabeled WSIs
- Integrated with Mixup and Manifold Mixup during fine-tuning, PreMix achieves a 4.7% performance improvement over the baseline MIL framework, the hierarchical image pyramid transformer (HIPT) on the Camelyon16 dataset
To assess the robustness of both the original MIL framework and the proposed PreMix framework, we conducted experiments using random sampling and active learning settings with a limited WSI labeled dataset
- Windows 10 Enterprise
- 2 NVIDIA RTX 2080 Ti GPUs (11GB each)
- CUDA version: 11.7
- Python version: 3.8.16
Install Anaconda
Create a new environment and activate it
conda create --name premix python=3.8.16
conda activate premix
Install all required packages
pip install -r requirements.txt
pip install torch==1.13.1+cu117 torchvision==0.14.1+cu117 --extra-index-url https://download.pytorch.org/whl/cu117
Extract square regions from each WSI to create a pretrain
dataset using HS2P, which tiles tissue and extracts relevant regions at a given pixel spacing
The results from HS2P should be structured as follows:
Folder structure
<data_dir>/
├── pretrain/
├── hs2p_20x_4096/
├── debug/
├── patches/
├── slide_1/
├── 4096
├── jpg
├── slide_1_x1_y1.png
├── slide_1_x2_y2.png
├── ...
slide_2/
├── 4096
├── jpg
├── slide_2_x1_y1.png
├── slide_2_x2_y2.png
├── ...
├── ...
Download HIPT pre-trained weights using the following commands:
Download commands
mkdir checkpoints
cd checkpoints
gdown 1Qm-_XrTMYhu9Hl-4FClaOMuroyWlOAxw
gdown 1A2eHTT0dedHgdCvy6t3d9HwluF8p5yjz
Create a configuration file under config/feature_extraction/ inspired by existing files
To extract region-level features, set level: 'global' in the config (refer to config/feature_extraction/global.yaml
)
Ensure that slides_list.txt contains a list of all slide names in a .txt file:
slide_1
slide_2
...
Run the following command to initiate feature extraction:
python extract_features.py --config-name global
The results should be structured as follow:
Folder structure
outputs/
├── pretrain/
├── features/
├── hipt/
├── global/
├── region/
├── slide_1_x1_y1.pt
├── slide_2_x2_y2.pt
├── ...
├── slide/
├── slide_1.pt
├── slide_2.pt
├── ...
1. Prepapre a csv file inside data/pretrain/
(refer to data/pretrain/camelyon16_cptac_ucec.csv
)
This csv list all the slides for pretraining
slide_id
slide_1
slide_2
...
2. Create a configuration file under config/training/
inspired by existing files
Refer to config/training/pretrain.yaml
for inspiration
3. Pretrain Barlow Twins Slide Mixing
Run the following command to initiate Barlow Twins Slide Mixing:
python barlow_twins_slide_mixing.py --config-name pretrain
The results should be structured as follows:
Folder structure
outputs/
├── pretrain/
├── checkpoints/
├── global
├── <model_name>_<epoch>.pth
├── <model_name>_<epoch>.pth
├── ...
1. Prepare pool and test csv files for downstream classification
Refer to data/camelyon16/pool.csv
and data/camelyon16/test.csv
for inspiration
These two csv files list all the slides for downstream classification
slide_id,label
slide_1,0
slide_2,1
...
2. Create a configuration file under config/training/
inspired by existing files
Refer to config/training/global.yaml
for inspiration
Note that the <model_name>
in the config file should be the full name <model_name>_<epoch>
Make sure to include the following to integrate slide mixing strategies during fine-tuning and uncomment if they are not needed
mixing:
mixup: True
manifold_mixup: True
manifold_mixup_transformer: True
mixup_alpha: 1
mixup_alpha_per_sample: False
mixup_type: random # [random, cosine_sim, class_aware]
...
3. Fine-tuning MIL aggregator with mixup and manifold mixup
Run the following command to initiate the fine-tuning process:
python main.py --config-name global
The results should be structured as follows:
Folder structure
outputs/
├── <downstream_dataset>
├── checkpoints/
├── <all_settings_folder>
├── <AL_strategy>
├── best_model.pth
├── results/
├── <all_settings_folder>
├── <AL_strategy>
├── train_0.csv
├── train_1.csv
├── ...
├── test.csv
├── test_results.csv
├── roc_auc_curve.png
├── scripts/
├── <all_settings_folder>
├── log.txt
best_model.pth
is the best model given the specified settings and AL strategy
train_0.csv
, train_1.csv
, etc., contain the predicted probability over classes
log.txt
contains all the model's performance metrics (ACC, AUC, Precision, Recall) over all AL strategies and training labeled budgets
This codebase builds upon HIPT and Re-Implementation HIPT