Skip to content

Commit

Permalink
PT2 Compile Individual Metrics (#2517)
Browse files Browse the repository at this point in the history
Summary:

We find that PT2 compiling the TorchMetrics helper functions that are called during  `.update()` and `._compute()` for `XMetricComputation` classes improves both the CPU  wall time (95.4ms -> 65.4ms) and GPU wall time (3.3ms -> 1.2ms)

We use `MetricsConfig.enable_pt2_config` to set the `enable_pt2_config` attribute within a class. The `pt2_compile` decorator compiles methods within a class based on whether the `pt2_compile` method is set to True.


Without PT2 compile
 {F1941555284}

With PT2 Compile
{F1941555918}

Reviewed By: iamzainhuda

Differential Revision: D64244129
  • Loading branch information
Nayef Ahmed authored and facebook-github-bot committed Oct 24, 2024
1 parent 9669707 commit 2eec35a
Show file tree
Hide file tree
Showing 28 changed files with 88 additions and 1 deletion.
2 changes: 2 additions & 0 deletions torchrec/metrics/accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
RecMetricComputation,
RecMetricException,
)
from torchrec.pt2.utils import pt2_compile_callable


THRESHOLD = "threshold"
Expand Down Expand Up @@ -84,6 +85,7 @@ def __init__(self, *args: Any, threshold: float = 0.5, **kwargs: Any) -> None:
)
self._threshold: float = threshold

@pt2_compile_callable
def update(
self,
*,
Expand Down
2 changes: 2 additions & 0 deletions torchrec/metrics/auc.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
RecMetricComputation,
RecMetricException,
)
from torchrec.pt2.utils import pt2_compile_callable

PREDICTIONS = "predictions"
LABELS = "labels"
Expand Down Expand Up @@ -243,6 +244,7 @@ def _init_states(self) -> None:
if self._grouped_auc:
getattr(self, GROUPING_KEYS).append(torch.tensor([-1], device=self.device))

@pt2_compile_callable
def update(
self,
*,
Expand Down
2 changes: 2 additions & 0 deletions torchrec/metrics/auprc.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
RecMetricComputation,
RecMetricException,
)
from torchrec.pt2.utils import pt2_compile_callable

PREDICTIONS = "predictions"
LABELS = "labels"
Expand Down Expand Up @@ -235,6 +236,7 @@ def _init_states(self) -> None:
if self._grouped_auprc:
getattr(self, GROUPING_KEYS).append(torch.tensor([-1], device=self.device))

@pt2_compile_callable
def update(
self,
*,
Expand Down
2 changes: 2 additions & 0 deletions torchrec/metrics/calibration.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
RecMetricComputation,
RecMetricException,
)
from torchrec.pt2.utils import pt2_compile_callable

CALIBRATION_NUM = "calibration_num"
CALIBRATION_DENOM = "calibration_denom"
Expand Down Expand Up @@ -65,6 +66,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
persistent=True,
)

@pt2_compile_callable
def update(
self,
*,
Expand Down
3 changes: 3 additions & 0 deletions torchrec/metrics/ctr.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,15 @@
from typing import Any, cast, Dict, List, Optional, Type

import torch

from torchrec.metrics.metrics_namespace import MetricName, MetricNamespace, MetricPrefix
from torchrec.metrics.rec_metric import (
MetricComputationReport,
RecMetric,
RecMetricComputation,
RecMetricException,
)
from torchrec.pt2.utils import pt2_compile_callable

CTR_NUM = "ctr_num"
CTR_DENOM = "ctr_denom"
Expand Down Expand Up @@ -61,6 +63,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
persistent=True,
)

@pt2_compile_callable
def update(
self,
*,
Expand Down
2 changes: 2 additions & 0 deletions torchrec/metrics/mae.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
RecMetricComputation,
RecMetricException,
)
from torchrec.pt2.utils import pt2_compile_callable


ERROR_SUM = "error_sum"
Expand Down Expand Up @@ -72,6 +73,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
persistent=True,
)

@pt2_compile_callable
# pyre-fixme[14]: `update` overrides method defined in `RecMetricComputation`
# inconsistently.
def update(
Expand Down
4 changes: 3 additions & 1 deletion torchrec/metrics/metric_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,8 @@ def _generate_rec_metrics(
if metric_def and metric_def.arguments is not None:
kwargs = metric_def.arguments

kwargs["enable_pt2_compile"] = metrics_config.enable_pt2_compile

rec_tasks: List[RecTaskInfo] = []
if metric_def.rec_tasks and metric_def.rec_task_indices:
raise ValueError(
Expand Down Expand Up @@ -468,7 +470,7 @@ def generate_metric_module(
metrics_config, world_size, my_rank, batch_size, process_group
)
"""
Batch_size_stages currently only used by ThroughputMetric to ensure total_example correct so
Batch_size_stages currently only used by ThroughputMetric to ensure total_example correct so
different training jobs have aligned mertics.
TODO: update metrics other than ThroughputMetric if it has dependency on batch_size
"""
Expand Down
2 changes: 2 additions & 0 deletions torchrec/metrics/metrics_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@ class MetricsConfig:
should_validate_update (bool): whether to check the inputs of update() and skip
update if the inputs are invalid. Invalid inputs include the case where all
examples have 0 weights for a batch.
enable_pt2_compile (bool): whether to enable PT2 compilation for metrics.
"""

rec_tasks: List[RecTaskInfo] = field(default_factory=list)
Expand All @@ -171,6 +172,7 @@ class MetricsConfig:
max_compute_interval: float = float("inf")
compute_on_all_ranks: bool = False
should_validate_update: bool = False
enable_pt2_compile: bool = False


DefaultTaskInfo = RecTaskInfo(
Expand Down
3 changes: 3 additions & 0 deletions torchrec/metrics/mse.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,15 @@
from typing import Any, cast, Dict, List, Optional, Type

import torch

from torchrec.metrics.metrics_namespace import MetricName, MetricNamespace, MetricPrefix
from torchrec.metrics.rec_metric import (
MetricComputationReport,
RecMetric,
RecMetricComputation,
RecMetricException,
)
from torchrec.pt2.utils import pt2_compile_callable


ERROR_SUM = "error_sum"
Expand Down Expand Up @@ -80,6 +82,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
persistent=True,
)

@pt2_compile_callable
def update(
self,
*,
Expand Down
2 changes: 2 additions & 0 deletions torchrec/metrics/multiclass_recall.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
RecMetricComputation,
RecMetricException,
)
from torchrec.pt2.utils import pt2_compile_callable


def compute_true_positives_at_k(
Expand Down Expand Up @@ -109,6 +110,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
persistent=True,
)

@pt2_compile_callable
def update(
self,
*,
Expand Down
2 changes: 2 additions & 0 deletions torchrec/metrics/ndcg.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
RecMetricComputation,
RecMetricException,
)
from torchrec.pt2.utils import pt2_compile_callable

SUM_NDCG = "sum_ndcg"
NUM_SESSIONS = "num_sessions"
Expand Down Expand Up @@ -331,6 +332,7 @@ def __init__(
persistent=True,
)

@pt2_compile_callable
def update(
self,
*,
Expand Down
2 changes: 2 additions & 0 deletions torchrec/metrics/ne.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
RecMetricComputation,
RecMetricException,
)
from torchrec.pt2.utils import pt2_compile_callable


def compute_cross_entropy(
Expand Down Expand Up @@ -148,6 +149,7 @@ def __init__(
)
self.eta = 1e-12

@pt2_compile_callable
def update(
self,
*,
Expand Down
2 changes: 2 additions & 0 deletions torchrec/metrics/ne_positive.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
RecMetricComputation,
RecMetricException,
)
from torchrec.pt2.utils import pt2_compile_callable


def compute_cross_entropy_positive(
Expand Down Expand Up @@ -130,6 +131,7 @@ def __init__(
)
self.eta = 1e-12

@pt2_compile_callable
def update(
self,
*,
Expand Down
2 changes: 2 additions & 0 deletions torchrec/metrics/output.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
RecMetricException,
RecTaskInfo,
)
from torchrec.pt2.utils import pt2_compile_callable


class OutputMetricComputation(RecMetricComputation):
Expand All @@ -46,6 +47,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
persistent=False,
)

@pt2_compile_callable
def update(
self,
*,
Expand Down
2 changes: 2 additions & 0 deletions torchrec/metrics/precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
RecMetricComputation,
RecMetricException,
)
from torchrec.pt2.utils import pt2_compile_callable


THRESHOLD = "threshold"
Expand Down Expand Up @@ -96,6 +97,7 @@ def __init__(self, *args: Any, threshold: float = 0.5, **kwargs: Any) -> None:
)
self._threshold: float = threshold

@pt2_compile_callable
def update(
self,
*,
Expand Down
2 changes: 2 additions & 0 deletions torchrec/metrics/rauc.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
RecMetricComputation,
RecMetricException,
)
from torchrec.pt2.utils import pt2_compile_callable

PREDICTIONS = "predictions"
LABELS = "labels"
Expand Down Expand Up @@ -287,6 +288,7 @@ def _init_states(self) -> None:
if self._grouped_rauc:
getattr(self, GROUPING_KEYS).append(torch.tensor([-1], device=self.device))

@pt2_compile_callable
def update(
self,
*,
Expand Down
4 changes: 4 additions & 0 deletions torchrec/metrics/rec_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
MetricNamespaceBase,
MetricPrefix,
)
from torchrec.pt2.utils import pt2_compile_callable


RecModelOutput = Union[torch.Tensor, Dict[str, torch.Tensor]]
Expand Down Expand Up @@ -136,6 +137,7 @@ def __init__(
process_group: Optional[dist.ProcessGroup] = None,
fused_update_limit: int = 0,
allow_missing_label_with_zero_weight: bool = False,
enable_pt2_compile: bool = False,
*args: Any,
**kwargs: Any,
) -> None:
Expand All @@ -159,6 +161,7 @@ def __init__(
dist_reduce_fx=lambda x: torch.any(x, dim=0).byte(),
persistent=True,
)
self.enable_pt2_compile = enable_pt2_compile

@staticmethod
def get_window_state_name(state_name: str) -> str:
Expand Down Expand Up @@ -244,6 +247,7 @@ def pre_compute(self) -> None:
"""
return

@pt2_compile_callable
def compute(self) -> List[MetricComputationReport]:
with record_function(f"## {self.__class__.__name__}:compute ##"):
if self._my_rank == 0 or self._compute_on_all_ranks:
Expand Down
2 changes: 2 additions & 0 deletions torchrec/metrics/recall.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
RecMetricComputation,
RecMetricException,
)
from torchrec.pt2.utils import pt2_compile_callable


THRESHOLD = "threshold"
Expand Down Expand Up @@ -96,6 +97,7 @@ def __init__(self, *args: Any, threshold: float = 0.5, **kwargs: Any) -> None:
)
self._threshold: float = threshold

@pt2_compile_callable
def update(
self,
*,
Expand Down
2 changes: 2 additions & 0 deletions torchrec/metrics/recall_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
RecMetricComputation,
RecMetricException,
)
from torchrec.pt2.utils import pt2_compile_callable

logger: logging.Logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -128,6 +129,7 @@ def __init__(
self.run_ranking_of_labels: bool = session_metric_def.run_ranking_of_labels
self.session_var_name: Optional[str] = session_metric_def.session_var_name

@pt2_compile_callable
def update(
self,
*,
Expand Down
2 changes: 2 additions & 0 deletions torchrec/metrics/scalar.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
RecMetric,
RecMetricComputation,
)
from torchrec.pt2.utils import pt2_compile_callable


class ScalarMetricComputation(RecMetricComputation):
Expand All @@ -41,6 +42,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
persistent=False,
)

@pt2_compile_callable
def update(
self,
*,
Expand Down
2 changes: 2 additions & 0 deletions torchrec/metrics/segmented_ne.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
RecMetricComputation,
RecMetricException,
)
from torchrec.pt2.utils import pt2_compile_callable

PREDICTIONS = "predictions"
LABELS = "labels"
Expand Down Expand Up @@ -206,6 +207,7 @@ def __init__(
)
self.eta = 1e-12

@pt2_compile_callable
def update(
self,
*,
Expand Down
2 changes: 2 additions & 0 deletions torchrec/metrics/serving_calibration.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
RecMetricComputation,
RecMetricException,
)
from torchrec.pt2.utils import pt2_compile_callable

CALIBRATION_NUM = "calibration_num"
CALIBRATION_DENOM = "calibration_denom"
Expand Down Expand Up @@ -57,6 +58,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
persistent=True,
)

@pt2_compile_callable
def update(
self,
*,
Expand Down
Loading

0 comments on commit 2eec35a

Please sign in to comment.