Skip to content

Commit

Permalink
remove .int() for cpu indices and values (#1590)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #1590

CPU DI is serving large model with big embedding tables (2TB), the value and indices would overflow with .int() conversion. Remove .int() just for CPU

Reviewed By: zyan0, tissue3

Differential Revision: D52225777

fbshipit-source-id: 0bf7973a91a7b7daed6eaed3a55bb8dca25fcdef
  • Loading branch information
jiayisuse authored and facebook-github-bot committed Dec 18, 2023
1 parent 80b19a2 commit 2980010
Showing 1 changed file with 8 additions and 1 deletion.
9 changes: 8 additions & 1 deletion torchrec/distributed/quant_embedding_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,14 @@ def _quantize_weight(
def _unwrap_kjt(
features: KeyedJaggedTensor,
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
return features.values().int(), features.offsets().int(), features.weights_or_none()
if features.device().type == "cuda":
return (
features.values().int(),
features.offsets().int(),
features.weights_or_none(),
)
else:
return features.values(), features.offsets(), features.weights_or_none()


class QuantBatchedEmbeddingBag(
Expand Down

0 comments on commit 2980010

Please sign in to comment.