Skip to content

Commit

Permalink
Merge pull request #20 from CielAl/dev
Browse files Browse the repository at this point in the history
1. fix a problem to draw random number for augmentation. 2. Add least…
  • Loading branch information
CielAl authored Jan 12, 2024
2 parents d2426c2 + ff371b9 commit c696fde
Show file tree
Hide file tree
Showing 16 changed files with 315 additions and 97 deletions.
129 changes: 87 additions & 42 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,34 +16,89 @@
## Documentation
Detail documentation regarding the code base can be found in the [GitPages](https://cielal.github.io/torch-staintools/).

## Citation
If this toolkit helps you in your publication, please feel free to cite with the following bibtex entry:
```bibtex
@software{zhou_2024_10453807,
author = {Zhou, Yufei},
title = {CielAl/torch-staintools: V1.0.3 Release},
month = jan,
year = 2024,
publisher = {Zenodo},
version = {v1.0.3},
doi = {10.5281/zenodo.10453807},
url = {https://doi.org/10.5281/zenodo.10453807}
}
```

## Description
* Stain Normalization (Reinhard, Macenko, and Vahadane) for pytorch. Input tensors (fit and transform) must be in shape of `NxCxHxW`, with value scaled to [0, 1] in format of torch.float32.
* Stain Augmentation using Macenko and Vahadane as stain extraction.
* Fast normalization/augmentation on GPU with stain matrices caching.
* Simulate the workflow in [StainTools library](https://github.com/Peter554/StainTools) but use the Iterative Shrinkage Thresholding Algorithm (ISTA), or optionally, the coordinate descent (CD) to solve the dictionary learning for stain matrix/concentration computation in Vahadane or Macenko (stain concentration only) algorithm. The implementation of ISTA and CD are derived from Cédric Walker's [torchvahadane](https://github.com/cwlkr/torchvahadane)
* Simulate the workflow in [StainTools library](https://github.com/Peter554/StainTools) but use the Iterative Shrinkage Thresholding Algorithm (ISTA), or optionally, the coordinate descent (CD) to solve the dictionary learning for stain matrix computation in Vahadane or Macenko (stain concentration only) algorithm. The implementation of ISTA and CD are derived from Cédric Walker's [torchvahadane](https://github.com/cwlkr/torchvahadane)
* Stain Concentration is solved via factorization of `Stain_Matrix x Concentration = Optical_Density`. For efficient sparse solution and more robust outcomes, ISTA can be applied. Alternatively, Least Square solver (LS) from `torch.linalg.lstsq` might be applied for faster non-sparse solution.
* No SPAMS requirement (which is a dependency in StainTools).

<br />

#### Sample Output of Torch StainTools
#### Sample Output of Torch-StainTools Normalization
![Screenshot](https://raw.githubusercontent.com/CielAl/torch-staintools/main/showcases/sample_out.png)

#### Sample Output of StainTools
![Screenshot](https://raw.githubusercontent.com/CielAl/torch-staintools/main/showcases/sample_out_staintools.png)

## Use case
#### Sample Output of Torch-StainTools Augmentation (Repeat 3 times)
![Screenshot](https://raw.githubusercontent.com/CielAl/torch-staintools/main/showcases/sample_out_augmentation.png)

#### Sample Output of StainTools Augmentation (Repeat 3 times)
![Screenshot](https://raw.githubusercontent.com/CielAl/torch-staintools/main/showcases/sample_out_augmentation_staintools.png)

## Benchmark (No Stain Matrices Caching)
* Use the sample images under ./test_images (size `2500x2500x3`). Mean was computed from 7 runs (1 loop per run) using
timeit. Comparison between torch_stain_tools in CPU/GPU mode, as well as that of the StainTools Implementation.
* For consistency, use ISTA to compute the concentration.

### Transformation

| Method | CPU[s] | GPU[s] | StainTool[s] |
|:---------|:-------|:-------|:-------------|
| Vahadane | 119 | 7.5 | 20.9 |
| Macenko | 5.57 | 0.479 | 20.7 |
| Reinhard | 0.840 | 0.024 | 0.414 |

### Fitting
| Method | CPU[s] | GPU[s] | StainTool[s] |
|:---------|:-------|:-------|:-------------|
| Vahadane | 132 | 8.40 | 19.1 |
| Macenko | 6.99 | 0.064 | 20.0 |
| Reinhard | 0.422 | 0.011 | 0.076 |

### Batchified Concentration Computation
* Split the sample images under ./test_images (size `2500x2500x3`) into 81 non-overlapping `256x256x3` tiles as a batch.
* For the StainTools baseline, a for-loop is implemented to get the individual concentration of each of the numpy array of the 81 tiles.
*
| Method | CPU[s] | GPU[s] |
|:-------------------------------------|:-------|:----------|
| ISTA (`concentration_method='ista'`) | 3.12 | 1.24 |
| CD (`concentration_method='cd'`) | 29.3s | 4.87 |
| LS (`concentration_method='ls'`) | 0.221 | **0.097** |
| StainTools (SPAMS) | 16.6 | N/A |


## Use Cases and Tips
* For details, follow the example in demo.py
* Normalizers are wrapped as `torch.nn.Module`, working similarly to a standalone neural network. This means that for a workflow involving dataloader with multiprocessing, the normalizer
(Note that CUDA has poor support in multiprocessing, and therefore it may not be the best practice to perform GPU-accelerated on-the-fly stain transformation in pytorch's dataset/dataloader)

* `concentration_method='ls'` (i.e., `torch.linalg.lstsq`) can be efficient for batches of many smaller input (e.g., `256x256`) in terms of width and height. However, it may fail on GPU for a single larger input image (width and height). This happens even if the
the total number of pixels of the image is fewer than the aforementioned batch of multiple smaller input. Therefore, `concentration_method='ls'` could be suitable to deal with huge amount of small images in batches on the fly.

```python
import cv2
import torch
from torchvision.transforms import ToTensor
from torchvision.transforms.functional import convert_image_dtype
from torch_staintools.normalizer.factory import NormalizerBuilder
from torch_staintools.augmentor.factory import AugmentorBuilder
from torch_staintools.normalizer import NormalizerBuilder
from torch_staintools.augmentor import AugmentorBuilder
import os
seed = 0
torch.manual_seed(seed)
Expand Down Expand Up @@ -71,7 +126,15 @@ norm_tensor = ToTensor()(norm).unsqueeze(0).to(device)

# ######## Normalization
# create the normalizer - using vahadane. Alternatively can use 'macenko' or 'reinhard'.
normalizer_vahadane = NormalizerBuilder.build('vahadane')
# note this is equivalent to:
# from torch_staintools.normalizer.separation import StainSeparation
# normalizer_vahadane = StainSeparation.build('vahadane', **arguments)

# we use the 'ista' (ISTA algorithm) to get the sparse solution of the factorization: STAIN_MATRIX * Concentration = OD
# alternatively, 'cd' (coordinate descent) and 'ls' (least square from torch.linalg) is available.
# Note that 'ls' does not can be much faster on batches of smaller input, but may fail on GPU for individual large input
# in terms of width and height, regardless of the batch size
normalizer_vahadane = NormalizerBuilder.build('vahadane', concentration_method='ista')
# move the normalizer to the device (CPU or GPU)
normalizer_vahadane = normalizer_vahadane.to(device)
# fit. For macenko and vahadane this step will compute the stain matrix and concentration
Expand All @@ -89,7 +152,8 @@ augmentor = AugmentorBuilder.build('vahadane',
# the luminosity threshold to find the tissue region to augment
# if set to None means all pixels are treated as tissue
luminosity_threshold=0.8,

# herein we use 'ista' to compute the concentration
concentration_method='ista',
sigma_alpha=0.2,
sigma_beta=0.2, target_stain_idx=(0, 1),
# this allows to cache the stain matrix if it's too time-consuming to recompute.
Expand Down Expand Up @@ -117,6 +181,21 @@ for _ in range(num_augment):

# dump the cache of stain matrices for future usage
augmentor.dump_cache('./cache.pickle')

# fast batch operation
tile_size = 512
tiles: torch.Tensor = norm_tensor.unfold(2, tile_size, tile_size)\
.unfold(3, tile_size, tile_size).reshape(1, 3, -1, tile_size, tile_size).squeeze(0).permute(1, 0, 2, 3).contiguous()
print(tiles.shape)
# use macenko normalization as example
normalizer_macenko = NormalizerBuilder.build('macenko', use_cache=True,
# use least square solver, along with cache, to perform
# normalization on-the-fly
concentration_method='ls')
normalizer_macenko = normalizer_macenko.to(device)
normalizer_macenko.fit(target_tensor)
normalizer_macenko(tiles)

```
## Stain Matrix Caching
As elaborated in the below in the running time benchmark of fitting, computation of stain matrix could be time-consuming.
Expand All @@ -133,40 +212,6 @@ augmentor(input_batch, cache_keys=list_of_keys_corresponding_to_input_batch)
The next time `Normalizer` or `Augmentor` process the images, the corresponding stain matrices will be queried and fetched from cache if they are stored already, rather than recomputing from scratch.


## Benchmark
* Use the sample images under ./test_images (size `2500x2500x3`). Mean was computed from 7 runs (1 loop per run) using
timeit. Comparison between torch_stain_tools in CPU/GPU mode, as well as that of the StainTools Implementation.

### Transformation

| Method | CPU[s] | GPU[s] | StainTool[s] |
|:---------|:-------|:-------|:-------------|
| Vahadane | 119 | 7.5 | 20.9 |
| Macenko | 5.57 | 0.479 | 20.7 |
| Reinhard | 0.840 | 0.024 | 0.414 |

### Fitting
| Method | CPU[s] | GPU[s] | StainTool[s] |
|:---------|:-------|:-------|:-------------|
| Vahadane | 132 | 8.40 | 19.1 |
| Macenko | 6.99 | 0.064 | 20.0 |
| Reinhard | 0.422 | 0.011 | 0.076 |

## Citation
If this toolkit helps you in your publication, please feel free to cite with the following bibtex entry:
```bibtex
@software{zhou_2024_10453807,
author = {Zhou, Yufei},
title = {CielAl/torch-staintools: V1.0.3 Release},
month = jan,
year = 2024,
publisher = {Zenodo},
version = {v1.0.3},
doi = {10.5281/zenodo.10453807},
url = {https://doi.org/10.5281/zenodo.10453807}
}
```

## Acknowledgments
* Some codes are derived from [torchvahadane](https://github.com/cwlkr/torchvahadane), [torchstain](https://github.com/EIDOSLAB/torchstain), and [StainTools](https://github.com/Peter554/StainTools)
* Sample images in the demo and ReadMe.md are selected from [The Cancer Genome Atlas Program(TCGA)](https://www.cancer.gov/ccg/research/genome-sequencing/tcga) dataset.
58 changes: 54 additions & 4 deletions demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@
import torch
from torchvision.transforms import ToTensor
from torchvision.transforms.functional import convert_image_dtype
from torch_staintools.normalizer.factory import NormalizerBuilder
from torch_staintools.augmentor.factory import AugmentorBuilder
from torch_staintools.normalizer import NormalizerBuilder
from torch_staintools.augmentor import AugmentorBuilder
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm
import random
import os
seed = 0
torch.manual_seed(seed)
Expand Down Expand Up @@ -52,7 +53,7 @@ def postprocess(image_tensor): return convert_image_dtype(image_tensor, torch.ui


# ######### Vahadane
normalizer_vahadane = NormalizerBuilder.build('vahadane', reconst_method='ista', use_cache=True,
normalizer_vahadane = NormalizerBuilder.build('vahadane', concentration_method='ista', use_cache=True,
rng=1,
)
normalizer_vahadane = normalizer_vahadane.to(device)
Expand All @@ -75,7 +76,7 @@ def postprocess(image_tensor): return convert_image_dtype(image_tensor, torch.ui
# #################### Macenko


normalizer_macenko = NormalizerBuilder.build('macenko')
normalizer_macenko = NormalizerBuilder.build('macenko', use_cache=True, concentration_method='ls')
normalizer_macenko = normalizer_macenko.to(device)
normalizer_macenko.fit(target_tensor)

Expand Down Expand Up @@ -202,3 +203,52 @@ def postprocess(image_tensor): return convert_image_dtype(image_tensor, torch.ui
ax.axis('off')
plt.savefig(os.path.join('.', 'showcases', 'sample_out_staintools.png'), bbox_inches='tight')
plt.show()

algorithms = ['Vahadane', 'Macenko']
num_repeat = 3

# # sample aug output
fig, axs = plt.subplots(2, num_repeat + 1, figsize=(15, 8), dpi=300)
for i, ax_alg in enumerate(axs):
alg = algorithms[i].lower()
augmentor = AugmentorBuilder.build(alg, concentration_method='ista',
sigma_alpha=0.5,
sigma_beta=0.5,
luminosity_threshold=0.8,
rng=314159, use_cache=True).to(device)
ax_alg[0].imshow(norm)
ax_alg[0].set_title("Augmentation Original")
ax_alg[0].axis('off')
for j in range(1, len(ax_alg)):
aug_out = augmentor(norm_tensor, cache_keys=[0])
ax_alg[j].imshow(postprocess(aug_out))
ax_alg[j].set_title(f"{alg} :{j}")
ax_alg[j].axis('off')
plt.savefig(os.path.join('.', 'showcases', 'sample_out_augmentation.png'), bbox_inches='tight')
plt.show()


# #### sample aug output
np.random.seed(314159)
random.seed(314159)
from staintools import StainAugmentor
from staintools.preprocessing.luminosity_standardizer import LuminosityStandardizer
algorithms = ['Vahadane', 'Macenko']
fig, axs = plt.subplots(2, num_repeat + 1, figsize=(15, 8), dpi=300)
for i, ax_alg in enumerate(axs):
alg = algorithms[i].lower()
augmentor = StainAugmentor(method=alg, sigma1=0.5, sigma2=0.5, augment_background=False)
standardized_norm = LuminosityStandardizer.standardize(norm)
augmentor.fit(standardized_norm)
ax_alg[0].imshow(standardized_norm)
ax_alg[0].set_title("Augmentation Original")
ax_alg[0].axis('off')
for j in range(1, len(ax_alg)):
aug_out = augmentor.pop().astype(np.uint8)
ax_alg[j].imshow(aug_out)
ax_alg[j].set_title(f"{alg} - StainTools: {j}")
ax_alg[j].axis('off')
plt.savefig(os.path.join('.', 'showcases', 'sample_out_augmentation_staintools.png'), bbox_inches='tight')
plt.show()


Binary file added showcases/sample_out_augmentation.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added showcases/sample_out_augmentation_staintools.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
2 changes: 1 addition & 1 deletion tests/images/test_functionals.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def extract_eval_helper(tester, get_stain_mat, luminosity_threshold,
def eval_wrapper(self, extractor):

# all pixel
algorithms = ['ista', 'cd']
algorithms = ['ista', 'cd', 'ls']
for alg in algorithms:
TestFunctional.extract_eval_helper(self, extractor, luminosity_threshold=None,
num_stains=2, regularizer=0.1, dict_algorithm=alg)
Expand Down
1 change: 1 addition & 0 deletions torch_staintools/augmentor/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .factory import *
Loading

0 comments on commit c696fde

Please sign in to comment.