Skip to content

Commit

Permalink
Revise skip torchrec logic (pytorch#2240)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#2240

X-link: pytorch/pytorch#130783

The previous logic adds skipped files when the file was imported which happens at very early stage. However, we could set skip_torchrec at later stage (e.g, in APS, we set it during the trainer execution). In that case, the skip logic will still take effect since skipped files have been added.

So in this diff, we revise the logic so that it can adapt to changes of skip_torchrec at later stages.

Since we revise the skip torchrec logic, we also update related torchrec tests.

Reviewed By: IvanKobzarev, yanboliang

Differential Revision: D59779153
  • Loading branch information
Microve authored and facebook-github-bot committed Jul 30, 2024
1 parent 2771a90 commit c62554a
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 46 deletions.
27 changes: 0 additions & 27 deletions torchrec/distributed/test_utils/infer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -977,33 +977,6 @@ def assert_weight_spec(
assert wqcomp_spec.sharding_type == sharding_type


# TODO(ivankobzarev): Remove once torchrec is not in dynamo skipfiles
@contextmanager
# pyre-ignore
def dynamo_skipfiles_allow(exclude_from_skipfiles_pattern: str):
replaced: bool = False
try:
# Temporary wrapping, as preparation for removal of trace_rules.FBCODE_SKIP_DIRS_RE
# Remove dynamo_skipfiles_allow once trace_rules.FBCODE_SKIP_DIRS removed
original_FBCODE_SKIP_DIRS_RE = copy.deepcopy(trace_rules.FBCODE_SKIP_DIRS_RE)
new_FBCODE_SKIP_DIRS = {
s
for s in trace_rules.FBCODE_SKIP_DIRS
if exclude_from_skipfiles_pattern not in s
}
trace_rules.FBCODE_SKIP_DIRS_RE = re.compile(
# pyre-ignore
f".*({'|'.join(map(re.escape, new_FBCODE_SKIP_DIRS))})"
)
replaced = True
except Exception:
pass
yield
if replaced:
# pyre-ignore
trace_rules.FBCODE_SKIP_DIRS_RE = original_FBCODE_SKIP_DIRS_RE


class MockTBE(nn.Module):
def __init__(
self,
Expand Down
13 changes: 10 additions & 3 deletions torchrec/distributed/tests/test_comm.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,10 @@
import torchrec.distributed.comm_ops as comm_ops
from hypothesis import given, settings
from torch.distributed.distributed_c10d import GroupMember
from torchrec.distributed.test_utils.infer_utils import dynamo_skipfiles_allow
from torchrec.test_utils import get_free_port, seed_and_log

torch.ops.import_module("fbgemm_gpu.sparse_ops")


@dataclass
class _CompileConfig:
Expand Down Expand Up @@ -128,7 +129,10 @@ def _test_async_sync_compile(
if compile_config.backend is not None:
fn_transform = compile_config_to_fn_transform(compile_config)

with dynamo_skipfiles_allow("torchrec"):
with unittest.mock.patch(
"torch._dynamo.config.skip_torchrec",
False,
):
if compile_config.test_compiled_with_noncompiled_ranks and rank == 1:
# Turn off compilation for rank==1 to test compatibility of compiled rank and non-compiled
fn_transform = lambda x: x
Expand Down Expand Up @@ -241,7 +245,10 @@ def fn(*args, **kwargs) -> List[torch.Tensor]:

fn_transform = compile_config_to_fn_transform(compile_config)

with dynamo_skipfiles_allow("torchrec"):
with unittest.mock.patch(
"torch._dynamo.config.skip_torchrec",
False,
):
v_embs_out = fn_transform(fn)(
input_embeddings, out_split=out_split, group=pg if specify_pg else None
)
Expand Down
46 changes: 36 additions & 10 deletions torchrec/distributed/tests/test_pt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@
from torchrec.distributed.test_utils.infer_utils import (
assert_close,
create_test_model_ebc_only,
dynamo_skipfiles_allow,
KJTInputExportWrapper,
prep_inputs,
replace_registered_tbes_with_mock_tbes,
Expand Down Expand Up @@ -241,7 +240,10 @@ def _test_compile_fwd_bwd(
eager_loss.backward()
eager_bwd_out = _grad_detach_clone(eager_input)

with dynamo_skipfiles_allow("torchrec"):
with unittest.mock.patch(
"torch._dynamo.config.skip_torchrec",
False,
):
torch._dynamo.config.capture_scalar_outputs = True
torch._dynamo.config.capture_dynamic_output_shape_ops = True

Expand Down Expand Up @@ -281,7 +283,10 @@ def _test_kjt_input_module(
test_aot_inductor: bool = True,
test_pt2_ir_export: bool = False,
) -> None:
with dynamo_skipfiles_allow("torchrec"):
with unittest.mock.patch(
"torch._dynamo.config.skip_torchrec",
False,
):
EM: torch.nn.Module = KJTInputExportWrapper(kjt_input_module, kjt.keys())
em_inputs = (kjt.values(), kjt.lengths(), kjt.weights_or_none(), *inputs)
eager_output = EM(*em_inputs)
Expand Down Expand Up @@ -332,7 +337,10 @@ def _test_kjt_input_module_dynamo_compile(
inputs,
backend: str = "eager",
) -> None:
with dynamo_skipfiles_allow("torchrec"):
with unittest.mock.patch(
"torch._dynamo.config.skip_torchrec",
False,
):
EM: torch.nn.Module = KJTInputExportWrapperWithStrides(
kjt_input_module, kjt_keys
)
Expand Down Expand Up @@ -442,7 +450,10 @@ def forward(self, inputs: List[KeyedTensor]) -> Dict[str, torch.Tensor]:
torch._dynamo.decorators.mark_dynamic(t, 1)

eager_output = m(inputs)
with dynamo_skipfiles_allow("torchrec"):
with unittest.mock.patch(
"torch._dynamo.config.skip_torchrec",
False,
):
torch_compile_backend = "eager"

torch._dynamo.config.capture_scalar_outputs = True
Expand Down Expand Up @@ -581,7 +592,10 @@ def kjt_to_inputs(kjt):

device: str = "cuda"

with dynamo_skipfiles_allow("torchrec"):
with unittest.mock.patch(
"torch._dynamo.config.skip_torchrec",
False,
):
tracing_values = kjt.values()
tracing_lengths = kjt.lengths()
torch._dynamo.mark_dynamic(tracing_values, 0)
Expand Down Expand Up @@ -690,7 +704,10 @@ def test_maybe_compute_kjt_to_jt_dict(self) -> None:
)

def test_kjt_values_specialization(self):
with dynamo_skipfiles_allow("torchrec"):
with unittest.mock.patch(
"torch._dynamo.config.skip_torchrec",
False,
):
from torch._dynamo.testing import CompileCounter

kjt0 = KeyedJaggedTensor(
Expand Down Expand Up @@ -722,7 +739,10 @@ def f(kjt):
self.assertEqual(counter.frame_count, 1)

def test_kjt_values_specialization_utils(self):
with dynamo_skipfiles_allow("torchrec"):
with unittest.mock.patch(
"torch._dynamo.config.skip_torchrec",
False,
):
from torch._dynamo.testing import CompileCounter

kjt0 = KeyedJaggedTensor(
Expand Down Expand Up @@ -890,7 +910,10 @@ def get_weights(m):

# COMPILE
orig_compile_weights = get_weights(m_compile)
with dynamo_skipfiles_allow("torchrec"):
with unittest.mock.patch(
"torch._dynamo.config.skip_torchrec",
False,
):
torch._dynamo.config.capture_scalar_outputs = True
torch._dynamo.config.capture_dynamic_output_shape_ops = True

Expand Down Expand Up @@ -969,7 +992,10 @@ def get_weights(m):

# COMPILE
orig_compile_weights = get_weights(m_compile)
with dynamo_skipfiles_allow("torchrec"):
with unittest.mock.patch(
"torch._dynamo.config.skip_torchrec",
False,
):
torch._dynamo.config.capture_scalar_outputs = True
torch._dynamo.config.capture_dynamic_output_shape_ops = True

Expand Down
15 changes: 9 additions & 6 deletions torchrec/distributed/tests/test_pt2_multiprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,7 @@
from torchrec.distributed.planner.types import ShardingPlan
from torchrec.distributed.sharding_plan import EmbeddingBagCollectionSharder

from torchrec.distributed.test_utils.infer_utils import (
dynamo_skipfiles_allow,
TestModelInfo,
)
from torchrec.distributed.test_utils.infer_utils import TestModelInfo

from torchrec.distributed.test_utils.multi_process import (
MultiProcessContext,
Expand Down Expand Up @@ -424,7 +421,10 @@ def get_weights(dmp: DistributedModelParallel) -> torch.Tensor:

##### COMPILE #####
run_compile_backward: bool = torch_compile_backend in ["aot_eager", "inductor"]
with dynamo_skipfiles_allow("torchrec"):
with unittest.mock.patch(
"torch._dynamo.config.skip_torchrec",
False,
):
torch._dynamo.config.capture_scalar_outputs = True
torch._dynamo.config.capture_dynamic_output_shape_ops = True
torch._dynamo.config.force_unspec_int_unbacked_size_like_on_torchrec_kjt = (
Expand Down Expand Up @@ -457,7 +457,10 @@ def get_weights(dmp: DistributedModelParallel) -> torch.Tensor:
##### COMPILE END #####

##### NUMERIC CHECK #####
with dynamo_skipfiles_allow("torchrec"):
with unittest.mock.patch(
"torch._dynamo.config.skip_torchrec",
False,
):
for i in range(n_extra_numerics_checks):
local_model_input = ins[1 + i][rank].to(device)
kjt = local_model_input.idlist_features
Expand Down

0 comments on commit c62554a

Please sign in to comment.