Skip to content

Commit

Permalink
Updated logic for extracting data sizes, made all methods static
Browse files Browse the repository at this point in the history
  • Loading branch information
emmay78 committed Oct 30, 2024
1 parent 821bdff commit ca1245f
Showing 1 changed file with 50 additions and 113 deletions.
163 changes: 50 additions & 113 deletions torch/distributed/_tools/fake_collectives.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
from contextlib import contextmanager, nullcontext
import logging
from datetime import timedelta
from typing import cast, Optional, overload
from typing import cast, Optional, overload, Any
from torch.utils._pytree import tree_map_only

aten = torch.ops.aten
c10d = torch.ops.c10d
Expand Down Expand Up @@ -108,6 +109,12 @@ def _alltoall_meta(*args):
fakework_script_obj = fakework.boxed()
return (args[0], fakework_script_obj)

def _alltoall_base_meta(*args):
fakework = FakeWork()
fakework.__setattr__("getFuture", fakework.get_future)
fakework_script_obj = fakework.boxed()
return fakework_script_obj


def _send_meta(*args):
fakework = FakeWork()
Expand Down Expand Up @@ -145,6 +152,7 @@ def _barrier_meta(*args):
lib_impl.impl("gather_", _gather_meta, "Meta")
lib_impl.impl("scatter_", _scatter_meta, "Meta")
lib_impl.impl("alltoall_", _alltoall_meta, "Meta")
lib_impl.impl("alltoall_base_", _alltoall_base_meta, "Meta")
lib_impl.impl("barrier", _barrier_meta, "Meta")
lib_impl.impl("send", _send_meta, "Meta")
lib_impl.impl("recv_", _recv_meta, "Meta")
Expand All @@ -164,6 +172,7 @@ def _barrier_meta(*args):
c10d.gather_.default,
c10d.scatter_.default,
c10d.alltoall_.default,
c10d.alltoall_base_.default,
_c10d_functional.broadcast.default,
_c10d_functional.all_reduce.default,
_c10d_functional.all_to_all_single.default,
Expand All @@ -174,93 +183,29 @@ def _barrier_meta(*args):
class IgnoreDistMode(TorchDispatchMode):

class CollectiveOp:
# CollectiveOp stores the name and data sizes of the collective operation
# so that the cost of the operation can be calculated
def __init__(self, func, res, *args, **kwargs):
self.func = func
self.tensor_args = self.parse_collective_args(args)
self.data_size = self.get_data_sizes()
self.pg_name, self.world_size = self.get_process_group_properties(args)
self.output_data_size = self.get_output_data_size(args, res)

# process group, world size of process group

# logging.info(f"Data size: {self.data_size}")
# logging.info(f"Process Group: {self.pg_name}, {self.world_size}")
logging.info(f"res: {res}")
logging.info(f"Output data size: {self.output_data_size}")

# Extract the tensor from the collective operator arguments
def parse_collective_args(self, args):
if self.func in [
c10d.broadcast_.default,
_c10d_functional.broadcast.default,
c10d.allreduce_.default,
_c10d_functional.all_reduce.default,
c10d.reduce_.default,
c10d.send.default,
c10d.recv_.default,
_c10d_functional.all_to_all_single.default,
]:
return args[0]
elif self.func in [
c10d.allgather_.default,
c10d._allgather_base_.default,
c10d.reduce_scatter_.default,
c10d._reduce_scatter_base_.default,
c10d.gather_.default,
c10d.scatter_.default,
c10d.alltoall_.default,
]:
return args[1]
elif self.func in [
_c10d_functional.all_gather_into_tensor.default,
_c10d_functional.reduce_scatter_tensor.default,
]:
return args[0]
else:
return None

def get_data_sizes(self):
if self.func in [
c10d.broadcast_.default,
c10d.allreduce_.default,
c10d.allgather_.default,
c10d.gather_.default,
c10d.reduce_.default,
c10d.alltoall_.default,
c10d.send.default,
c10d.recv_.default,
]:
return [t.untyped_storage().nbytes() for t in self.tensor_args]
elif self.func in [
c10d.reduce_scatter_.default,
c10d.scatter_.default,
]:
return [t.untyped_storage().nbytes() for l in self.tensor_args for t in l]
elif self.func in [
_c10d_functional.broadcast.default,
_c10d_functional.all_reduce.default,
c10d._allgather_base_.default,
_c10d_functional.all_gather_into_tensor.default,
_c10d_functional.reduce_scatter_tensor.default,
c10d._reduce_scatter_base_.default,
_c10d_functional.all_to_all_single.default,
]:
return self.tensor_args.untyped_storage().nbytes()
else:
return None

def get_process_group_properties(self, args):
if self.func in [
@staticmethod
def sum_tensors(arg: Any) -> int:
# Calculate the memory consumed by the inputs or outputs of the module.
total_memory = 0

def sum_bytes(t: torch.Tensor) -> None:
nonlocal total_memory
total_memory += t.untyped_storage().nbytes()

tree_map_only(torch.Tensor, sum_bytes, arg)
return total_memory

@staticmethod
def get_process_group_properties(args, func):
if func in [
c10d.broadcast_.default,
c10d.allreduce_.default,
c10d.reduce_.default,
c10d.send.default,
c10d.recv_.default,
]:
pg = ProcessGroup.unbox(args[1])
elif self.func in [
elif func in [
c10d.allgather_.default,
c10d._allgather_base_.default,
c10d.reduce_scatter_.default,
Expand All @@ -275,37 +220,26 @@ def get_process_group_properties(self, args):
return None, None

return pg.name(), pg.size()

def get_output_data_size(self, args, res):
if self.func in [
c10d._allgather_base_.default,
c10d._reduce_scatter_base_.default,
]:
return args[0].untyped_storage().nbytes()
elif self.func in [
c10d.reduce_scatter_.default,
c10d.scatter_.default,
c10d.alltoall_.default,
]:
return [t.untyped_storage().nbytes() for t in args[0]]
elif self.func in [
c10d.allgather_.default,
c10d.gather_.default,
]:
return [t.untyped_storage().nbytes() for l in args[0] for t in l]
elif self.func in [
c10d.broadcast_.default,
c10d.allreduce_.default,
]:
return [t.untyped_storage().nbytes() for t in res[0]]
elif type(res) == list and type(res[0]) == torch.FakeTensor:
return res[0].untyped_storage().nbytes()
elif type(res) == FakeTensor:
return res.untyped_storage().nbytes()
else:
return None

collectives = []

@staticmethod
def get_tensor_size(args, func, kwargs, res):
match func:
case c10d.broadcast_.default:
return args[0][0].untyped_storage().nbytes()
case c10d.allreduce_.default | c10d.send.default | c10d.recv_.default | c10d.allgather_.default | c10d.gather_.default | c10d.reduce_.default:
return IgnoreDistMode.CollectiveOp.sum_tensors(args[0])
case c10d.reduce_scatter_.default | c10d.scatter_.default:
return IgnoreDistMode.CollectiveOp.sum_tensors(args[1])
case c10d._reduce_scatter_base_.default:
return args[1].untyped_storage().nbytes()
case c10d._allgather_base_.default | _c10d_functional.broadcast.default | _c10d_functional.all_reduce.default | _c10d_functional.all_to_all_single.default | _c10d_functional.all_gather_into_tensor.default:
return args[0].untyped_storage().nbytes()
case _c10d_functional.reduce_scatter_tensor.default:
return res.untyped_storage().nbytes()
case c10d.alltoall_.default:
return max(IgnoreDistMode.CollectiveOp.sum_tensors(args[0]), IgnoreDistMode.CollectiveOp.sum_tensors(args[1]))
case c10d.alltoall_base_.default:
return max(args[0].untyped_storage().nbytes(), args[1].untyped_storage().nbytes())

def __torch_dispatch__(self, func, types, args=(), kwargs=None):
logging.info(f"Function name: {str(func.__name__)}")
Expand All @@ -316,7 +250,10 @@ def __torch_dispatch__(self, func, types, args=(), kwargs=None):
res = func(*args, **kwargs or {})

if func in collective_op_funcs:
self.collectives.append(self.CollectiveOp(func, res, *args, **kwargs))
pg_name, pg_size = self.CollectiveOp.get_process_group_properties(args, func)
size = self.CollectiveOp.get_tensor_size(args, func, kwargs, res)
logging.info(f"Process Group: {pg_name} ({pg_size})")
logging.info(f"Tensor Size: {size}")
return res


Expand Down

0 comments on commit ca1245f

Please sign in to comment.