From 185b892a082e0ecc906e0bd926a8a8a445fe31d6 Mon Sep 17 00:00:00 2001 From: Sergey Kolesnikov Date: Mon, 27 Sep 2021 09:00:08 +0300 Subject: [PATCH] Cleanup (#1303) * old BalanceBatchSampler removed * extra classification report added * fix * fix * fix * fix * fix * fix --- .github/PULL_REQUEST_TEMPLATE.md | 36 +-- README.md | 6 +- catalyst/callbacks/metrics/cmc_score.py | 8 +- catalyst/callbacks/mixup.py | 32 +- catalyst/callbacks/sklearn_model.py | 18 +- catalyst/contrib/utils/__init__.py | 2 + catalyst/contrib/utils/report.py | 125 ++++++++ catalyst/data/__init__.py | 2 +- catalyst/data/sampler.py | 199 ++++++------ catalyst/data/sampler_inbatch.py | 2 +- catalyst/metrics/_cmc_score.py | 6 +- docs/api/utils.rst | 7 + .../customizing_what_happens_in_train.ipynb | 286 +++++++++--------- examples/self_supervised/barlow_twins.py | 8 +- examples/self_supervised/common.py | 25 +- tests/catalyst/callbacks/test_metric.py | 14 +- tests/pipelines/test_metric_learning.py | 6 +- 17 files changed, 443 insertions(+), 339 deletions(-) create mode 100644 catalyst/contrib/utils/report.py diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md index f4bd0d50f3..c16d3f3ce4 100644 --- a/.github/PULL_REQUEST_TEMPLATE.md +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -1,19 +1,9 @@ -## Before submitting (checklist) - -- [ ] Was this discussed/approved via a Github issue? (no need for typos and docs improvements) -- [ ] Did you read the [contribution guide](https://github.com/catalyst-team/catalyst/blob/master/CONTRIBUTING.md)? -- [ ] Did you check the code style? `catalyst-make-codestyle -l 99 && catalyst-check-codestyle -l 99 ` (`pip install -U catalyst-codestyle`). -- [ ] Did you make sure to update the docs? We use Google format for all the methods and classes. -- [ ] Did you check the docs with `make check-docs`? -- [ ] Did you write any new necessary tests? -- [ ] Did you check that your code passes the unit tests `pytest .` ? -- [ ] Did you add your new functionality to the docs? -- [ ] Did you update the [CHANGELOG](https://github.com/catalyst-team/catalyst/blob/master/CHANGELOG.md)? -- [ ] Did you run [colab minimal CI/CD](https://colab.research.google.com/github/catalyst-team/catalyst/blob/master/examples/notebooks/colab_ci_cd.ipynb) with `latest` and `minimal` requirements? -- [ ] Did you check XLA integration with [single](https://colab.research.google.com/github/catalyst-team/catalyst/blob/master/examples/notebooks/Catalyst_XLA_single_process.ipynb) and [multiple](https://colab.research.google.com/github/catalyst-team/catalyst/blob/master/examples/notebooks/Catalyst_XLA_multi_process.ipynb) processes? - - - +### Pull Request FAQ +- [documentation](https://catalyst-team.github.io/catalyst/) +- [contribution guide](https://github.com/catalyst-team/catalyst/blob/master/CONTRIBUTING.md) +- [minimal examples section](https://github.com/catalyst-team/catalyst#minimal-examples) +- [changelog](https://github.com/catalyst-team/catalyst/blob/master/CHANGELOG.md) for main framework updates +- [Catalyst slack (#__questions channel)](https://join.slack.com/t/catalyst-team-core/shared_invite/zt-d9miirnn-z86oKDzFMKlMG4fgFdZafw) for issue discussion ## Description @@ -43,11 +33,11 @@ If we didn't discuss your PR in Github issues there's a high chance it will not +### Checklist +- [ ] Have you updated tests for the new functionality? +- [ ] Have you added your new classes/functions to the docs? +- [ ] Have you updated the [CHANGELOG](https://github.com/catalyst-team/catalyst/blob/master/CHANGELOG.md)? +- [ ] Have you run [colab minimal CI/CD](https://colab.research.google.com/github/catalyst-team/catalyst/blob/master/examples/notebooks/colab_ci_cd.ipynb) with `latest` and `minimal` requirements? +- [ ] Have you checked XLA integration with [single](https://colab.research.google.com/github/catalyst-team/catalyst/blob/master/examples/notebooks/Catalyst_XLA_single_process.ipynb) and [multiple](https://colab.research.google.com/github/catalyst-team/catalyst/blob/master/examples/notebooks/Catalyst_XLA_multi_process.ipynb) processes? -### FAQ -Please review the FAQ before submitting an issue: -- [ ] I have read the [documentation and FAQ](https://catalyst-team.github.io/catalyst/) -- [ ] I have reviewed the [minimal examples section](https://github.com/catalyst-team/catalyst#minimal-examples) -- [ ] I have checked the [changelog](https://github.com/catalyst-team/catalyst/blob/master/CHANGELOG.md) for main framework updates -- [ ] I have read the [contribution guide](https://github.com/catalyst-team/catalyst/blob/master/CONTRIBUTING.md) -- [ ] I have joined [Catalyst slack (#__questions channel)](https://join.slack.com/t/catalyst-team-core/shared_invite/zt-d9miirnn-z86oKDzFMKlMG4fgFdZafw) for issue discussion + \ No newline at end of file diff --git a/README.md b/README.md index 4edc515a91..8d4995fbbb 100644 --- a/README.md +++ b/README.md @@ -812,8 +812,10 @@ from catalyst.data.transforms import Compose, Normalize, ToTensor transforms = Compose([ToTensor(), Normalize((0.1307,), (0.3081,))]) train_dataset = datasets.MnistMLDataset(root=os.getcwd(), download=True, transform=transforms) -sampler = data.BalanceBatchSampler(labels=train_dataset.get_labels(), p=5, k=10) -train_loader = DataLoader(dataset=train_dataset, sampler=sampler, batch_size=sampler.batch_size) +sampler = data.BatchBalanceClassSampler( + labels=train_dataset.get_labels(), num_classes=5, num_samples=10, num_batches=10 +) +train_loader = DataLoader(dataset=train_dataset, batch_sampler=sampler) valid_dataset = datasets.MnistQGDataset(root=os.getcwd(), transform=transforms, gallery_fraq=0.2) valid_loader = DataLoader(dataset=valid_dataset, batch_size=1024) diff --git a/catalyst/callbacks/metrics/cmc_score.py b/catalyst/callbacks/metrics/cmc_score.py index 5b148d057e..1217587c00 100644 --- a/catalyst/callbacks/metrics/cmc_score.py +++ b/catalyst/callbacks/metrics/cmc_score.py @@ -52,10 +52,10 @@ class CMCScoreCallback(LoaderMetricCallback): train_dataset = datasets.MnistMLDataset( root=os.getcwd(), download=True, transform=transforms ) - sampler = data.BalanceBatchSampler(labels=train_dataset.get_labels(), p=5, k=10) - train_loader = DataLoader( - dataset=train_dataset, sampler=sampler, batch_size=sampler.batch_size - ) + sampler = data.BatchBalanceClassSampler( + labels=train_dataset.get_labels(), num_classes=5, num_samples=10 + ) + train_loader = DataLoader(dataset=train_dataset, batch_sampler=sampler) valid_dataset = datasets.MnistQGDataset( root=os.getcwd(), transform=transforms, gallery_fraq=0.2 diff --git a/catalyst/callbacks/mixup.py b/catalyst/callbacks/mixup.py index b1d69417f8..28fc54bdf3 100644 --- a/catalyst/callbacks/mixup.py +++ b/catalyst/callbacks/mixup.py @@ -10,6 +10,18 @@ class MixupCallback(Callback): Callback to do mixup augmentation. More details about mixin can be found in the paper `mixup: Beyond Empirical Risk Minimization`: https://arxiv.org/abs/1710.09412 . + Args: + keys: batch keys to which you want to apply augmentation + alpha: beta distribution a=b parameters. Must be >=0. The more alpha closer to zero the + less effect of the mixup. + mode: mode determines the method of use. Must be in ["replace", "add"]. If "replace" + then replaces the batch with a mixed one, while the batch size is not changed + If "add", concatenates mixed examples to the current ones, the batch size increases + by 2 times. + on_train_only: apply to train only. As the mixup use the proxy inputs, the targets are + also proxy. We are not interested in them, are we? So, if ``on_train_only`` + is ``True`` use a standard output/metric for validation. + Examples: .. code-block:: python @@ -107,24 +119,8 @@ def handle_batch(self, batch): use ControlFlowCallback in order to evaluate model(see example) """ - def __init__( - self, keys: Union[str, List[str]], alpha=0.2, mode="replace", on_train_only=True, **kwargs - ): - """ - - Args: - keys: batch keys to which you want to apply augmentation - alpha: beta distribution a=b parameters. Must be >=0. The more alpha closer to zero the - less effect of the mixup. - mode: mode determines the method of use. Must be in ["replace", "add"]. If "replace" - then replaces the batch with a mixed one, while the batch size is not changed - If "add", concatenates mixed examples to the current ones, the batch size increases - by 2 times. - on_train_only: apply to train only. As the mixup use the proxy inputs, the targets are - also proxy. We are not interested in them, are we? So, if ``on_train_only`` - is ``True`` use a standard output/metric for validation. - **kwargs: - """ + def __init__(self, keys: Union[str, List[str]], alpha=0.2, mode="replace", on_train_only=True): + """Init.""" assert isinstance(keys, (str, list, tuple)), ( f"keys must be str of list[str]," f" get: {type(keys)}" ) diff --git a/catalyst/callbacks/sklearn_model.py b/catalyst/callbacks/sklearn_model.py index f0814ca8bf..9550c7d4eb 100644 --- a/catalyst/callbacks/sklearn_model.py +++ b/catalyst/callbacks/sklearn_model.py @@ -48,11 +48,10 @@ class SklearnModelCallback(Callback): download=True, transform=transforms ) - sampler = data.BalanceBatchSampler(labels=train_dataset.get_labels(), p=5, k=10) - train_loader = DataLoader( - dataset=train_dataset, - sampler=sampler, - batch_size=sampler.batch_size) + sampler = data.BatchBalanceClassSampler( + labels=train_dataset.get_labels(), num_classes=5, num_samples=10 + ) + train_loader = DataLoader(dataset=train_dataset, batch_sampler=sampler) valid_dataset = datasets.MNIST(root=os.getcwd(), transform=transforms, train=False) valid_loader = DataLoader(dataset=valid_dataset, batch_size=1024) @@ -137,11 +136,10 @@ def handle_batch(self, batch) -> None: download=True, transform=transforms ) - sampler = data.BalanceBatchSampler(labels=train_dataset.get_labels(), p=5, k=10) - train_loader = DataLoader( - dataset=train_dataset, - sampler=sampler, - batch_size=sampler.batch_size) + sampler = data.BatchBalanceClassSampler( + labels=train_dataset.get_labels(), num_classes=5, num_samples=10 + ) + train_loader = DataLoader(dataset=train_dataset, batch_sampler=sampler) valid_dataset = datasets.MNIST(root=os.getcwd(), transform=transforms, train=False) valid_loader = DataLoader(dataset=valid_dataset, batch_size=1024) diff --git a/catalyst/contrib/utils/__init__.py b/catalyst/contrib/utils/__init__.py index 76bf25fb34..88a58a8f05 100644 --- a/catalyst/contrib/utils/__init__.py +++ b/catalyst/contrib/utils/__init__.py @@ -43,6 +43,8 @@ get_pool, ) +if SETTINGS.ml_required: + from catalyst.contrib.utils.report import get_classification_report from catalyst.contrib.utils.serialization import deserialize, serialize diff --git a/catalyst/contrib/utils/report.py b/catalyst/contrib/utils/report.py new file mode 100644 index 0000000000..41aab42ffb --- /dev/null +++ b/catalyst/contrib/utils/report.py @@ -0,0 +1,125 @@ +from collections import defaultdict + +import numpy as np +import pandas as pd +from sklearn.metrics import accuracy_score, precision_recall_fscore_support, roc_auc_score + + +def get_classification_report( + y_true: np.ndarray, y_pred: np.ndarray, y_scores: np.ndarray = None, beta: float = None +) -> pd.DataFrame: + """Generates pandas-based per-class and aggregated classification metrics. + + Args: + y_true (np.ndarray): ground truth labels + y_pred (np.ndarray): predicted model labels + y_scores (np.ndarray): predicted model scores. Defaults to None. + beta (float, optional): Beta parameter for custom Fbeta score computation. + Defaults to None. + + Returns: + pd.DataFrame: pandas dataframe with main classification metrics. + + Examples: + + .. code-block:: python + + from sklearn import datasets, linear_model, metrics + from sklearn.model_selection import train_test_split + from catalyst import utils + + digits = datasets.load_digits() + + # flatten the images + n_samples = len(digits.images) + data = digits.images.reshape((n_samples, -1)) + + # Create a classifier + clf = linear_model.LogisticRegression(multi_class="ovr") + + # Split data into 50% train and 50% test subsets + X_train, X_test, y_train, y_test = train_test_split( + data, digits.target, test_size=0.5, shuffle=False) + + # Learn the digits on the train subset + clf.fit(X_train, y_train) + + # Predict the value of the digit on the test subset + y_scores = clf.predict_proba(X_test) + y_pred = clf.predict(X_test) + + utils.get_classification_report( + y_true=y_test, + y_pred=y_pred, + y_scores=y_scores, + beta=0.5 + ) + """ + metrics = defaultdict(lambda: {}) + metrics_names = [ + "precision", + "recall", + "f1-score", + "auc", + "support", + "support (%)", + ] + avg_names = ["macro", "micro", "weighted"] + labels = sorted(set(y_true).union(y_pred)) + auc = np.zeros(len(labels)) + if y_scores is not None: + for i, label in enumerate(labels): + auc[i] = roc_auc_score((y_true == label).astype(int), y_scores[:, i]) + + accuracy = accuracy_score(y_true=y_true, y_pred=y_pred) + precision, recall, f1, support = precision_recall_fscore_support( + y_true=y_true, y_pred=y_pred, average=None, labels=labels + ) + + r_support = support / support.sum() + for average in avg_names: + avg_precision, avg_recall, avg_f1, _ = precision_recall_fscore_support( + y_true=y_true, y_pred=y_pred, average=average, labels=labels + ) + + avg_metrics = avg_precision, avg_recall, avg_f1 + for k, v in zip(metrics_names[:4], avg_metrics): + metrics[k][average] = v + + report = pd.DataFrame( + [precision, recall, f1, auc, support, r_support], columns=labels, index=metrics_names + ).T + + if beta is not None: + _, _, fbeta, _ = precision_recall_fscore_support( + y_true=y_true, y_pred=y_pred, average=None, beta=beta, labels=labels + ) + avg_fbeta = np.zeros(len(avg_names)) + for i, average in enumerate(avg_names): + _, _, avg_beta, _ = precision_recall_fscore_support( + y_true=y_true, y_pred=y_pred, average=average, beta=beta, labels=labels + ) + avg_fbeta[i] = avg_beta + report.insert(3, "f-beta", fbeta, True) + + metrics["support"]["macro"] = support.sum() + metrics["precision"]["accuracy"] = accuracy + if y_scores is not None: + metrics["auc"]["macro"] = roc_auc_score( + y_true, y_scores, multi_class="ovr", average="macro" + ) + metrics["auc"]["weighted"] = roc_auc_score( + y_true, y_scores, multi_class="ovr", average="weighted" + ) + metrics = pd.DataFrame(metrics, index=avg_names + ["accuracy"]) + + result = pd.concat((report, metrics)).fillna("") + + if beta: + result["f-beta"]["macro"] = avg_fbeta[0] + result["f-beta"]["micro"] = avg_fbeta[1] + result["f-beta"]["weighted"] = avg_fbeta[2] + return result + + +__all__ = ["get_classification_report"] diff --git a/catalyst/data/__init__.py b/catalyst/data/__init__.py index 2d44558a02..eeb2c16c17 100644 --- a/catalyst/data/__init__.py +++ b/catalyst/data/__init__.py @@ -16,8 +16,8 @@ BatchPrefetchLoaderWrapper, ) from catalyst.data.sampler import ( - BalanceClassSampler, BalanceBatchSampler, + BalanceClassSampler, BatchBalanceClassSampler, DistributedSamplerWrapper, DynamicLenBatchSampler, diff --git a/catalyst/data/sampler.py b/catalyst/data/sampler.py index a8d20c569d..2451cb93d9 100644 --- a/catalyst/data/sampler.py +++ b/catalyst/data/sampler.py @@ -14,107 +14,15 @@ from catalyst.utils.misc import find_value_ids -class BalanceClassSampler(Sampler): - """Allows you to create stratified sample on unbalanced classes. - - Args: - labels: list of class label for each elem in the dataset - mode: Strategy to balance classes. - Must be one of [downsampling, upsampling] - - Python API examples: - - .. code-block:: python - - import os - from torch import nn, optim - from torch.utils.data import DataLoader - from catalyst import dl - from catalyst.data import ToTensor, BalanceClassSampler - from catalyst.contrib.datasets import MNIST - - train_data = MNIST(os.getcwd(), train=True, download=True, transform=ToTensor()) - train_labels = train_data.targets.cpu().numpy().tolist() - train_sampler = BalanceClassSampler(train_labels, mode=5000) - valid_data = MNIST(os.getcwd(), train=False, download=True, transform=ToTensor()) - - loaders = { - "train": DataLoader(train_data, sampler=train_sampler, batch_size=32), - "valid": DataLoader(valid_data, batch_size=32), - } - - model = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 10)) - criterion = nn.CrossEntropyLoss() - optimizer = optim.Adam(model.parameters(), lr=0.02) - - runner = dl.SupervisedRunner() - # model training - runner.train( - model=model, - criterion=criterion, - optimizer=optimizer, - loaders=loaders, - num_epochs=1, - logdir="./logs", - valid_loader="valid", - valid_metric="loss", - minimize_valid_metric=True, - verbose=True, - ) - """ - - def __init__(self, labels: List[int], mode: Union[str, int] = "downsampling"): - """Sampler initialisation.""" - super().__init__(labels) - - labels = np.array(labels) - samples_per_class = {label: (labels == label).sum() for label in set(labels)} - - self.lbl2idx = { - label: np.arange(len(labels))[labels == label].tolist() for label in set(labels) - } - - if isinstance(mode, str): - assert mode in ["downsampling", "upsampling"] - - if isinstance(mode, int) or mode == "upsampling": - samples_per_class = mode if isinstance(mode, int) else max(samples_per_class.values()) - else: - samples_per_class = min(samples_per_class.values()) - - self.labels = labels - self.samples_per_class = samples_per_class - self.length = self.samples_per_class * len(set(labels)) - - def __iter__(self) -> Iterator[int]: - """ - Yields: - indices of stratified sample - """ - indices = [] - for key in sorted(self.lbl2idx): - replace_flag = self.samples_per_class > len(self.lbl2idx[key]) - indices += np.random.choice( - self.lbl2idx[key], self.samples_per_class, replace=replace_flag - ).tolist() - assert len(indices) == self.length - np.random.shuffle(indices) - - return iter(indices) - - def __len__(self) -> int: - """ - Returns: - length of result sample - """ - return self.length - - class BalanceBatchSampler(Sampler): """ This kind of sampler can be used for both metric learning and classification task. + .. warning:: + Deprecated realization, used for backward compatibility. + Please use `BatchBalanceClassSampler` instead. + Sampler with the given strategy for the C unique classes dataset: - Selection P of C classes for the 1st batch - Selection K instances for each class for the 1st batch @@ -224,9 +132,104 @@ def __iter__(self) -> Iterator[int]: return iter(inds) +class BalanceClassSampler(Sampler): + """Allows you to create stratified sample on unbalanced classes. + + Args: + labels: list of class label for each elem in the dataset + mode: Strategy to balance classes. + Must be one of [downsampling, upsampling] + + Python API examples: + + .. code-block:: python + + import os + from torch import nn, optim + from torch.utils.data import DataLoader + from catalyst import dl + from catalyst.data import ToTensor, BalanceClassSampler + from catalyst.contrib.datasets import MNIST + + train_data = MNIST(os.getcwd(), train=True, download=True, transform=ToTensor()) + train_labels = train_data.targets.cpu().numpy().tolist() + train_sampler = BalanceClassSampler(train_labels, mode=5000) + valid_data = MNIST(os.getcwd(), train=False, download=True, transform=ToTensor()) + + loaders = { + "train": DataLoader(train_data, sampler=train_sampler, batch_size=32), + "valid": DataLoader(valid_data, batch_size=32), + } + + model = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 10)) + criterion = nn.CrossEntropyLoss() + optimizer = optim.Adam(model.parameters(), lr=0.02) + + runner = dl.SupervisedRunner() + # model training + runner.train( + model=model, + criterion=criterion, + optimizer=optimizer, + loaders=loaders, + num_epochs=1, + logdir="./logs", + valid_loader="valid", + valid_metric="loss", + minimize_valid_metric=True, + verbose=True, + ) + """ + + def __init__(self, labels: List[int], mode: Union[str, int] = "downsampling"): + """Sampler initialisation.""" + super().__init__(labels) + + labels = np.array(labels) + samples_per_class = {label: (labels == label).sum() for label in set(labels)} + + self.lbl2idx = { + label: np.arange(len(labels))[labels == label].tolist() for label in set(labels) + } + + if isinstance(mode, str): + assert mode in ["downsampling", "upsampling"] + + if isinstance(mode, int) or mode == "upsampling": + samples_per_class = mode if isinstance(mode, int) else max(samples_per_class.values()) + else: + samples_per_class = min(samples_per_class.values()) + + self.labels = labels + self.samples_per_class = samples_per_class + self.length = self.samples_per_class * len(set(labels)) + + def __iter__(self) -> Iterator[int]: + """ + Yields: + indices of stratified sample + """ + indices = [] + for key in sorted(self.lbl2idx): + replace_flag = self.samples_per_class > len(self.lbl2idx[key]) + indices += np.random.choice( + self.lbl2idx[key], self.samples_per_class, replace=replace_flag + ).tolist() + assert len(indices) == self.length + np.random.shuffle(indices) + + return iter(indices) + + def __len__(self) -> int: + """ + Returns: + length of result sample + """ + return self.length + + class BatchBalanceClassSampler(Sampler): """ - BatchSampler version of BalanceBatchSampler. This kind of sampler can be used for both metric learning and classification task. BatchSampler with the given strategy for the C unique classes dataset: @@ -729,8 +732,8 @@ def __iter__(self) -> Iterator[int]: __all__ = [ - "BalanceClassSampler", "BalanceBatchSampler", + "BalanceClassSampler", "BatchBalanceClassSampler", "DistributedSamplerWrapper", "DynamicBalanceClassSampler", diff --git a/catalyst/data/sampler_inbatch.py b/catalyst/data/sampler_inbatch.py index 4db757a3e8..fd6ffb65a3 100644 --- a/catalyst/data/sampler_inbatch.py +++ b/catalyst/data/sampler_inbatch.py @@ -65,7 +65,7 @@ class InBatchTripletsSampler(IInbatchTripletSampler): The batches must contain at least 2 samples for each class and at least 2 different classes, such behaviour can be garantee via using - catalyst.data.sampler.BalanceBatchSampler + catalyst.data.sampler.BatchBalanceClassSampler But you are not limited to using it in any other way. """ diff --git a/catalyst/metrics/_cmc_score.py b/catalyst/metrics/_cmc_score.py index c603eb6341..d564f1eb44 100644 --- a/catalyst/metrics/_cmc_score.py +++ b/catalyst/metrics/_cmc_score.py @@ -74,10 +74,10 @@ class CMCMetric(AccumulativeMetric): train_dataset = datasets.MnistMLDataset( root=os.getcwd(), download=True, transform=transforms ) - sampler = data.BalanceBatchSampler(labels=train_dataset.get_labels(), p=5, k=10) - train_loader = DataLoader( - dataset=train_dataset, sampler=sampler, batch_size=sampler.batch_size + sampler = data.BatchBalanceClassSampler( + labels=train_dataset.get_labels(), num_classes=5, num_samples=10 ) + train_loader = DataLoader(dataset=train_dataset, batch_sampler=sampler) valid_dataset = datasets.MnistQGDataset( root=os.getcwd(), transform=transforms, gallery_fraq=0.2 diff --git a/docs/api/utils.rst b/docs/api/utils.rst index 09d7e6dbe8..71eb14e15a 100644 --- a/docs/api/utils.rst +++ b/docs/api/utils.rst @@ -141,6 +141,13 @@ Parallel :undoc-members: :show-inheritance: +Report +~~~~~~~~~~~~~~~~~~~~~~ +.. automodule:: catalyst.contrib.utils.report + :members: + :undoc-members: + :show-inheritance: + Serialization ~~~~~~~~~~~~~~~~~~~~~~ .. automodule:: catalyst.contrib.utils.serialization diff --git a/examples/notebooks/customizing_what_happens_in_train.ipynb b/examples/notebooks/customizing_what_happens_in_train.ipynb index 6aa140eab3..403471541d 100644 --- a/examples/notebooks/customizing_what_happens_in_train.ipynb +++ b/examples/notebooks/customizing_what_happens_in_train.ipynb @@ -2,21 +2,17 @@ "cells": [ { "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "poK9I4m9wSTW" - }, "source": [ "# Catalyst - customizing what happens in `train()`\n", "based on `Keras customizing what happens in fit`" - ] + ], + "metadata": { + "colab_type": "text", + "id": "poK9I4m9wSTW" + } }, { "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "NJb_eBKEwXS6" - }, "source": [ "## Introduction\n", "\n", @@ -29,21 +25,30 @@ "Note that this pattern does not prevent you from building models with the Functional API. You can do this with **any** PyTorch model.\n", "\n", "Let's see how that works." - ] + ], + "metadata": { + "colab_type": "text", + "id": "NJb_eBKEwXS6" + } }, { "cell_type": "markdown", + "source": [ + "## Setup" + ], "metadata": { "colab_type": "text", "id": "E6R34jh5xKkW" - }, - "source": [ - "## Setup" - ] + } }, { "cell_type": "code", "execution_count": null, + "source": [ + "!pip install catalyst[ml]==21.8\n", + "# don't forget to restart runtime for correct `PIL` work with Colab" + ], + "outputs": [], "metadata": { "colab": { "base_uri": "https://localhost:8080/", @@ -63,16 +68,17 @@ }, "id": "S1rkIIaKaG2O", "outputId": "5c4d2dc1-74a0-4b04-f3f5-9a08e8efc28c" - }, - "outputs": [], - "source": [ - "!pip install catalyst[ml]==21.4.2\n", - "# don't forget to restart runtime for correct `PIL` work with Colab" - ] + } }, { "cell_type": "code", "execution_count": null, + "source": [ + "import catalyst\n", + "from catalyst import dl, metrics, utils\n", + "catalyst.__version__" + ], + "outputs": [], "metadata": { "colab": { "base_uri": "https://localhost:8080/", @@ -92,20 +98,10 @@ }, "id": "3eLP4fR6wCYc", "outputId": "1e36766e-6d62-46da-894f-8d2e4967544c" - }, - "outputs": [], - "source": [ - "import catalyst\n", - "from catalyst import dl, metrics, utils\n", - "catalyst.__version__" - ] + } }, { "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "5F8q4oByxt2T" - }, "source": [ "## A first simple example\n", "\n", @@ -120,17 +116,15 @@ "In the body of the `handle_batch` method, we implement a regular training update, similar to what you are already familiar with. Importantly, **we log batch-based metrics via `self.batch_metrics`**, which passes them to the loggers.\n", "\n", "Addiionally, we have to use [`AdditiveMetric`](https://catalyst-team.github.io/catalyst/api/metrics.html#AdditiveMetric) during `on_loader_start` and `on_loader_start` for correct metrics aggregation for the whole loader. Importantly, **we log loader-based metrics via `self.loader_metrics`**, which passes them to the loggers." - ] + ], + "metadata": { + "colab_type": "text", + "id": "5F8q4oByxt2T" + } }, { "cell_type": "code", "execution_count": null, - "metadata": { - "colab": {}, - "colab_type": "code", - "id": "MbTkRLQUxQmC" - }, - "outputs": [], "source": [ "import torch\n", "from torch.nn import functional as F\n", @@ -171,42 +165,27 @@ " for key in [\"loss\", \"mae\"]:\n", " self.loader_metrics[key] = self.meters[key].compute()[0]\n", " super().on_loader_end(runner)" - ] + ], + "outputs": [], + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "MbTkRLQUxQmC" + } }, { "cell_type": "markdown", + "source": [ + "Let's try this out:" + ], "metadata": { "colab_type": "text", "id": "nAEiVP4IzNj-" - }, - "source": [ - "Let's try this out:" - ] + } }, { "cell_type": "code", "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/", - "height": 562 - }, - "colab_type": "code", - "executionInfo": { - "elapsed": 17386, - "status": "error", - "timestamp": 1587015544733, - "user": { - "displayName": "Sergey Kolesnikov", - "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GjYXcxiGiIDQYhW2wTkdrNLwx68llP5BzH91oGlAQ=s64", - "userId": "07081474162282073276" - }, - "user_tz": -180 - }, - "id": "AlUHnIG6zPV9", - "outputId": "bcc53cac-174d-4a1e-c3e8-441c102609cb" - }, - "outputs": [], "source": [ "import numpy as np\n", "import torch\n", @@ -234,25 +213,44 @@ " verbose=True, # you can pass True for more precise training process logging\n", " timeit=False, # you can pass True to measure execution time of different parts of train process\n", ")" - ] + ], + "outputs": [], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 562 + }, + "colab_type": "code", + "executionInfo": { + "elapsed": 17386, + "status": "error", + "timestamp": 1587015544733, + "user": { + "displayName": "Sergey Kolesnikov", + "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GjYXcxiGiIDQYhW2wTkdrNLwx68llP5BzH91oGlAQ=s64", + "userId": "07081474162282073276" + }, + "user_tz": -180 + }, + "id": "AlUHnIG6zPV9", + "outputId": "bcc53cac-174d-4a1e-c3e8-441c102609cb" + } }, { "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "NGJgVd9lzkQc" - }, "source": [ "## Going high-level\n", "\n", "Naturally, you could skip a loss function backward in `handle_batch()`, and instead do everything with `Callbacks` in `train` params. Likewise for metrics. Here's a high-level example, that only uses `handle_batch()` for model forward pass and metrics computation:" - ] + ], + "metadata": { + "colab_type": "text", + "id": "NGJgVd9lzkQc" + } }, { "cell_type": "code", "execution_count": null, - "metadata": {}, - "outputs": [], "source": [ "import numpy as np\n", "import torch\n", @@ -329,14 +327,12 @@ " )\n", " }\n", ")" - ] + ], + "outputs": [], + "metadata": {} }, { "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "vVZtc6P61icn" - }, "source": [ "## Metrics support through Callbacks\n", "\n", @@ -347,13 +343,15 @@ "- Add extra callbacks, that will use data from `runner.batch` during training.\n", "\n", "That's it. That's the list. Let's see the example:" - ] + ], + "metadata": { + "colab_type": "text", + "id": "vVZtc6P61icn" + } }, { "cell_type": "code", "execution_count": null, - "metadata": {}, - "outputs": [], "source": [ "import numpy as np\n", "import torch\n", @@ -426,24 +424,24 @@ " )\n", " }\n", ")" - ] + ], + "outputs": [], + "metadata": {} }, { "cell_type": "markdown", - "metadata": {}, "source": [ "## Simplify it a bit - SupervisedRunner\n", "\n", "But can we simplify last example a bit?
\n", "What if we know, that we are going to train `supervised` model, that will take some `features` in and output some `logits` back?
\n", "Looks like commom case... could we automate it? Let's check it out!" - ] + ], + "metadata": {} }, { "cell_type": "code", "execution_count": null, - "metadata": {}, - "outputs": [], "source": [ "import numpy as np\n", "import torch\n", @@ -508,27 +506,27 @@ "# )\n", " }\n", ")" - ] + ], + "outputs": [], + "metadata": {} }, { "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "rn1q6NCP2dtR" - }, "source": [ "## Providing your own inference step\n", "\n", "But let's return to the basics.\n", "\n", "What if you want to do the same customization for calls to `runner.predict_*()`? Then you would override `predict_batch` in exactly the same way. Here's what it looks like:" - ] + ], + "metadata": { + "colab_type": "text", + "id": "rn1q6NCP2dtR" + } }, { "cell_type": "code", "execution_count": null, - "metadata": {}, - "outputs": [], "source": [ "import torch\n", "from torch.nn import functional as F\n", @@ -573,13 +571,13 @@ " for key in [\"loss\", \"mae\"]:\n", " self.loader_metrics[key] = self.meters[key].compute()[0]\n", " super().on_loader_end(runner)" - ] + ], + "outputs": [], + "metadata": {} }, { "cell_type": "code", "execution_count": null, - "metadata": {}, - "outputs": [], "source": [ "import numpy as np\n", "import torch\n", @@ -616,22 +614,22 @@ "# or `loader` prediction\n", "for prediction in runner.predict_loader(loader=loader):\n", " assert prediction.detach().cpu().numpy().shape[-1] == 1 # as we have 1-class regression" - ] + ], + "outputs": [], + "metadata": {} }, { "cell_type": "markdown", - "metadata": {}, "source": [ "Finally, after model training and evaluation, it's time to prepare it for deployment. PyTorch upport model tracing for production-friendly Deep Leanring models deployment.\n", "\n", "Could we make it quick with Catalyst? Sure!" - ] + ], + "metadata": {} }, { "cell_type": "code", "execution_count": null, - "metadata": {}, - "outputs": [], "source": [ "features_batch = next(iter(loaders[\"valid\"]))[0].to(runner.device)\n", "# model stochastic weight averaging\n", @@ -644,14 +642,12 @@ "utils.prune_model(model=runner.model, pruning_fn=\"l1_unstructured\", amount=0.8)\n", "# onnx export, catalyst[onnx] or catalyst[onnx-gpu] required\n", "# utils.onnx_export(model=runner.model, batch=features_batch, file=\"./logs/mnist.onnx\", verbose=True)" - ] + ], + "outputs": [], + "metadata": {} }, { "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "9-CMpP5a3Wcp" - }, "source": [ "## Wrapping up: an end-to-end GAN example\n", "\n", @@ -662,13 +658,15 @@ "- A generator network meant to generate 28x28x1 images.\n", "- A discriminator network meant to classify 28x28x1 images into two classes (\"fake\" - 1 and \"real\" - 0).\n", "\n" - ] + ], + "metadata": { + "colab_type": "text", + "id": "9-CMpP5a3Wcp" + } }, { "cell_type": "code", "execution_count": null, - "metadata": {}, - "outputs": [], "source": [ "import torch\n", "from torch import nn\n", @@ -708,27 +706,23 @@ " \"generator\": torch.optim.Adam(generator.parameters(), lr=0.0003, betas=(0.5, 0.999)),\n", " \"discriminator\": torch.optim.Adam(discriminator.parameters(), lr=0.0003, betas=(0.5, 0.999)),\n", "}" - ] + ], + "outputs": [], + "metadata": {} }, { "cell_type": "markdown", + "source": [ + "Here's a feature-complete `GANRunner`, overriding `predict_batch()` to use its own signature, and implementing the entire GAN algorithm in 16 lines in `handle_batch`:" + ], "metadata": { "colab_type": "text", "id": "POY42XRf5Jbd" - }, - "source": [ - "Here's a feature-complete `GANRunner`, overriding `predict_batch()` to use its own signature, and implementing the entire GAN algorithm in 16 lines in `handle_batch`:" - ] + } }, { "cell_type": "code", "execution_count": null, - "metadata": { - "colab": {}, - "colab_type": "code", - "id": "iyKOtjfn5RL3" - }, - "outputs": [], "source": [ "class GANRunner(dl.Runner):\n", " \n", @@ -780,25 +774,27 @@ " \"generated_predictions\": generated_predictions,\n", " \"misleading_labels\": misleading_labels,\n", " }" - ] + ], + "outputs": [], + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "iyKOtjfn5RL3" + } }, { "cell_type": "markdown", + "source": [ + "Let's test-drive it:" + ], "metadata": { "colab_type": "text", "id": "zYGZRIJh6ZYu" - }, - "source": [ - "Let's test-drive it:" - ] + } }, { "cell_type": "code", "execution_count": null, - "metadata": { - "scrolled": false - }, - "outputs": [], "source": [ "import os\n", "from torch.utils.data import DataLoader\n", @@ -848,23 +844,25 @@ " verbose=True,\n", " logdir=\"./logs_gan\",\n", ")" - ] + ], + "outputs": [], + "metadata": { + "scrolled": false + } }, { "cell_type": "markdown", + "source": [ + "The idea behind deep learning are simple, so why should their implementation be painful?" + ], "metadata": { "colab_type": "text", "id": "M9Fz5_u68FqW" - }, - "source": [ - "The idea behind deep learning are simple, so why should their implementation be painful?" - ] + } }, { "cell_type": "code", "execution_count": null, - "metadata": {}, - "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", "%matplotlib inline\n", @@ -872,24 +870,26 @@ "utils.set_global_seed(42)\n", "generated_image = runner.predict_batch(None)\n", "plt.imshow(generated_image[0, 0].detach().cpu().numpy())" - ] + ], + "outputs": [], + "metadata": {} }, { "cell_type": "code", "execution_count": null, - "metadata": {}, - "outputs": [], "source": [ "%load_ext tensorboard\n", "%tensorboard --logdir ./logs_gan" - ] + ], + "outputs": [], + "metadata": {} }, { "cell_type": "code", "execution_count": null, - "metadata": {}, + "source": [], "outputs": [], - "source": [] + "metadata": {} } ], "metadata": { @@ -928,4 +928,4 @@ }, "nbformat": 4, "nbformat_minor": 1 -} +} \ No newline at end of file diff --git a/examples/self_supervised/barlow_twins.py b/examples/self_supervised/barlow_twins.py index ac61002b40..fe03da601e 100644 --- a/examples/self_supervised/barlow_twins.py +++ b/examples/self_supervised/barlow_twins.py @@ -81,9 +81,7 @@ def forward(self, x): callbacks = [ dl.ControlFlowCallback( dl.CriterionCallback( - input_key="projection_left", - target_key="projection_right", - metric_key="loss", + input_key="projection_left", target_key="projection_right", metric_key="loss" ), loaders="train", ), @@ -99,9 +97,7 @@ def forward(self, x): dl.OptimizerCallback(metric_key="loss"), dl.ControlFlowCallback( dl.AccuracyCallback( - target_key="target", - input_key="sklearn_predict", - topk_args=(1, 3), + target_key="target", input_key="sklearn_predict", topk_args=(1, 3) ), loaders="valid", ), diff --git a/examples/self_supervised/common.py b/examples/self_supervised/common.py index 5f5bfdd2ec..9badc62284 100644 --- a/examples/self_supervised/common.py +++ b/examples/self_supervised/common.py @@ -15,34 +15,19 @@ def add_arguments(parser) -> None: parser: argparser like object """ parser.add_argument( - "--feature_dim", - default=128, - type=int, - help="Feature dim for latent vector", + "--feature_dim", default=128, type=int, help="Feature dim for latent vector" ) parser.add_argument( - "--temperature", - default=0.5, - type=float, - help="Temperature used in softmax", + "--temperature", default=0.5, type=float, help="Temperature used in softmax" ) parser.add_argument( - "--batch_size", - default=512, - type=int, - help="Number of images in each mini-batch", + "--batch_size", default=512, type=int, help="Number of images in each mini-batch" ) parser.add_argument( - "--epochs", - default=1000, - type=int, - help="Number of sweeps over the dataset to train", + "--epochs", default=1000, type=int, help="Number of sweeps over the dataset to train" ) parser.add_argument( - "--num_workers", - default=8, - type=float, - help="Number of workers to process a dataloader", + "--num_workers", default=8, type=float, help="Number of workers to process a dataloader" ) parser.add_argument( "--logdir", diff --git a/tests/catalyst/callbacks/test_metric.py b/tests/catalyst/callbacks/test_metric.py index 6dd3c499fa..dec742a3cc 100644 --- a/tests/catalyst/callbacks/test_metric.py +++ b/tests/catalyst/callbacks/test_metric.py @@ -217,10 +217,10 @@ def test_metric_learning_pipeline(): """ with TemporaryDirectory() as tmp_dir: dataset_train = datasets.MnistMLDataset(root=tmp_dir, download=True) - sampler = data.BalanceBatchSampler(labels=dataset_train.get_labels(), p=5, k=10) - train_loader = DataLoader( - dataset=dataset_train, sampler=sampler, batch_size=sampler.batch_size + sampler = data.BatchBalanceClassSampler( + labels=dataset_train.get_labels(), num_classes=3, num_samples=10, num_batches=10 ) + train_loader = DataLoader(dataset=dataset_train, batch_sampler=sampler, num_workers=0) dataset_val = datasets.MnistQGDataset(root=tmp_dir, transform=None, gallery_fraq=0.2) val_loader = DataLoader(dataset=dataset_val, batch_size=1024) @@ -273,10 +273,10 @@ def test_reid_pipeline(): transforms = Compose([ToTensor(), Normalize((0.1307,), (0.3081,))]) train_dataset = MnistMLDataset(root=os.getcwd(), download=True, transform=transforms) - sampler = data.BalanceBatchSampler(labels=train_dataset.get_labels(), p=5, k=10) - train_loader = DataLoader( - dataset=train_dataset, sampler=sampler, batch_size=sampler.batch_size + sampler = data.BatchBalanceClassSampler( + labels=train_dataset.get_labels(), num_classes=3, num_samples=10, num_batches=20 ) + train_loader = DataLoader(dataset=train_dataset, batch_sampler=sampler, num_workers=0) valid_dataset = MnistReIDQGDataset( root=os.getcwd(), transform=transforms, gallery_fraq=0.2 @@ -326,7 +326,7 @@ def test_reid_pipeline(): valid_loader="valid", valid_metric="cmc01", minimize_valid_metric=False, - num_epochs=6, + num_epochs=10, ) assert "cmc01" in runner.loader_metrics assert runner.loader_metrics["cmc01"] > 0.7 diff --git a/tests/pipelines/test_metric_learning.py b/tests/pipelines/test_metric_learning.py index 951df3a8cf..aa7a376e16 100644 --- a/tests/pipelines/test_metric_learning.py +++ b/tests/pipelines/test_metric_learning.py @@ -46,10 +46,10 @@ def train_experiment(device, engine=None): train_dataset = datasets.MnistMLDataset( root=os.getcwd(), download=True, transform=transforms ) - sampler = data.BalanceBatchSampler(labels=train_dataset.get_labels(), p=5, k=10) - train_loader = DataLoader( - dataset=train_dataset, sampler=sampler, batch_size=sampler.batch_size + sampler = data.BatchBalanceClassSampler( + labels=train_dataset.get_labels(), num_classes=5, num_samples=10, num_batches=10 ) + train_loader = DataLoader(dataset=train_dataset, batch_sampler=sampler) valid_dataset = datasets.MnistQGDataset( root=os.getcwd(), transform=transforms, gallery_fraq=0.2