diff --git a/torchrec/metrics/cali_free_ne.py b/torchrec/metrics/cali_free_ne.py new file mode 100644 index 000000000..82983f611 --- /dev/null +++ b/torchrec/metrics/cali_free_ne.py @@ -0,0 +1,226 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +from typing import Any, cast, Dict, List, Optional, Type + +import torch +from torchrec.metrics.metrics_namespace import MetricName, MetricNamespace, MetricPrefix +from torchrec.metrics.rec_metric import ( + MetricComputationReport, + RecMetric, + RecMetricComputation, + RecMetricException, +) +from torchrec.pt2.utils import pt2_compile_callable + + +def compute_cross_entropy( + labels: torch.Tensor, + predictions: torch.Tensor, + weights: torch.Tensor, + eta: float, +) -> torch.Tensor: + predictions = predictions.double() + predictions.clamp_(min=eta, max=1 - eta) + cross_entropy = -weights * labels * torch.log2(predictions) - weights * ( + 1.0 - labels + ) * torch.log2(1.0 - predictions) + return cross_entropy + + +def _compute_cross_entropy_norm( + mean_label: torch.Tensor, + pos_labels: torch.Tensor, + neg_labels: torch.Tensor, + eta: float, +) -> torch.Tensor: + mean_label = mean_label.double() + mean_label.clamp_(min=eta, max=1 - eta) + return -pos_labels * torch.log2(mean_label) - neg_labels * torch.log2( + 1.0 - mean_label + ) + + +@torch.fx.wrap +def _compute_ne( + ce_sum: torch.Tensor, + weighted_num_samples: torch.Tensor, + pos_labels: torch.Tensor, + neg_labels: torch.Tensor, + eta: float, +) -> torch.Tensor: + # Goes into this block if all elements in weighted_num_samples > 0 + weighted_num_samples = weighted_num_samples.double().clamp(min=eta) + mean_label = pos_labels / weighted_num_samples + ce_norm = _compute_cross_entropy_norm(mean_label, pos_labels, neg_labels, eta) + return ce_sum / ce_norm + + +def compute_cali_free_ne( + ce_sum: torch.Tensor, + weighted_num_samples: torch.Tensor, + pos_labels: torch.Tensor, + neg_labels: torch.Tensor, + weighted_sum_predictions: torch.Tensor, + eta: float, + allow_missing_label_with_zero_weight: bool = False, +) -> torch.Tensor: + if allow_missing_label_with_zero_weight and not weighted_num_samples.all(): + # If nan were to occur, return a dummy value instead of nan if + # allow_missing_label_with_zero_weight is True + return torch.tensor([eta]) + raw_ne = _compute_ne( + ce_sum=ce_sum, + weighted_num_samples=weighted_num_samples, + pos_labels=pos_labels, + neg_labels=neg_labels, + eta=eta, + ) + return raw_ne / ( + -pos_labels * torch.log2(weighted_sum_predictions / weighted_num_samples) + - (weighted_num_samples - pos_labels) + * torch.log2(1 - (weighted_sum_predictions / weighted_num_samples)) + ) + + +def get_cali_free_ne_states( + labels: torch.Tensor, + predictions: torch.Tensor, + weights: torch.Tensor, + eta: float, +) -> Dict[str, torch.Tensor]: + cross_entropy = compute_cross_entropy( + labels, + predictions, + weights, + eta, + ) + return { + "cross_entropy_sum": torch.sum(cross_entropy, dim=-1), + "weighted_num_samples": torch.sum(weights, dim=-1), + "pos_labels": torch.sum(weights * labels, dim=-1), + "neg_labels": torch.sum(weights * (1.0 - labels), dim=-1), + "weighted_sum_predictions": torch.sum(weights * predictions, dim=-1), + } + + +class CaliFreeNEMetricComputation(RecMetricComputation): + r""" + This class implements the RecMetricComputation for CaliFree NE, i.e. Normalized Entropy. + + The constructor arguments are defined in RecMetricComputation. + See the docstring of RecMetricComputation for more detail. + + Args: + allow_missing_label_with_zero_weight (bool): allow missing label to have weight 0, instead of throwing exception. + """ + + def __init__( + self, + *args: Any, + allow_missing_label_with_zero_weight: bool = False, + **kwargs: Any, + ) -> None: + self._allow_missing_label_with_zero_weight: bool = ( + allow_missing_label_with_zero_weight + ) + super().__init__(*args, **kwargs) + self._add_state( + "cross_entropy_sum", + torch.zeros(self._n_tasks, dtype=torch.double), + add_window_state=True, + dist_reduce_fx="sum", + persistent=True, + ) + self._add_state( + "weighted_num_samples", + torch.zeros(self._n_tasks, dtype=torch.double), + add_window_state=True, + dist_reduce_fx="sum", + persistent=True, + ) + self._add_state( + "pos_labels", + torch.zeros(self._n_tasks, dtype=torch.double), + add_window_state=True, + dist_reduce_fx="sum", + persistent=True, + ) + self._add_state( + "neg_labels", + torch.zeros(self._n_tasks, dtype=torch.double), + add_window_state=True, + dist_reduce_fx="sum", + persistent=True, + ) + self._add_state( + "weighted_sum_predictions", + torch.zeros(self._n_tasks, dtype=torch.double), + add_window_state=True, + dist_reduce_fx="sum", + persistent=True, + ) + self.eta = 1e-12 + + @pt2_compile_callable + def update( + self, + *, + predictions: Optional[torch.Tensor], + labels: torch.Tensor, + weights: Optional[torch.Tensor], + **kwargs: Dict[str, Any], + ) -> None: + if predictions is None or weights is None: + raise RecMetricException( + "Inputs 'predictions' and 'weights' should not be None for CaliFreeNEMetricComputation update" + ) + states = get_cali_free_ne_states(labels, predictions, weights, self.eta) + num_samples = predictions.shape[-1] + + for state_name, state_value in states.items(): + state = getattr(self, state_name) + state += state_value + self._aggregate_window_state(state_name, state_value, num_samples) + + def _compute(self) -> List[MetricComputationReport]: + reports = [ + MetricComputationReport( + name=MetricName.CALI_FREE_NE, + metric_prefix=MetricPrefix.LIFETIME, + value=compute_cali_free_ne( + cast(torch.Tensor, self.cross_entropy_sum), + cast(torch.Tensor, self.weighted_num_samples), + cast(torch.Tensor, self.pos_labels), + cast(torch.Tensor, self.neg_labels), + cast(torch.Tensor, self.weighted_sum_predictions), + self.eta, + self._allow_missing_label_with_zero_weight, + ), + ), + MetricComputationReport( + name=MetricName.CALI_FREE_NE, + metric_prefix=MetricPrefix.WINDOW, + value=compute_cali_free_ne( + self.get_window_state("cross_entropy_sum"), + self.get_window_state("weighted_num_samples"), + self.get_window_state("pos_labels"), + self.get_window_state("neg_labels"), + self.get_window_state("weighted_sum_predictions"), + self.eta, + self._allow_missing_label_with_zero_weight, + ), + ), + ] + return reports + + +class CaliFreeNEMetric(RecMetric): + _namespace: MetricNamespace = MetricNamespace.CALI_FREE_NE + _computation_class: Type[RecMetricComputation] = CaliFreeNEMetricComputation diff --git a/torchrec/metrics/metric_module.py b/torchrec/metrics/metric_module.py index d04a4aa5e..e40b7a3cd 100644 --- a/torchrec/metrics/metric_module.py +++ b/torchrec/metrics/metric_module.py @@ -21,6 +21,7 @@ from torchrec.metrics.accuracy import AccuracyMetric from torchrec.metrics.auc import AUCMetric from torchrec.metrics.auprc import AUPRCMetric +from torchrec.metrics.cali_free_ne import CaliFreeNEMetric from torchrec.metrics.calibration import CalibrationMetric from torchrec.metrics.ctr import CTRMetric from torchrec.metrics.mae import MAEMetric @@ -57,6 +58,7 @@ from torchrec.metrics.tensor_weighted_avg import TensorWeightedAvgMetric from torchrec.metrics.throughput import ThroughputMetric from torchrec.metrics.tower_qps import TowerQPSMetric +from torchrec.metrics.unweighted_ne import UnweightedNEMetric from torchrec.metrics.weighted_avg import WeightedAvgMetric from torchrec.metrics.xauc import XAUCMetric @@ -88,6 +90,8 @@ RecMetricEnum.SERVING_CALIBRATION: ServingCalibrationMetric, RecMetricEnum.OUTPUT: OutputMetric, RecMetricEnum.TENSOR_WEIGHTED_AVG: TensorWeightedAvgMetric, + RecMetricEnum.CALI_FREE_NE: CaliFreeNEMetric, + RecMetricEnum.UNWEIGHTED_NE: UnweightedNEMetric, } diff --git a/torchrec/metrics/metrics_config.py b/torchrec/metrics/metrics_config.py index 7ff5af552..fe905f5a2 100644 --- a/torchrec/metrics/metrics_config.py +++ b/torchrec/metrics/metrics_config.py @@ -44,6 +44,8 @@ class RecMetricEnum(RecMetricEnumBase): SERVING_CALIBRATION = "serving_calibration" OUTPUT = "output" TENSOR_WEIGHTED_AVG = "tensor_weighted_avg" + CALI_FREE_NE = "cali_free_ne" + UNWEIGHTED_NE = "unweighted_ne" @dataclass(unsafe_hash=True, eq=True) diff --git a/torchrec/metrics/metrics_namespace.py b/torchrec/metrics/metrics_namespace.py index 63386b481..21b05f587 100644 --- a/torchrec/metrics/metrics_namespace.py +++ b/torchrec/metrics/metrics_namespace.py @@ -76,6 +76,9 @@ class MetricName(MetricNameBase): SERVING_CALIBRATION = "serving_calibration" TENSOR_WEIGHTED_AVG = "tensor_weighted_avg" + CALI_FREE_NE = "cali_free_ne" + UNWEIGHTED_NE = "unweighted_ne" + class MetricNamespaceBase(StrValueMixin, Enum): pass @@ -120,6 +123,9 @@ class MetricNamespace(MetricNamespaceBase): OUTPUT = "output" TENSOR_WEIGHTED_AVG = "tensor_weighted_avg" + CALI_FREE_NE = "cali_free_ne" + UNWEIGHTED_NE = "unweighted_ne" + class MetricPrefix(StrValueMixin, Enum): DEFAULT = "" diff --git a/torchrec/metrics/tests/test_cali_free_ne.py b/torchrec/metrics/tests/test_cali_free_ne.py new file mode 100644 index 000000000..968f02677 --- /dev/null +++ b/torchrec/metrics/tests/test_cali_free_ne.py @@ -0,0 +1,168 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +import unittest +from typing import Dict, Type + +import torch +from torchrec.metrics.cali_free_ne import ( + CaliFreeNEMetric, + compute_cali_free_ne, + compute_cross_entropy, +) +from torchrec.metrics.rec_metric import RecComputeMode, RecMetric +from torchrec.metrics.test_utils import ( + metric_test_helper, + rec_metric_gpu_sync_test_launcher, + rec_metric_value_test_launcher, + sync_test_helper, + TestMetric, +) + + +WORLD_SIZE = 4 + + +class TestCaliFreeNEMetric(TestMetric): + eta: float = 1e-12 + + @staticmethod + def _get_states( + labels: torch.Tensor, predictions: torch.Tensor, weights: torch.Tensor + ) -> Dict[str, torch.Tensor]: + cross_entropy = compute_cross_entropy( + labels, predictions, weights, TestCaliFreeNEMetric.eta + ) + cross_entropy_sum = torch.sum(cross_entropy) + weighted_num_samples = torch.sum(weights) + pos_labels = torch.sum(weights * labels) + neg_labels = torch.sum(weights * (1.0 - labels)) + weighted_sum_predictions = torch.sum(weights * predictions) + return { + "cross_entropy_sum": cross_entropy_sum, + "weighted_num_samples": weighted_num_samples, + "pos_labels": pos_labels, + "neg_labels": neg_labels, + "num_samples": torch.tensor(labels.size()).long(), + "weighted_sum_predictions": weighted_sum_predictions, + } + + @staticmethod + def _compute(states: Dict[str, torch.Tensor]) -> torch.Tensor: + allow_missing_label_with_zero_weight = False + if not states["weighted_num_samples"].all(): + allow_missing_label_with_zero_weight = True + + return compute_cali_free_ne( + states["cross_entropy_sum"], + states["weighted_num_samples"], + pos_labels=states["pos_labels"], + neg_labels=states["neg_labels"], + weighted_sum_predictions=states["weighted_sum_predictions"], + eta=TestCaliFreeNEMetric.eta, + allow_missing_label_with_zero_weight=allow_missing_label_with_zero_weight, + ) + + +class CaliFreeNEMetricTest(unittest.TestCase): + target_clazz: Type[RecMetric] = CaliFreeNEMetric + target_compute_mode: RecComputeMode = RecComputeMode.UNFUSED_TASKS_COMPUTATION + task_name: str = "cali_free_ne" + + def test_cali_free_ne_unfused(self) -> None: + rec_metric_value_test_launcher( + target_clazz=CaliFreeNEMetric, + target_compute_mode=RecComputeMode.UNFUSED_TASKS_COMPUTATION, + test_clazz=TestCaliFreeNEMetric, + metric_name=CaliFreeNEMetricTest.task_name, + task_names=["t1", "t2", "t3"], + fused_update_limit=0, + compute_on_all_ranks=False, + should_validate_update=False, + world_size=WORLD_SIZE, + entry_point=metric_test_helper, + ) + + def test_cali_free_ne_fused(self) -> None: + rec_metric_value_test_launcher( + target_clazz=CaliFreeNEMetric, + target_compute_mode=RecComputeMode.FUSED_TASKS_COMPUTATION, + test_clazz=TestCaliFreeNEMetric, + metric_name=CaliFreeNEMetricTest.task_name, + task_names=["t1", "t2", "t3"], + fused_update_limit=0, + compute_on_all_ranks=False, + should_validate_update=False, + world_size=WORLD_SIZE, + entry_point=metric_test_helper, + ) + + def test_cali_free_ne_update_fused(self) -> None: + rec_metric_value_test_launcher( + target_clazz=CaliFreeNEMetric, + target_compute_mode=RecComputeMode.UNFUSED_TASKS_COMPUTATION, + test_clazz=TestCaliFreeNEMetric, + metric_name=CaliFreeNEMetricTest.task_name, + task_names=["t1", "t2", "t3"], + fused_update_limit=5, + compute_on_all_ranks=False, + should_validate_update=False, + world_size=WORLD_SIZE, + entry_point=metric_test_helper, + ) + + rec_metric_value_test_launcher( + target_clazz=CaliFreeNEMetric, + target_compute_mode=RecComputeMode.UNFUSED_TASKS_COMPUTATION, + test_clazz=TestCaliFreeNEMetric, + metric_name=CaliFreeNEMetricTest.task_name, + task_names=["t1", "t2", "t3"], + fused_update_limit=100, + compute_on_all_ranks=False, + should_validate_update=False, + world_size=WORLD_SIZE, + entry_point=metric_test_helper, + batch_window_size=10, + ) + + def test_cali_free_ne_zero_weights(self) -> None: + rec_metric_value_test_launcher( + target_clazz=CaliFreeNEMetric, + target_compute_mode=RecComputeMode.UNFUSED_TASKS_COMPUTATION, + test_clazz=TestCaliFreeNEMetric, + metric_name=CaliFreeNEMetricTest.task_name, + task_names=["t1", "t2", "t3"], + fused_update_limit=0, + compute_on_all_ranks=False, + should_validate_update=False, + world_size=WORLD_SIZE, + entry_point=metric_test_helper, + zero_weights=True, + ) + + +class CaliFreeNEGPUSyncTest(unittest.TestCase): + clazz: Type[RecMetric] = CaliFreeNEMetric + task_name: str = "cali_free_ne" + + def test_sync_cali_free_ne(self) -> None: + rec_metric_gpu_sync_test_launcher( + target_clazz=CaliFreeNEMetric, + target_compute_mode=RecComputeMode.UNFUSED_TASKS_COMPUTATION, + test_clazz=TestCaliFreeNEMetric, + metric_name=CaliFreeNEGPUSyncTest.task_name, + task_names=["t1"], + fused_update_limit=0, + compute_on_all_ranks=False, + should_validate_update=False, + world_size=2, + batch_size=5, + batch_window_size=20, + entry_point=sync_test_helper, + ) diff --git a/torchrec/metrics/tests/test_unweighted_ne.py b/torchrec/metrics/tests/test_unweighted_ne.py new file mode 100644 index 000000000..d80c10ae6 --- /dev/null +++ b/torchrec/metrics/tests/test_unweighted_ne.py @@ -0,0 +1,167 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +import unittest +from typing import Dict, Type + +import torch +from torchrec.metrics.rec_metric import RecComputeMode, RecMetric +from torchrec.metrics.test_utils import ( + metric_test_helper, + rec_metric_gpu_sync_test_launcher, + rec_metric_value_test_launcher, + sync_test_helper, + TestMetric, +) +from torchrec.metrics.unweighted_ne import ( + compute_cross_entropy, + compute_ne, + UnweightedNEMetric, +) + + +WORLD_SIZE = 4 + + +class TestUnweightedNEMetric(TestMetric): + eta: float = 1e-12 + + @staticmethod + def _get_states( + labels: torch.Tensor, predictions: torch.Tensor, weights: torch.Tensor + ) -> Dict[str, torch.Tensor]: + # Override the weights to be all ones + weights = torch.ones_like(labels) + cross_entropy = compute_cross_entropy( + labels, predictions, weights, TestUnweightedNEMetric.eta + ) + cross_entropy_sum = torch.sum(cross_entropy) + weighted_num_samples = torch.sum(weights) + pos_labels = torch.sum(weights * labels) + neg_labels = torch.sum(weights * (1.0 - labels)) + return { + "cross_entropy_sum": cross_entropy_sum, + "weighted_num_samples": weighted_num_samples, + "pos_labels": pos_labels, + "neg_labels": neg_labels, + "num_samples": torch.tensor(labels.size()).long(), + } + + @staticmethod + def _compute(states: Dict[str, torch.Tensor]) -> torch.Tensor: + allow_missing_label_with_zero_weight = False + if not states["weighted_num_samples"].all(): + allow_missing_label_with_zero_weight = True + + return compute_ne( + states["cross_entropy_sum"], + states["weighted_num_samples"], + pos_labels=states["pos_labels"], + neg_labels=states["neg_labels"], + eta=TestUnweightedNEMetric.eta, + allow_missing_label_with_zero_weight=allow_missing_label_with_zero_weight, + ) + + +class UnweightedNEMetricTest(unittest.TestCase): + target_clazz: Type[RecMetric] = UnweightedNEMetric + target_compute_mode: RecComputeMode = RecComputeMode.UNFUSED_TASKS_COMPUTATION + task_name: str = "unweighted_ne" + + def test_unweighted_ne_unfused(self) -> None: + rec_metric_value_test_launcher( + target_clazz=UnweightedNEMetric, + target_compute_mode=RecComputeMode.UNFUSED_TASKS_COMPUTATION, + test_clazz=TestUnweightedNEMetric, + metric_name=UnweightedNEMetricTest.task_name, + task_names=["t1", "t2", "t3"], + fused_update_limit=0, + compute_on_all_ranks=False, + should_validate_update=False, + world_size=WORLD_SIZE, + entry_point=metric_test_helper, + ) + + def test_unweighted_ne_fused(self) -> None: + rec_metric_value_test_launcher( + target_clazz=UnweightedNEMetric, + target_compute_mode=RecComputeMode.FUSED_TASKS_COMPUTATION, + test_clazz=TestUnweightedNEMetric, + metric_name=UnweightedNEMetricTest.task_name, + task_names=["t1", "t2", "t3"], + fused_update_limit=0, + compute_on_all_ranks=False, + should_validate_update=False, + world_size=WORLD_SIZE, + entry_point=metric_test_helper, + ) + + def test_unweighted_ne_update_fused(self) -> None: + rec_metric_value_test_launcher( + target_clazz=UnweightedNEMetric, + target_compute_mode=RecComputeMode.UNFUSED_TASKS_COMPUTATION, + test_clazz=TestUnweightedNEMetric, + metric_name=UnweightedNEMetricTest.task_name, + task_names=["t1", "t2", "t3"], + fused_update_limit=5, + compute_on_all_ranks=False, + should_validate_update=False, + world_size=WORLD_SIZE, + entry_point=metric_test_helper, + ) + + rec_metric_value_test_launcher( + target_clazz=UnweightedNEMetric, + target_compute_mode=RecComputeMode.UNFUSED_TASKS_COMPUTATION, + test_clazz=TestUnweightedNEMetric, + metric_name=UnweightedNEMetricTest.task_name, + task_names=["t1", "t2", "t3"], + fused_update_limit=100, + compute_on_all_ranks=False, + should_validate_update=False, + world_size=WORLD_SIZE, + entry_point=metric_test_helper, + batch_window_size=10, + ) + + def test_unweighted_ne_zero_weights(self) -> None: + rec_metric_value_test_launcher( + target_clazz=UnweightedNEMetric, + target_compute_mode=RecComputeMode.UNFUSED_TASKS_COMPUTATION, + test_clazz=TestUnweightedNEMetric, + metric_name=UnweightedNEMetricTest.task_name, + task_names=["t1", "t2", "t3"], + fused_update_limit=0, + compute_on_all_ranks=False, + should_validate_update=False, + world_size=WORLD_SIZE, + entry_point=metric_test_helper, + zero_weights=True, + ) + + +class UnweightedNEGPUSyncTest(unittest.TestCase): + clazz: Type[RecMetric] = UnweightedNEMetric + task_name: str = "unweighted_ne" + + def test_sync_unweighted_ne(self) -> None: + rec_metric_gpu_sync_test_launcher( + target_clazz=UnweightedNEMetric, + target_compute_mode=RecComputeMode.UNFUSED_TASKS_COMPUTATION, + test_clazz=TestUnweightedNEMetric, + metric_name=UnweightedNEGPUSyncTest.task_name, + task_names=["t1"], + fused_update_limit=0, + compute_on_all_ranks=False, + should_validate_update=False, + world_size=2, + batch_size=5, + batch_window_size=20, + entry_point=sync_test_helper, + ) diff --git a/torchrec/metrics/unweighted_ne.py b/torchrec/metrics/unweighted_ne.py new file mode 100644 index 000000000..74f77ce9b --- /dev/null +++ b/torchrec/metrics/unweighted_ne.py @@ -0,0 +1,202 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +from typing import Any, cast, Dict, List, Optional, Type + +import torch +from torchrec.metrics.metrics_namespace import MetricName, MetricNamespace, MetricPrefix +from torchrec.metrics.rec_metric import ( + MetricComputationReport, + RecMetric, + RecMetricComputation, + RecMetricException, +) +from torchrec.pt2.utils import pt2_compile_callable + + +def compute_cross_entropy( + labels: torch.Tensor, + predictions: torch.Tensor, + weights: torch.Tensor, + eta: float, +) -> torch.Tensor: + predictions = predictions.double() + predictions.clamp_(min=eta, max=1 - eta) + cross_entropy = -weights * labels * torch.log2(predictions) - weights * ( + 1.0 - labels + ) * torch.log2(1.0 - predictions) + return cross_entropy + + +def _compute_cross_entropy_norm( + mean_label: torch.Tensor, + pos_labels: torch.Tensor, + neg_labels: torch.Tensor, + eta: float, +) -> torch.Tensor: + mean_label = mean_label.double() + mean_label.clamp_(min=eta, max=1 - eta) + return -pos_labels * torch.log2(mean_label) - neg_labels * torch.log2( + 1.0 - mean_label + ) + + +@torch.fx.wrap +def compute_ne( + ce_sum: torch.Tensor, + weighted_num_samples: torch.Tensor, + pos_labels: torch.Tensor, + neg_labels: torch.Tensor, + eta: float, + allow_missing_label_with_zero_weight: bool = False, +) -> torch.Tensor: + if allow_missing_label_with_zero_weight and not weighted_num_samples.all(): + # If nan were to occur, return a dummy value instead of nan if + # allow_missing_label_with_zero_weight is True + return torch.tensor([eta]) + + # Goes into this block if all elements in weighted_num_samples > 0 + weighted_num_samples = weighted_num_samples.double().clamp(min=eta) + mean_label = pos_labels / weighted_num_samples + ce_norm = _compute_cross_entropy_norm(mean_label, pos_labels, neg_labels, eta) + return ce_sum / ce_norm + + +def get_unweighted_ne_states( + labels: torch.Tensor, + predictions: torch.Tensor, + weights: torch.Tensor, + eta: float, +) -> Dict[str, torch.Tensor]: + # Allow for unweighted NE computation by passing in a weights tensor of all ones + weights = torch.ones_like(labels) + cross_entropy = compute_cross_entropy( + labels, + predictions, + weights, + eta, + ) + return { + "cross_entropy_sum": torch.sum(cross_entropy, dim=-1), + "weighted_num_samples": torch.sum(weights, dim=-1), + "pos_labels": torch.sum(weights * labels, dim=-1), + "neg_labels": torch.sum(weights * (1.0 - labels), dim=-1), + } + + +class UnweightedNEMetricComputation(RecMetricComputation): + r""" + This class implements the RecMetricComputation for Unweighted NE, i.e. Normalized Entropy. + + The constructor arguments are defined in RecMetricComputation. + See the docstring of RecMetricComputation for more detail. + + Args: + allow_missing_label_with_zero_weight (bool): allow missing label to have weight 0, instead of throwing exception. + """ + + def __init__( + self, + *args: Any, + allow_missing_label_with_zero_weight: bool = False, + **kwargs: Any, + ) -> None: + self._allow_missing_label_with_zero_weight: bool = ( + allow_missing_label_with_zero_weight + ) + super().__init__(*args, **kwargs) + self._add_state( + "cross_entropy_sum", + torch.zeros(self._n_tasks, dtype=torch.double), + add_window_state=True, + dist_reduce_fx="sum", + persistent=True, + ) + self._add_state( + "weighted_num_samples", + torch.zeros(self._n_tasks, dtype=torch.double), + add_window_state=True, + dist_reduce_fx="sum", + persistent=True, + ) + self._add_state( + "pos_labels", + torch.zeros(self._n_tasks, dtype=torch.double), + add_window_state=True, + dist_reduce_fx="sum", + persistent=True, + ) + self._add_state( + "neg_labels", + torch.zeros(self._n_tasks, dtype=torch.double), + add_window_state=True, + dist_reduce_fx="sum", + persistent=True, + ) + self.eta = 1e-12 + + @pt2_compile_callable + def update( + self, + *, + predictions: Optional[torch.Tensor], + labels: torch.Tensor, + weights: Optional[torch.Tensor], + **kwargs: Dict[str, Any], + ) -> None: + if predictions is None or weights is None: + raise RecMetricException( + "Inputs 'predictions' and 'weights' should not be None for UnweightedNEMetricComputation update. Weight will not be used for this metric." + ) + states = get_unweighted_ne_states( + labels, + predictions, + weights, + self.eta, + ) + num_samples = predictions.shape[-1] + + for state_name, state_value in states.items(): + state = getattr(self, state_name) + state += state_value + self._aggregate_window_state(state_name, state_value, num_samples) + + def _compute(self) -> List[MetricComputationReport]: + reports = [ + MetricComputationReport( + name=MetricName.UNWEIGHTED_NE, + metric_prefix=MetricPrefix.LIFETIME, + value=compute_ne( + cast(torch.Tensor, self.cross_entropy_sum), + cast(torch.Tensor, self.weighted_num_samples), + cast(torch.Tensor, self.pos_labels), + cast(torch.Tensor, self.neg_labels), + self.eta, + self._allow_missing_label_with_zero_weight, + ), + ), + MetricComputationReport( + name=MetricName.UNWEIGHTED_NE, + metric_prefix=MetricPrefix.WINDOW, + value=compute_ne( + self.get_window_state("cross_entropy_sum"), + self.get_window_state("weighted_num_samples"), + self.get_window_state("pos_labels"), + self.get_window_state("neg_labels"), + self.eta, + self._allow_missing_label_with_zero_weight, + ), + ), + ] + return reports + + +class UnweightedNEMetric(RecMetric): + _namespace: MetricNamespace = MetricNamespace.UNWEIGHTED_NE + _computation_class: Type[RecMetricComputation] = UnweightedNEMetricComputation