Skip to content

Commit

Permalink
[dtensor] move ops to private (pytorch#131211)
Browse files Browse the repository at this point in the history
as titled

Differential Revision: [D60132519](https://our.internmc.facebook.com/intern/diff/D60132519)
Pull Request resolved: pytorch#131211
Approved by: https://github.com/XilunWu, https://github.com/wz337
ghstack dependencies: pytorch#131212
  • Loading branch information
wanchaol authored and pytorchmergebot committed Jul 25, 2024
1 parent 605dfd8 commit 1c58aac
Show file tree
Hide file tree
Showing 20 changed files with 24 additions and 24 deletions.
2 changes: 1 addition & 1 deletion test/distributed/_tensor/test_common_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import torch
from torch.distributed._tensor import DeviceMesh
from torch.distributed._tensor._op_schema import OpSchema
from torch.distributed._tensor.ops.common_rules import einop_rule, pointwise_rule
from torch.distributed._tensor.ops._common_rules import einop_rule, pointwise_rule
from torch.distributed._tensor.placement_types import DTensorSpec, TensorMeta
from torch.testing._internal.common_utils import run_tests
from torch.testing._internal.distributed._tensor.common_dtensor import (
Expand Down
2 changes: 1 addition & 1 deletion test/distributed/_tensor/test_embedding_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def test_sharded_embedding_rowwise(self):
self._run_embedding_op_test(mesh, 0, [6, 7, 6], 13, 22)
self._run_embedding_op_test(mesh, 0, [34], 15, 14, padding_idx=10)

from torch.distributed._tensor.ops.embedding_ops import _MaskPartial
from torch.distributed._tensor.ops._embedding_ops import _MaskPartial

# test collectives
embedding_mod = torch.nn.Embedding(10, 20, device=self.device_type)
Expand Down
8 changes: 4 additions & 4 deletions test/distributed/_tensor/test_op_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from torch.distributed._tensor import DeviceMesh, DTensor
from torch.distributed._tensor._collective_utils import redistribute_cost
from torch.distributed._tensor._op_schema import OpSchema, OpStrategy, PlacementStrategy
from torch.distributed._tensor.ops.basic_strategy import (
from torch.distributed._tensor.ops._einsum_strategy import (
EinsumDims,
gen_einsum_strategies,
)
Expand Down Expand Up @@ -169,7 +169,7 @@ def test_redistribute_cost_mesh_1d(self):

def test_redistribute_cost_latency(self):
# test cost model on addmm op
from torch.distributed._tensor.ops.matrix_ops import addmm_strategy
from torch.distributed._tensor.ops._matrix_ops import addmm_strategy

mesh = self.build_device_mesh()
shard0_placement = (Shard(0),)
Expand Down Expand Up @@ -246,7 +246,7 @@ def test_redistribute_cost_mesh_2d(self):
self.assertTrue(allreduce_cost > reduce_scatter_cost)

def test_mm_strategies(self):
from torch.distributed._tensor.ops.matrix_ops import mm_strategy
from torch.distributed._tensor.ops._matrix_ops import mm_strategy

mesh = self.build_device_mesh()
lhs_tensor = torch.randn(6, 8)
Expand Down Expand Up @@ -292,7 +292,7 @@ def test_mm_strategies(self):
self.assertFalse(output_sharding.needs_redistribute)

def test_bmm_strategies(self):
from torch.distributed._tensor.ops.matrix_ops import bmm_strategy
from torch.distributed._tensor.ops._matrix_ops import bmm_strategy

mesh = self.build_device_mesh()
lhs_tensor = torch.randn(8, 6, 8)
Expand Down
2 changes: 1 addition & 1 deletion test/distributed/_tensor/test_tensor_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,7 +445,7 @@ def test_gather(self):
# case 2 input sharding: input sharded, index replicated, output mask partial
# only works when index has size 1 on the gather dimension and
# input is sharded on the gather dimension
from torch.distributed._tensor.ops.embedding_ops import _MaskPartial
from torch.distributed._tensor.ops._embedding_ops import _MaskPartial

gather_dim = 1
global_input = torch.randn(12, 8, 16)
Expand Down
2 changes: 1 addition & 1 deletion test/distributed/_tensor/test_view_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from torch import rand, randn, Tensor
from torch.distributed._tensor import DeviceMesh, distribute_tensor, Replicate, Shard
from torch.distributed._tensor.debug import CommDebugMode
from torch.distributed._tensor.ops.view_ops import (
from torch.distributed._tensor.ops._view_ops import (
Broadcast,
dim_maps,
Flatten,
Expand Down
2 changes: 1 addition & 1 deletion torch/distributed/_spmd/batch_dim_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import torch.utils._pytree as pytree
from torch import Tensor
from torch.distributed._tensor import DeviceMesh, Replicate, Shard
from torch.distributed._tensor.ops.view_ops import dim_maps, DimSpec, InputDim
from torch.distributed._tensor.ops._view_ops import dim_maps, DimSpec, InputDim
from torch.distributed._tensor.placement_types import _Partial, DTensorSpec


Expand Down
2 changes: 1 addition & 1 deletion torch/distributed/_spmd/experimental_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import torch
from torch.distributed._tensor._op_schema import OpSchema, OutputSharding
from torch.distributed._tensor.ops.common_rules import pointwise_rule
from torch.distributed._tensor.ops._common_rules import pointwise_rule
from torch.distributed._tensor.ops.utils import register_prop_rule
from torch.distributed._tensor.placement_types import (
_Partial,
Expand Down
18 changes: 9 additions & 9 deletions torch/distributed/_tensor/ops/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
# Copyright (c) Meta Platforms, Inc. and affiliates
from .conv_ops import * # noqa: F403
from .embedding_ops import * # noqa: F403
from .experimental_ops import * # noqa: F403
from .math_ops import * # noqa: F403
from .matrix_ops import * # noqa: F403
from .pointwise_ops import * # noqa: F403
from .random_ops import * # noqa: F403
from .tensor_ops import * # noqa: F403
from .view_ops import * # noqa: F403
from ._conv_ops import * # noqa: F403
from ._embedding_ops import * # noqa: F403
from ._experimental_ops import * # noqa: F403
from ._math_ops import * # noqa: F403
from ._matrix_ops import * # noqa: F403
from ._pointwise_ops import * # noqa: F403
from ._random_ops import * # noqa: F403
from ._tensor_ops import * # noqa: F403
from ._view_ops import * # noqa: F403
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
PlacementList,
PlacementStrategy,
)
from torch.distributed._tensor.ops.basic_strategy import gen_einsum_strategies
from torch.distributed._tensor.ops._einsum_strategy import gen_einsum_strategies
from torch.distributed._tensor.ops.utils import (
expand_to_full_mesh_op_strategy,
generate_redistribute_costs,
Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
StrategyType,
TupleStrategy,
)
from torch.distributed._tensor.ops.common_rules import pointwise_rule
from torch.distributed._tensor.ops.embedding_ops import _MaskPartial
from torch.distributed._tensor.ops._common_rules import pointwise_rule
from torch.distributed._tensor.ops._embedding_ops import _MaskPartial
from torch.distributed._tensor.ops.utils import (
expand_to_full_mesh_op_strategy,
is_tensor_dim_sharded,
Expand Down
File renamed without changes.
4 changes: 2 additions & 2 deletions torch/distributed/tensor/parallel/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
import torch.distributed.distributed_c10d as c10d
from torch import Tensor
from torch.distributed._tensor import DTensor, Replicate, Shard
from torch.distributed._tensor.ops.embedding_ops import _MaskPartial
from torch.distributed._tensor.ops.math_ops import (
from torch.distributed._tensor.ops._embedding_ops import _MaskPartial
from torch.distributed._tensor.ops._math_ops import (
_skip_dim,
Reduction,
replicate_reduction_dims,
Expand Down

0 comments on commit 1c58aac

Please sign in to comment.