Skip to content

Commit

Permalink
Register output dist for inference modules and add set_device helper (p…
Browse files Browse the repository at this point in the history
…ytorch#1323)

Summary:
Pull Request resolved: pytorch#1323

- add `set_device` to EmbeddingsAllToOne and SequenceEmbeddingsAllToOne (the infernece output dist modules), so you have the option to set these modules from external runtime
- Register the output dist of torchrec inference sharded modules so you can actually find these modules from the named_modules of the top level module

Reviewed By: sayitmemory

Differential Revision: D48249155

fbshipit-source-id: c0fc405e46a12e80ddc17b91637b0130a2dc697a
  • Loading branch information
s4ayub authored and facebook-github-bot committed Aug 11, 2023
1 parent 9a6c5ee commit d274ac8
Show file tree
Hide file tree
Showing 5 changed files with 21 additions and 4 deletions.
12 changes: 12 additions & 0 deletions torchrec/distributed/dist_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -714,6 +714,12 @@ def __init__(
self._world_size = world_size
self._cat_dim = cat_dim

# This method can be used by an inference runtime to update the
# device information for this module.
@torch.jit.export
def set_device(self, device_str: str) -> None:
self._device = torch.device(device_str)

def forward(self, tensors: List[torch.Tensor]) -> torch.Tensor:
"""
Performs AlltoOne operation on pooled/sequence embeddings tensors.
Expand Down Expand Up @@ -762,6 +768,12 @@ def __init__(
self._device = device
self._world_size = world_size

# This method can be used by an inference runtime to update the
# device information for this module.
@torch.jit.export
def set_device(self, device_str: str) -> None:
self._device = torch.device(device_str)

def forward(self, tensors: List[torch.Tensor]) -> List[torch.Tensor]:
"""
Performs AlltoOne operation on pooled embeddings tensors.
Expand Down
4 changes: 3 additions & 1 deletion torchrec/distributed/quant_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,9 @@ def __init__(
self._input_dists: List[nn.Module] = []
self._lookups: List[nn.Module] = []
self._create_lookups(fused_params, device)
self._output_dists: List[nn.Module] = []

# Ensure output dist is set for post processing from an inference runtime (ie. setting device from runtime).
self._output_dists: torch.nn.ModuleList = torch.nn.ModuleList()

self._feature_splits: List[int] = []
self._features_order: List[int] = []
Expand Down
5 changes: 4 additions & 1 deletion torchrec/distributed/quant_embeddingbag.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,10 @@ def __init__(
self._input_dists: List[nn.Module] = []
self._lookups: List[nn.Module] = []
self._create_lookups(fused_params, device)
self._output_dists: List[nn.Module] = []

# Ensure output dist is set for post processing from an inference runtime (ie. setting device from runtime).
self._output_dists: torch.nn.ModuleList = torch.nn.ModuleList()

self._embedding_names: List[str] = []
self._embedding_dims: List[int] = []
self._feature_splits: List[int] = []
Expand Down
2 changes: 1 addition & 1 deletion torchrec/distributed/sharding/tw_sequence_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ def forward(
Returns:
Awaitable[torch.Tensor]: awaitable of sequence embeddings.
"""
return self._dist.forward(local_embs)
return self._dist(local_embs)


class InferTwSequenceEmbeddingSharding(
Expand Down
2 changes: 1 addition & 1 deletion torchrec/distributed/sharding/tw_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,7 +419,7 @@ def forward(
Awaitable[torch.Tensor]: awaitable of merged pooled embedding tensor.
"""

return self._dist.forward(local_embs)
return self._dist(local_embs)


class InferTwEmbeddingSharding(
Expand Down

0 comments on commit d274ac8

Please sign in to comment.