From ca1245f0cf4d8c4af8b7756a4b93d2489cfa17ff Mon Sep 17 00:00:00 2001 From: Emma Yang Date: Wed, 30 Oct 2024 15:33:29 -0400 Subject: [PATCH] Updated logic for extracting data sizes, made all methods static --- torch/distributed/_tools/fake_collectives.py | 163 ++++++------------- 1 file changed, 50 insertions(+), 113 deletions(-) diff --git a/torch/distributed/_tools/fake_collectives.py b/torch/distributed/_tools/fake_collectives.py index 7de6189c1303b..7d247e1e0b875 100644 --- a/torch/distributed/_tools/fake_collectives.py +++ b/torch/distributed/_tools/fake_collectives.py @@ -10,7 +10,8 @@ from contextlib import contextmanager, nullcontext import logging from datetime import timedelta -from typing import cast, Optional, overload +from typing import cast, Optional, overload, Any +from torch.utils._pytree import tree_map_only aten = torch.ops.aten c10d = torch.ops.c10d @@ -108,6 +109,12 @@ def _alltoall_meta(*args): fakework_script_obj = fakework.boxed() return (args[0], fakework_script_obj) +def _alltoall_base_meta(*args): + fakework = FakeWork() + fakework.__setattr__("getFuture", fakework.get_future) + fakework_script_obj = fakework.boxed() + return fakework_script_obj + def _send_meta(*args): fakework = FakeWork() @@ -145,6 +152,7 @@ def _barrier_meta(*args): lib_impl.impl("gather_", _gather_meta, "Meta") lib_impl.impl("scatter_", _scatter_meta, "Meta") lib_impl.impl("alltoall_", _alltoall_meta, "Meta") + lib_impl.impl("alltoall_base_", _alltoall_base_meta, "Meta") lib_impl.impl("barrier", _barrier_meta, "Meta") lib_impl.impl("send", _send_meta, "Meta") lib_impl.impl("recv_", _recv_meta, "Meta") @@ -164,6 +172,7 @@ def _barrier_meta(*args): c10d.gather_.default, c10d.scatter_.default, c10d.alltoall_.default, + c10d.alltoall_base_.default, _c10d_functional.broadcast.default, _c10d_functional.all_reduce.default, _c10d_functional.all_to_all_single.default, @@ -174,85 +183,21 @@ def _barrier_meta(*args): class IgnoreDistMode(TorchDispatchMode): class CollectiveOp: - # CollectiveOp stores the name and data sizes of the collective operation - # so that the cost of the operation can be calculated - def __init__(self, func, res, *args, **kwargs): - self.func = func - self.tensor_args = self.parse_collective_args(args) - self.data_size = self.get_data_sizes() - self.pg_name, self.world_size = self.get_process_group_properties(args) - self.output_data_size = self.get_output_data_size(args, res) - - # process group, world size of process group - - # logging.info(f"Data size: {self.data_size}") - # logging.info(f"Process Group: {self.pg_name}, {self.world_size}") - logging.info(f"res: {res}") - logging.info(f"Output data size: {self.output_data_size}") - - # Extract the tensor from the collective operator arguments - def parse_collective_args(self, args): - if self.func in [ - c10d.broadcast_.default, - _c10d_functional.broadcast.default, - c10d.allreduce_.default, - _c10d_functional.all_reduce.default, - c10d.reduce_.default, - c10d.send.default, - c10d.recv_.default, - _c10d_functional.all_to_all_single.default, - ]: - return args[0] - elif self.func in [ - c10d.allgather_.default, - c10d._allgather_base_.default, - c10d.reduce_scatter_.default, - c10d._reduce_scatter_base_.default, - c10d.gather_.default, - c10d.scatter_.default, - c10d.alltoall_.default, - ]: - return args[1] - elif self.func in [ - _c10d_functional.all_gather_into_tensor.default, - _c10d_functional.reduce_scatter_tensor.default, - ]: - return args[0] - else: - return None - - def get_data_sizes(self): - if self.func in [ - c10d.broadcast_.default, - c10d.allreduce_.default, - c10d.allgather_.default, - c10d.gather_.default, - c10d.reduce_.default, - c10d.alltoall_.default, - c10d.send.default, - c10d.recv_.default, - ]: - return [t.untyped_storage().nbytes() for t in self.tensor_args] - elif self.func in [ - c10d.reduce_scatter_.default, - c10d.scatter_.default, - ]: - return [t.untyped_storage().nbytes() for l in self.tensor_args for t in l] - elif self.func in [ - _c10d_functional.broadcast.default, - _c10d_functional.all_reduce.default, - c10d._allgather_base_.default, - _c10d_functional.all_gather_into_tensor.default, - _c10d_functional.reduce_scatter_tensor.default, - c10d._reduce_scatter_base_.default, - _c10d_functional.all_to_all_single.default, - ]: - return self.tensor_args.untyped_storage().nbytes() - else: - return None - - def get_process_group_properties(self, args): - if self.func in [ + @staticmethod + def sum_tensors(arg: Any) -> int: + # Calculate the memory consumed by the inputs or outputs of the module. + total_memory = 0 + + def sum_bytes(t: torch.Tensor) -> None: + nonlocal total_memory + total_memory += t.untyped_storage().nbytes() + + tree_map_only(torch.Tensor, sum_bytes, arg) + return total_memory + + @staticmethod + def get_process_group_properties(args, func): + if func in [ c10d.broadcast_.default, c10d.allreduce_.default, c10d.reduce_.default, @@ -260,7 +205,7 @@ def get_process_group_properties(self, args): c10d.recv_.default, ]: pg = ProcessGroup.unbox(args[1]) - elif self.func in [ + elif func in [ c10d.allgather_.default, c10d._allgather_base_.default, c10d.reduce_scatter_.default, @@ -275,37 +220,26 @@ def get_process_group_properties(self, args): return None, None return pg.name(), pg.size() - - def get_output_data_size(self, args, res): - if self.func in [ - c10d._allgather_base_.default, - c10d._reduce_scatter_base_.default, - ]: - return args[0].untyped_storage().nbytes() - elif self.func in [ - c10d.reduce_scatter_.default, - c10d.scatter_.default, - c10d.alltoall_.default, - ]: - return [t.untyped_storage().nbytes() for t in args[0]] - elif self.func in [ - c10d.allgather_.default, - c10d.gather_.default, - ]: - return [t.untyped_storage().nbytes() for l in args[0] for t in l] - elif self.func in [ - c10d.broadcast_.default, - c10d.allreduce_.default, - ]: - return [t.untyped_storage().nbytes() for t in res[0]] - elif type(res) == list and type(res[0]) == torch.FakeTensor: - return res[0].untyped_storage().nbytes() - elif type(res) == FakeTensor: - return res.untyped_storage().nbytes() - else: - return None - - collectives = [] + + @staticmethod + def get_tensor_size(args, func, kwargs, res): + match func: + case c10d.broadcast_.default: + return args[0][0].untyped_storage().nbytes() + case c10d.allreduce_.default | c10d.send.default | c10d.recv_.default | c10d.allgather_.default | c10d.gather_.default | c10d.reduce_.default: + return IgnoreDistMode.CollectiveOp.sum_tensors(args[0]) + case c10d.reduce_scatter_.default | c10d.scatter_.default: + return IgnoreDistMode.CollectiveOp.sum_tensors(args[1]) + case c10d._reduce_scatter_base_.default: + return args[1].untyped_storage().nbytes() + case c10d._allgather_base_.default | _c10d_functional.broadcast.default | _c10d_functional.all_reduce.default | _c10d_functional.all_to_all_single.default | _c10d_functional.all_gather_into_tensor.default: + return args[0].untyped_storage().nbytes() + case _c10d_functional.reduce_scatter_tensor.default: + return res.untyped_storage().nbytes() + case c10d.alltoall_.default: + return max(IgnoreDistMode.CollectiveOp.sum_tensors(args[0]), IgnoreDistMode.CollectiveOp.sum_tensors(args[1])) + case c10d.alltoall_base_.default: + return max(args[0].untyped_storage().nbytes(), args[1].untyped_storage().nbytes()) def __torch_dispatch__(self, func, types, args=(), kwargs=None): logging.info(f"Function name: {str(func.__name__)}") @@ -316,7 +250,10 @@ def __torch_dispatch__(self, func, types, args=(), kwargs=None): res = func(*args, **kwargs or {}) if func in collective_op_funcs: - self.collectives.append(self.CollectiveOp(func, res, *args, **kwargs)) + pg_name, pg_size = self.CollectiveOp.get_process_group_properties(args, func) + size = self.CollectiveOp.get_tensor_size(args, func, kwargs, res) + logging.info(f"Process Group: {pg_name} ({pg_size})") + logging.info(f"Tensor Size: {size}") return res