From 55682a953a933f371a723056fb1059198accadc5 Mon Sep 17 00:00:00 2001 From: Nayef Ahmed Date: Mon, 21 Oct 2024 19:10:28 -0700 Subject: [PATCH] Horizontal Task Fusion for Metrics Computation (#2498) Summary: Pull Request resolved: https://github.com/pytorch/torchrec/pull/2498 For each metric we iterate through all the tasks and run update in the baseline. With Torchrec we can run a fused update across all tasks as long as the inputs to the metrics are in the correct format. After this optimization we see that only a single update is called for each metric. The CPU wall time for metrics update goes from 95ms to 7ms and the GPU wall timereduces from 2.7ms to 0.6ms. See this doc for more details: https://docs.google.com/document/d/15ELwQ1mehjecYoJJxryWDXURBJMHiTWW8iWxK-I3Y-Q/edit Reviewed By: iamzainhuda Differential Revision: D64205895 fbshipit-source-id: a1a4728831b6b0c7b1603d38648af0f5e11c60eb --- torchrec/metrics/rec_metric.py | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/torchrec/metrics/rec_metric.py b/torchrec/metrics/rec_metric.py index 9ff4e52eb..67df6c29f 100644 --- a/torchrec/metrics/rec_metric.py +++ b/torchrec/metrics/rec_metric.py @@ -11,6 +11,7 @@ import abc import itertools +import logging import math from collections import defaultdict, deque from dataclasses import dataclass @@ -47,6 +48,7 @@ MetricPrefix, ) +logger: logging.Logger = logging.getLogger(__name__) RecModelOutput = Union[torch.Tensor, Dict[str, torch.Tensor]] @@ -520,6 +522,31 @@ def _update( ) -> None: with torch.no_grad(): if self._compute_mode == RecComputeMode.FUSED_TASKS_COMPUTATION: + task_names = [task.name for task in self._tasks] + + if not isinstance(predictions, torch.Tensor): + logger.info( + "Converting predictions to tensors for RecComputeMode.FUSED_TASKS_COMPUTATION" + ) + predictions = torch.stack( + [predictions[task_name] for task_name in task_names] + ) + + if not isinstance(labels, torch.Tensor): + logger.info( + "Converting labels to tensors for RecComputeMode.FUSED_TASKS_COMPUTATION" + ) + labels = torch.stack( + [labels[task_name] for task_name in task_names] + ) + if weights is not None and not isinstance(weights, torch.Tensor): + logger.info( + "Converting weights to tensors for RecComputeMode.FUSED_TASKS_COMPUTATION" + ) + weights = torch.stack( + [weights[task_name] for task_name in task_names] + ) + assert isinstance(predictions, torch.Tensor) and isinstance( labels, torch.Tensor )