Skip to content

Commit

Permalink
TorchRec inference output dtype
Browse files Browse the repository at this point in the history
Differential Revision: D65445160
  • Loading branch information
PaulZhang12 authored and facebook-github-bot committed Nov 4, 2024
1 parent d2ed744 commit 215beb9
Showing 1 changed file with 1 addition and 2 deletions.
3 changes: 1 addition & 2 deletions torchrec/inference/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,7 +418,6 @@ def _quantize_fp_module(
model: torch.nn.Module,
fp_module: FeatureProcessedEmbeddingBagCollection,
fp_module_fqn: str,
activation_dtype: torch.dtype = torch.float,
weight_dtype: torch.dtype = DEFAULT_QUANTIZATION_DTYPE,
per_fp_table_weight_dtype: Optional[Dict[str, torch.dtype]] = None,
) -> None:
Expand All @@ -428,7 +427,7 @@ def _quantize_fp_module(

quant_prep_enable_register_tbes(model, [FeatureProcessedEmbeddingBagCollection])
fp_module.qconfig = QuantConfig(
activation=quant.PlaceholderObserver.with_args(dtype=activation_dtype),
activation=quant.PlaceholderObserver.with_args(dtype=output_dtype),
weight=quant.PlaceholderObserver.with_args(dtype=weight_dtype),
per_table_weight_dtype=per_fp_table_weight_dtype,
)
Expand Down

0 comments on commit 215beb9

Please sign in to comment.