Skip to content

Commit

Permalink
Horizontal Task Fusion for Metrics Computation (#2498)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #2498

For each metric we iterate through all the tasks and run update in the baseline. With Torchrec we can run a fused update across all tasks as long as the inputs to the metrics are in the correct format. After this optimization we see that only a single update is called for each metric.

The CPU wall time for metrics update goes from 95ms to 7ms and the GPU wall timereduces from 2.7ms to 0.6ms.

See this doc for more details:
https://docs.google.com/document/d/15ELwQ1mehjecYoJJxryWDXURBJMHiTWW8iWxK-I3Y-Q/edit

Reviewed By: iamzainhuda

Differential Revision: D64205895

fbshipit-source-id: a1a4728831b6b0c7b1603d38648af0f5e11c60eb
  • Loading branch information
Nayef Ahmed authored and facebook-github-bot committed Oct 22, 2024
1 parent 3e8de05 commit 55682a9
Showing 1 changed file with 27 additions and 0 deletions.
27 changes: 27 additions & 0 deletions torchrec/metrics/rec_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

import abc
import itertools
import logging
import math
from collections import defaultdict, deque
from dataclasses import dataclass
Expand Down Expand Up @@ -47,6 +48,7 @@
MetricPrefix,
)

logger: logging.Logger = logging.getLogger(__name__)

RecModelOutput = Union[torch.Tensor, Dict[str, torch.Tensor]]

Expand Down Expand Up @@ -520,6 +522,31 @@ def _update(
) -> None:
with torch.no_grad():
if self._compute_mode == RecComputeMode.FUSED_TASKS_COMPUTATION:
task_names = [task.name for task in self._tasks]

if not isinstance(predictions, torch.Tensor):
logger.info(
"Converting predictions to tensors for RecComputeMode.FUSED_TASKS_COMPUTATION"
)
predictions = torch.stack(
[predictions[task_name] for task_name in task_names]
)

if not isinstance(labels, torch.Tensor):
logger.info(
"Converting labels to tensors for RecComputeMode.FUSED_TASKS_COMPUTATION"
)
labels = torch.stack(
[labels[task_name] for task_name in task_names]
)
if weights is not None and not isinstance(weights, torch.Tensor):
logger.info(
"Converting weights to tensors for RecComputeMode.FUSED_TASKS_COMPUTATION"
)
weights = torch.stack(
[weights[task_name] for task_name in task_names]
)

assert isinstance(predictions, torch.Tensor) and isinstance(
labels, torch.Tensor
)
Expand Down

0 comments on commit 55682a9

Please sign in to comment.