diff --git a/torchrec/distributed/quant_embeddingbag.py b/torchrec/distributed/quant_embeddingbag.py index 366352180..30d69b7ae 100644 --- a/torchrec/distributed/quant_embeddingbag.py +++ b/torchrec/distributed/quant_embeddingbag.py @@ -5,6 +5,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import copy from typing import Any, Dict, List, Optional, Type import torch @@ -410,7 +411,7 @@ def __init__( breakpoint() for i in range(env.world_size): self.feature_processors_per_rank.append( - feature_processor + copy.deepcopy(feature_processor) if device_type == "meta" else copy_to_device( feature_processor,