Skip to content

Commit

Permalink
2024-11-06 nightly release (509b0d2)
Browse files Browse the repository at this point in the history
  • Loading branch information
pytorchbot committed Nov 6, 2024
1 parent 5e46ad6 commit ccfcc94
Show file tree
Hide file tree
Showing 7 changed files with 64 additions and 8 deletions.
14 changes: 9 additions & 5 deletions torchrec/distributed/planner/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from torch import nn

from torchrec.distributed.embedding_types import EmbeddingComputeKernel
from torchrec.distributed.planner.constants import BIGINT_DTYPE, NUM_POOLINGS
from torchrec.distributed.planner.shard_estimators import _calculate_shard_io_sizes
from torchrec.distributed.planner.storage_reservations import (
Expand Down Expand Up @@ -421,11 +422,14 @@ def log(
if hasattr(sharder, "fused_params") and sharder.fused_params
else None
)
cache_load_factor = str(
so.cache_load_factor
if so.cache_load_factor is not None
else sharder_cache_load_factor
)
cache_load_factor = "None"
# Surfacing cache load factor does not make sense if not using uvm caching.
if so.compute_kernel == EmbeddingComputeKernel.FUSED_UVM_CACHING.value:
cache_load_factor = str(
so.cache_load_factor
if so.cache_load_factor is not None
else sharder_cache_load_factor
)
hash_size = so.tensor.shape[0]
param_table.append(
[
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
#include "torchrec/inference/BatchingQueue.h"
#include "torchrec/inference/Observer.h"
#include "torchrec/inference/ResultSplit.h"
#include "torchrec/inference/include/torchrec/inference/Observer.h"
#include "torchrec/inference/include/torchrec/inference/Observer.h" // @manual

namespace torchrec {

Expand Down
2 changes: 1 addition & 1 deletion torchrec/inference/inference_legacy/src/GPUExecutor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
#include <folly/stop_watch.h>
#include <gflags/gflags.h>
#include <glog/logging.h>
#include <torch/csrc/autograd/profiler.h>
#include <torch/csrc/autograd/profiler.h> // @manual

// remove this after we switch over to multipy externally for torchrec
#ifdef FBCODE_CAFFE2
Expand Down
2 changes: 2 additions & 0 deletions torchrec/inference/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,6 +488,7 @@ def shard_quant_model(
sharders: Optional[List[ModuleSharder[torch.nn.Module]]] = None,
device_memory_size: Optional[int] = None,
constraints: Optional[Dict[str, ParameterConstraints]] = None,
ddr_cap: Optional[int] = None,
) -> Tuple[torch.nn.Module, ShardingPlan]:
"""
Shard a quantized TorchRec model, used for generating the most optimal model for inference and
Expand Down Expand Up @@ -557,6 +558,7 @@ def shard_quant_model(
compute_device=compute_device,
local_world_size=world_size,
hbm_cap=hbm_cap,
ddr_cap=ddr_cap,
)
batch_size = 1
model_plan = trec_dist.planner.EmbeddingShardingPlanner(
Expand Down
23 changes: 22 additions & 1 deletion torchrec/modules/fp_embedding_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

# pyre-strict

from typing import Dict, List, Set, Union
from typing import Dict, List, Set, Tuple, Union

import torch
import torch.nn as nn
Expand Down Expand Up @@ -55,6 +55,15 @@ def apply_feature_processors_to_kjt(
)


class FeatureProcessorDictWrapper(FeatureProcessorsCollection):
def __init__(self, feature_processors: nn.ModuleDict) -> None:
super().__init__()
self._feature_processors = feature_processors

def forward(self, features: KeyedJaggedTensor) -> KeyedJaggedTensor:
return apply_feature_processors_to_kjt(features, self._feature_processors)


class FeatureProcessedEmbeddingBagCollection(nn.Module):
"""
FeatureProcessedEmbeddingBagCollection represents a EmbeddingBagCollection module and a set of feature processor modules.
Expand Down Expand Up @@ -125,6 +134,18 @@ def __init__(
feature_names_set.update(table_config.feature_names)
self._feature_names: List[str] = list(feature_names_set)

def split(
self,
) -> Tuple[FeatureProcessorsCollection, EmbeddingBagCollection]:
if isinstance(self._feature_processors, nn.ModuleDict):
return (
FeatureProcessorDictWrapper(self._feature_processors),
self._embedding_bag_collection,
)
else:
assert isinstance(self._feature_processors, FeatureProcessorsCollection)
return self._feature_processors, self._embedding_bag_collection

def forward(
self,
features: KeyedJaggedTensor,
Expand Down
18 changes: 18 additions & 0 deletions torchrec/modules/tests/test_fp_embedding_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,15 @@ def test_position_weighted_module_ebc_with_excessive_features(self) -> None:
self.assertEqual(pooled_embeddings.values().size(), (3, 16))
self.assertEqual(pooled_embeddings.offset_per_key(), [0, 8, 16])

# Test split method, FP then EBC
fp, ebc = fp_ebc.split()
fp_kjt = fp(features)
pooled_embeddings_split = ebc(fp_kjt)

self.assertEqual(pooled_embeddings_split.keys(), ["f1", "f2"])
self.assertEqual(pooled_embeddings_split.values().size(), (3, 16))
self.assertEqual(pooled_embeddings_split.offset_per_key(), [0, 8, 16])


class PositionWeightedModuleCollectionEmbeddingBagCollectionTest(unittest.TestCase):
def generate_fp_ebc(self) -> FeatureProcessedEmbeddingBagCollection:
Expand Down Expand Up @@ -144,3 +153,12 @@ def test_position_weighted_collection_module_ebc(self) -> None:
pooled_embeddings_gm_script.offset_per_key(),
pooled_embeddings.offset_per_key(),
)

# Test split method, FP then EBC
fp, ebc = fp_ebc.split()
fp_kjt = fp(features)
pooled_embeddings_split = ebc(fp_kjt)

self.assertEqual(pooled_embeddings_split.keys(), ["f1", "f2"])
self.assertEqual(pooled_embeddings_split.values().size(), (3, 16))
self.assertEqual(pooled_embeddings_split.offset_per_key(), [0, 8, 16])
11 changes: 11 additions & 0 deletions torchrec/modules/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,17 @@ def _slice_1d_tensor(tensor: torch.Tensor, start: int, end: int) -> torch.Tensor
return tensor[start:end]


# PLEASE DO NOT USE THIS FUNCTION, THIS FUNCTION IS FOR BACKWARD COMPATIBILITY ONLY
# USE THE ONE IN torchrec/quant/embedding_modules.py
# TODO(@shuaoxiong): remove this function after we make sure all models switch to the new reference
@torch.fx.wrap
def _get_unflattened_lengths(lengths: torch.Tensor, num_features: int) -> torch.Tensor:
"""
Unflatten lengths tensor from [F * B] to [F, B].
"""
return lengths.view(num_features, -1)


def extract_module_or_tensor_callable(
module_or_callable: Union[
Callable[[], torch.nn.Module],
Expand Down

0 comments on commit ccfcc94

Please sign in to comment.