From fae26f474d9c65acc9bf5474069044dd8649e5b8 Mon Sep 17 00:00:00 2001 From: Shabab Ayub Date: Wed, 16 Aug 2023 11:57:13 -0700 Subject: [PATCH] Register output dist for inference modules and add set_device helper (#1323) Summary: Pull Request resolved: https://github.com/pytorch/torchrec/pull/1323 - add `set_device` to EmbeddingsAllToOne and SequenceEmbeddingsAllToOne (the infernece output dist modules), so you have the option to set these modules from external runtime - Register the output dist of torchrec inference sharded modules so you can actually find these modules from the named_modules of the top level module Reviewed By: sayitmemory Differential Revision: D48249155 fbshipit-source-id: cfaccb4ceaa9909a58e94632496643c17f9dd0ea --- torchrec/distributed/dist_data.py | 12 ++++++++++++ torchrec/distributed/quant_embedding.py | 4 +++- torchrec/distributed/quant_embeddingbag.py | 5 ++++- .../distributed/sharding/tw_sequence_sharding.py | 2 +- torchrec/distributed/sharding/tw_sharding.py | 2 +- 5 files changed, 21 insertions(+), 4 deletions(-) diff --git a/torchrec/distributed/dist_data.py b/torchrec/distributed/dist_data.py index 81be979d4..a2ec1e6d7 100644 --- a/torchrec/distributed/dist_data.py +++ b/torchrec/distributed/dist_data.py @@ -714,6 +714,12 @@ def __init__( self._world_size = world_size self._cat_dim = cat_dim + # This method can be used by an inference runtime to update the + # device information for this module. + @torch.jit.export + def set_device(self, device_str: str) -> None: + self._device = torch.device(device_str) + def forward(self, tensors: List[torch.Tensor]) -> torch.Tensor: """ Performs AlltoOne operation on pooled/sequence embeddings tensors. @@ -762,6 +768,12 @@ def __init__( self._device = device self._world_size = world_size + # This method can be used by an inference runtime to update the + # device information for this module. + @torch.jit.export + def set_device(self, device_str: str) -> None: + self._device = torch.device(device_str) + def forward(self, tensors: List[torch.Tensor]) -> List[torch.Tensor]: """ Performs AlltoOne operation on pooled embeddings tensors. diff --git a/torchrec/distributed/quant_embedding.py b/torchrec/distributed/quant_embedding.py index 59987430a..bb2dcb2da 100644 --- a/torchrec/distributed/quant_embedding.py +++ b/torchrec/distributed/quant_embedding.py @@ -389,7 +389,9 @@ def __init__( self._input_dists: List[nn.Module] = [] self._lookups: List[nn.Module] = [] self._create_lookups(fused_params, device) - self._output_dists: List[nn.Module] = [] + + # Ensure output dist is set for post processing from an inference runtime (ie. setting device from runtime). + self._output_dists: torch.nn.ModuleList = torch.nn.ModuleList() self._feature_splits: List[int] = [] self._features_order: List[int] = [] diff --git a/torchrec/distributed/quant_embeddingbag.py b/torchrec/distributed/quant_embeddingbag.py index 011bf8a16..b8341aa95 100644 --- a/torchrec/distributed/quant_embeddingbag.py +++ b/torchrec/distributed/quant_embeddingbag.py @@ -124,7 +124,10 @@ def __init__( self._input_dists: List[nn.Module] = [] self._lookups: List[nn.Module] = [] self._create_lookups(fused_params, device) - self._output_dists: List[nn.Module] = [] + + # Ensure output dist is set for post processing from an inference runtime (ie. setting device from runtime). + self._output_dists: torch.nn.ModuleList = torch.nn.ModuleList() + self._embedding_names: List[str] = [] self._embedding_dims: List[int] = [] self._feature_splits: List[int] = [] diff --git a/torchrec/distributed/sharding/tw_sequence_sharding.py b/torchrec/distributed/sharding/tw_sequence_sharding.py index 9098af49a..e4498d604 100644 --- a/torchrec/distributed/sharding/tw_sequence_sharding.py +++ b/torchrec/distributed/sharding/tw_sequence_sharding.py @@ -183,7 +183,7 @@ def forward( Returns: Awaitable[torch.Tensor]: awaitable of sequence embeddings. """ - return self._dist.forward(local_embs) + return self._dist(local_embs) class InferTwSequenceEmbeddingSharding( diff --git a/torchrec/distributed/sharding/tw_sharding.py b/torchrec/distributed/sharding/tw_sharding.py index 8799b9647..9fa158e47 100644 --- a/torchrec/distributed/sharding/tw_sharding.py +++ b/torchrec/distributed/sharding/tw_sharding.py @@ -419,7 +419,7 @@ def forward( Awaitable[torch.Tensor]: awaitable of merged pooled embedding tensor. """ - return self._dist.forward(local_embs) + return self._dist(local_embs) class InferTwEmbeddingSharding(