-
-
Notifications
You must be signed in to change notification settings - Fork 390
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* r2_score added * catalyst-make-codestyle _r2_score.py * r2 score LoaderMetric API is added * r2 score renamed to r2 squared * functional r2 metric name fix to r2_squared * test for functional r2 squared is added * compute key-value fix * args order in update fixed * args order fix * r2squared import is added to functional metrics init * r2squared callback is added * r2squared callback is added to metrics callbacks init * r2squared metric is added to metrics init * tests for r2squared is added * regression test update * metrics docs update * codestyle fix * torch.square to torch.pow fix) * codestyle update * spaces codestyle fix * codestyle fix * Update _r2_squared.py Co-authored-by: Sergey Kolesnikov <[email protected]>
- Loading branch information
Showing
10 changed files
with
307 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
from catalyst.callbacks.metric import LoaderMetricCallback | ||
from catalyst.metrics._r2_squared import R2Squared | ||
|
||
|
||
class R2SquaredCallback(LoaderMetricCallback): | ||
"""R2 Squared metric callback. | ||
Args: | ||
input_key: input key to use for r2squared calculation, specifies our ``y_true``. | ||
target_key: output key to use for r2squared calculation, specifies our ``y_pred``. | ||
prefix: metric prefix | ||
suffix: metric suffix | ||
Examples: | ||
.. code-block:: python | ||
import torch | ||
from torch.utils.data import DataLoader, TensorDataset | ||
from catalyst import dl | ||
# data | ||
num_samples, num_features = int(1e4), int(1e1) | ||
X, y = torch.rand(num_samples, num_features), torch.rand(num_samples) | ||
dataset = TensorDataset(X, y) | ||
loader = DataLoader(dataset, batch_size=32, num_workers=1) | ||
loaders = {"train": loader, "valid": loader} | ||
# model, criterion, optimizer, scheduler | ||
model = torch.nn.Linear(num_features, 1) | ||
criterion = torch.nn.MSELoss() | ||
optimizer = torch.optim.Adam(model.parameters()) | ||
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, [3, 6]) | ||
# model training | ||
runner = dl.SupervisedRunner() | ||
runner.train( | ||
model=model, | ||
criterion=criterion, | ||
optimizer=optimizer, | ||
scheduler=scheduler, | ||
loaders=loaders, | ||
logdir="./logdir", | ||
valid_loader="valid", | ||
valid_metric="loss", | ||
minimize_valid_metric=True, | ||
num_epochs=8, | ||
verbose=True, | ||
callbacks=[ | ||
dl.R2SquaredCallback(input_key="logits", target_key="targets") | ||
] | ||
) | ||
.. note:: | ||
Please follow the `minimal examples`_ sections for more use cases. | ||
.. _`minimal examples`: https://github.com/catalyst-team/catalyst#minimal-examples | ||
""" | ||
|
||
def __init__( | ||
self, | ||
input_key: str, | ||
target_key: str, | ||
prefix: str = None, | ||
suffix: str = None, | ||
): | ||
"""Init.""" | ||
super().__init__( | ||
metric=R2Squared(prefix=prefix, suffix=suffix), | ||
input_key=input_key, | ||
target_key=target_key, | ||
) | ||
|
||
|
||
__all__ = ["R2SquaredCallback"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
from typing import Optional | ||
|
||
import torch | ||
|
||
from catalyst.metrics._metric import ICallbackLoaderMetric | ||
|
||
|
||
class R2Squared(ICallbackLoaderMetric): | ||
"""This metric accumulates r2 score along loader | ||
Args: | ||
compute_on_call: if True, allows compute metric's value on call | ||
prefix: metric prefix | ||
suffix: metric suffix | ||
""" | ||
|
||
def __init__( | ||
self, | ||
compute_on_call: bool = True, | ||
prefix: Optional[str] = None, | ||
suffix: Optional[str] = None, | ||
) -> None: | ||
"""Init R2Squared""" | ||
super().__init__(compute_on_call=compute_on_call, prefix=prefix, suffix=suffix) | ||
self.metric_name = f"{self.prefix}r2squared{self.suffix}" | ||
self.num_examples = 0 | ||
self.delta_sum = 0 | ||
self.y_sum = 0 | ||
self.y_sq_sum = 0 | ||
|
||
def reset(self, num_batches: int, num_samples: int) -> None: | ||
""" | ||
Reset metrics fields | ||
""" | ||
self.num_examples = 0 | ||
self.delta_sum = 0 | ||
self.y_sum = 0 | ||
self.y_sq_sum = 0 | ||
|
||
def update(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> None: | ||
""" | ||
Update accumulated data with new batch | ||
""" | ||
self.num_examples += len(y_true) | ||
self.delta_sum += torch.sum(torch.pow(y_pred - y_true, 2)) | ||
self.y_sum += torch.sum(y_true) | ||
self.y_sq_sum += torch.sum(torch.pow(y_true, 2)) | ||
|
||
def compute(self) -> torch.Tensor: | ||
""" | ||
Return accumulated metric | ||
""" | ||
return 1 - self.delta_sum / (self.y_sq_sum - (self.y_sum ** 2) / self.num_examples) | ||
|
||
def compute_key_value(self) -> torch.Tensor: | ||
""" | ||
Return key-value | ||
""" | ||
r2squared = self.compute() | ||
output = {self.metric_name: r2squared} | ||
return output | ||
|
||
|
||
__all__ = ["R2Squared"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
from typing import Sequence | ||
|
||
import torch | ||
|
||
|
||
def r2_squared(outputs: torch.Tensor, targets: torch.Tensor) -> Sequence[torch.Tensor]: | ||
""" | ||
Computes regression r2 squared. | ||
Args: | ||
outputs: model outputs | ||
with shape [bs; 1] | ||
targets: ground truth | ||
with shape [bs; 1] | ||
Returns: | ||
float of computed r2 squared | ||
Examples: | ||
.. code-block:: python | ||
import torch | ||
from catalyst import metrics | ||
metrics.r2_squared( | ||
outputs=torch.tensor([0, 1, 2]), | ||
targets=torch.tensor([0, 1, 2]), | ||
) | ||
# tensor([1.]) | ||
.. code-block:: python | ||
import torch | ||
from catalyst import metrics | ||
metrics.r2_squared( | ||
outputs=torch.tensor([2.5, 0.0, 2, 8]), | ||
targets=torch.tensor([3, -0.5, 2, 7]), | ||
) | ||
# tensor([0.9486]) | ||
""" | ||
total_sum_of_squares = torch.sum( | ||
torch.pow(targets.float() - torch.mean(targets.float()), 2) | ||
).view(-1) | ||
residual_sum_of_squares = torch.sum(torch.pow(targets.float() - outputs.float(), 2)).view(-1) | ||
output = 1 - residual_sum_of_squares / total_sum_of_squares | ||
return output | ||
|
||
|
||
__all__ = ["r2_squared"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
# flake8: noqa | ||
import numpy as np | ||
|
||
import torch | ||
|
||
from catalyst.metrics.functional._r2_squared import r2_squared | ||
|
||
|
||
def test_r2_squared(): | ||
""" | ||
Tests for catalyst.metrics.r2_squared metric. | ||
""" | ||
y_true = torch.tensor([3, -0.5, 2, 7]) | ||
y_pred = torch.tensor([2.5, 0.0, 2, 8]) | ||
val = r2_squared(y_pred, y_true) | ||
assert torch.isclose(val, torch.Tensor([0.9486])) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
# flake8: noqa | ||
from typing import Dict, Iterable, Union | ||
|
||
import pytest | ||
|
||
import torch | ||
|
||
from catalyst.metrics._r2_squared import R2Squared | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"outputs,targets,true_values", | ||
( | ||
( | ||
torch.Tensor([2.5, 0.0, 2, 8]), | ||
torch.Tensor([3, -0.5, 2, 7]), | ||
{ | ||
"r2squared": torch.Tensor([0.9486]), | ||
}, | ||
), | ||
), | ||
) | ||
def test_r2_squared( | ||
outputs: torch.Tensor, | ||
targets: torch.Tensor, | ||
true_values: Dict[str, torch.Tensor], | ||
) -> None: | ||
""" | ||
Test r2 squared metric | ||
Args: | ||
outputs: tensor of outputs | ||
targets: tensor of targets | ||
true_values: true metric values | ||
""" | ||
metric = R2Squared() | ||
metric.update(y_pred=outputs, y_true=targets) | ||
metrics = metric.compute_key_value() | ||
for key in true_values.keys(): | ||
assert torch.isclose(true_values[key], metrics[key]) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"outputs_list,targets_list,true_values", | ||
( | ||
( | ||
( | ||
torch.Tensor([2.5, 0.0, 2, 8]), | ||
torch.Tensor([2.5, 0.0, 2, 8]), | ||
torch.Tensor([2.5, 0.0, 2, 8]), | ||
torch.Tensor([2.5, 0.0, 2, 8]), | ||
), | ||
( | ||
torch.Tensor([3, -0.5, 2, 7]), | ||
torch.Tensor([3, -0.5, 2, 7]), | ||
torch.Tensor([3, -0.5, 2, 7]), | ||
torch.Tensor([3, -0.5, 2, 7]), | ||
), | ||
{ | ||
"r2squared": torch.Tensor([0.9486]), | ||
}, | ||
), | ||
), | ||
) | ||
def test_r2_squared_update( | ||
outputs_list: Iterable[torch.Tensor], | ||
targets_list: Iterable[torch.Tensor], | ||
true_values: Dict[str, torch.Tensor], | ||
): | ||
""" | ||
Test r2 squared metric computation | ||
Args: | ||
outputs_list: list of outputs | ||
targets_list: list of targets | ||
true_values: true metric values | ||
""" | ||
metric = R2Squared() | ||
for outputs, targets in zip(outputs_list, targets_list): | ||
metric.update(y_pred=outputs, y_true=targets) | ||
metrics = metric.compute_key_value() | ||
for key in true_values.keys(): | ||
assert torch.isclose(true_values[key], metrics[key]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters