Skip to content

Commit

Permalink
Deepcopy FP module even if on meta device (pytorch#1676)
Browse files Browse the repository at this point in the history
Summary:

When we fx trace, even if there are 2 FP modules (because 2 cards), since it was sharded on meta, the ranks just have a reference to the FP on rank 0

and for whatever reason, FX eliminates the FP on rank 1 and it just shows the one on rank 0

do a deepcopy even when on meta device so each rank explicitly has their own copy, fx will persist it

Reviewed By: lequytra, tissue3

Differential Revision: D53294788
  • Loading branch information
s4ayub authored and facebook-github-bot committed Feb 1, 2024
1 parent 58cc035 commit 15ba21c
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion torchrec/distributed/quant_embeddingbag.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import copy
from typing import Any, Dict, List, Optional, Type

import torch
Expand Down Expand Up @@ -409,7 +410,7 @@ def __init__(
self.feature_processors_per_rank: nn.ModuleList = torch.nn.ModuleList()
for i in range(env.world_size):
self.feature_processors_per_rank.append(
feature_processor
copy.deepcopy(feature_processor)
if device_type == "meta"
else copy_to_device(
feature_processor,
Expand Down

0 comments on commit 15ba21c

Please sign in to comment.