From ad98dc944be4308f405ab34e78fa85b16c7d3709 Mon Sep 17 00:00:00 2001 From: Longjie Zheng <32992656+zhenglongjiepheonix@users.noreply.github.com> Date: Mon, 2 Sep 2024 10:40:14 -0400 Subject: [PATCH] Modify Parallelization Strategy to Make it More General (#1988) * modify parallelization strategy * only support model id in api now * more comments * more comments * address comments * remove idle runner * fix * format * more comments * nit --- .../workflows/test_fx_automatic_parallel.yml | 2 +- optimum/fx/parallelization/api.py | 87 ++-- optimum/fx/parallelization/core.py | 5 + optimum/fx/parallelization/decomp.py | 225 +++++++++ .../parallelization/op_registry/__init__.py | 15 + .../op_registry/op_handlers.py | 450 ++++++++++++++++++ optimum/fx/parallelization/passes.py | 350 +++++--------- optimum/fx/parallelization/utils.py | 29 +- 8 files changed, 878 insertions(+), 285 deletions(-) create mode 100644 optimum/fx/parallelization/decomp.py create mode 100644 optimum/fx/parallelization/op_registry/__init__.py create mode 100644 optimum/fx/parallelization/op_registry/op_handlers.py diff --git a/.github/workflows/test_fx_automatic_parallel.yml b/.github/workflows/test_fx_automatic_parallel.yml index 3c913e3f7e..d8af6e40ca 100644 --- a/.github/workflows/test_fx_automatic_parallel.yml +++ b/.github/workflows/test_fx_automatic_parallel.yml @@ -24,7 +24,7 @@ jobs: config: - name: GPU-enabled Optimum Test Suite image: nvidia/cuda:12.4.1-devel-ubuntu22.04 - gpu_target: ["nvidia-multi-gpu-l4-runners", "nvidia-multi-gpu-a10-runners"] + gpu_target: ["nvidia-multi-gpu-a10-runners"] name: ${{ matrix.config.name }} runs-on: diff --git a/optimum/fx/parallelization/api.py b/optimum/fx/parallelization/api.py index bd307bd93c..9700b491e5 100644 --- a/optimum/fx/parallelization/api.py +++ b/optimum/fx/parallelization/api.py @@ -15,10 +15,11 @@ import importlib import os from functools import partial -from typing import List, Union +from typing import Callable, List import torch from torch.fx import GraphModule +from transformers import AutoConfig from .core import Config, ParallelExecutionCtx from .passes import build_parallel_pass_pipeline @@ -43,30 +44,31 @@ def parallelize_backend( def parallelize_model( - model: Union[torch.nn.Module, str], + model: str, parallel_ctx: ParallelExecutionCtx, *model_args, **kwargs, -): +) -> Callable: """ API for automatic model parallelism through Pytorch FX. Args: - model (Union[torch.nn.Module, str]): - Model to parallelize, could either be a module or a model id on the Huggingface Hub. - parallel_ctx (ParallelExecutionCtx): + model (`str`): + Model to parallelize, a model id on the Huggingface Hub or path to a local directory containing config and weights + of the model. + parallel_ctx (`ParallelExecutionCtx`): Parallel execution context containing process groups the current process belongs to. - *model_args (Any): + *model_args (`Any`): Additional postional arguments for intializing the model if a model id is passed. - revision (str, defaults to `main`): + revision (`str`, defaults to `main`): Model revision for weights downloading if a model id is passed. - cache_dir (Optional[str], defaults to `None`): + cache_dir (`Optional[str]`, defaults to `None`): Cache directory to store downloaded weights. Defaults to None. - local_files_only (bool, defaults to `False`): + local_files_only (`bool`, defaults to `False`): Whether to use local files only, will avoid downloading from remote if set to `True`. - skip_load_weights (bool, defaults to `False`): + skip_load_weights (`bool`, defaults to `False`): Whether to skip loading weights from disk to model. - **kwargs (Dict[str, Any]): + **kwargs (`Dict[str, Any]`): Addtional keyword arguments for overriding fields in parallel config, model config and `Model.__init__`. """ revision = kwargs.pop("revision", "main") @@ -80,44 +82,41 @@ def parallelize_model( setattr(parallel_config, k, v) kwargs.pop(k) - if isinstance(model, str): - from transformers import AutoConfig - - is_local = os.path.isdir(model) - if not is_local: - hf_folder = download_model_from_hf( - model_name_or_path=model, - cache_dir=cache_dir, - revision=revision, - local_files_only=local_files_only, - skip_download_weights=skip_load_weights, - ) - else: - hf_folder = model - - # should be able to load config using only local files - model_config, kwargs = AutoConfig.from_pretrained( - hf_folder, revision=revision, local_files_only=True, return_unused_kwargs=True, **kwargs + is_local = os.path.isdir(model) + if not is_local: + hf_folder = download_model_from_hf( + model_name_or_path=model, + cache_dir=cache_dir, + revision=revision, + local_files_only=local_files_only, + skip_download_weights=skip_load_weights, ) + else: + hf_folder = model - # try getting model class info from config - model_arch = model_config.architectures - model_cls = getattr(importlib.import_module("transformers"), model_arch[0]) + # should be able to load config using only local files + model_config, kwargs = AutoConfig.from_pretrained( + hf_folder, revision=revision, local_files_only=True, return_unused_kwargs=True, **kwargs + ) - if not skip_load_weights: - parallel_ctx.weight_map = try_collect_weight_map(model, cache_dir, hf_folder) + # try getting model class info from config + model_arch = model_config.architectures + model_cls = getattr(importlib.import_module("transformers"), model_arch[0]) - torch_dtype, dtype_orig = kwargs.pop("torch_dtype", None), None - if torch_dtype is not None: - dtype_orig = model_cls._set_default_torch_dtype(torch_dtype) + if not skip_load_weights: + parallel_ctx.weight_map = try_collect_weight_map(model, cache_dir, hf_folder) - with MetaAwareMethodsPatcher(): - model = model_cls(model_config, *model_args, **kwargs) - # TODO: remove this once support training-time trace - model.eval() + torch_dtype, dtype_orig = kwargs.pop("torch_dtype", None), None + if torch_dtype is not None: + dtype_orig = model_cls._set_default_torch_dtype(torch_dtype) - if dtype_orig is not None: - torch.set_default_dtype(dtype_orig) + with MetaAwareMethodsPatcher(): + model = model_cls(model_config, *model_args, **kwargs) + # TODO: remove this once support training-time trace + model.eval() + + if dtype_orig is not None: + torch.set_default_dtype(dtype_orig) move_model_to_device(model, device=parallel_ctx.current_device) initialize_parameter_meta(model) diff --git a/optimum/fx/parallelization/core.py b/optimum/fx/parallelization/core.py index 1d13b00b46..84737292f0 100644 --- a/optimum/fx/parallelization/core.py +++ b/optimum/fx/parallelization/core.py @@ -166,8 +166,13 @@ class Config: - weight_init_fn (`Callable`, defaults to `partial(nn.init.normal_, std=0.02)`) Initialization function of weights in `nn.Linear` and `nn.Embedding` layers, if not provided weights loading path. + + - enable_sequence_parallel (`bool`, defaults to `False`): + Whether to enable Megatron-style sequence parallelism in searching parallelization + strategies. """ lint_and_recompile: bool = True clean_markers_after_all_passes: bool = True weight_init_fn: Callable = partial(nn.init.normal_, std=0.02) + enable_sequence_parallel: bool = False diff --git a/optimum/fx/parallelization/decomp.py b/optimum/fx/parallelization/decomp.py new file mode 100644 index 0000000000..26258d451b --- /dev/null +++ b/optimum/fx/parallelization/decomp.py @@ -0,0 +1,225 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import contextlib +from typing import Callable, Dict, List + +import torch +import torch.nn.functional as F +import torch.utils._pytree as pytree +from torch import SymBool, SymFloat, SymInt +from torch._decomp import core_aten_decompositions +from torch._functorch._aot_autograd.functional_utils import from_fun, to_fun +from torch._subclasses.functional_tensor import FunctionalTensor, FunctionalTensorMode, disable_functional_mode +from torch.fx import Graph, GraphModule, Interpreter, Proxy, traceback +from torch.fx.experimental.proxy_tensor import ( + ProxyTorchDispatchMode, + _ProxyTensor, + _SymNodeDict, + decompose, + disable_proxy_modes_tracing, + fetch_object_proxy, + fetch_sym_proxy, + get_proxy_slot, + track_tensor_tree, +) +from torch.fx.proxy import GraphAppendingTracer +from torch.utils.weak import WeakTensorKeyDictionary + + +def is_leaf_module(m): + return (m.__module__.startswith("torch.nn") or m.__module__.startswith("torch.ao.nn")) and not isinstance( + m, torch.nn.Sequential + ) + + +@contextlib.contextmanager +def trace_decomp_origin(): + creat_node = Graph.create_node + + def create_node_(*args, **kwargs): + node = creat_node(*args, **kwargs) + node.meta["traced_from"] = traceback.get_current_meta()["from_node"] + return node + + try: + Graph.create_node = create_node_ + yield + finally: + Graph.create_node = creat_node + + +class DecompTracer(GraphAppendingTracer): + """ + DecompTracer is a tracer class which works together with `DecompositionInterpreter`, it keeps track of tensors and their + corresponding proxy objects during execution process. When invoked with `create_proxy`, it creates a node in the containing + graph and associate the output tensor of the node with the created proxy. + + See https://github.com/pytorch/pytorch/blob/main/torch/fx/experimental/proxy_tensor.py for more details. + """ + + def __init__(self, graph: Graph): + super().__init__(graph) + self.tensor_tracker = WeakTensorKeyDictionary() + self.symnode_tracker = _SymNodeDict() + + +class DecompositionInterpreter(Interpreter): + """ + DecompositionInterpreter takes the high-level graph module, run the iternal nodes following the topo order, and decompose + high-level pytorch operators into core aten operators by utilizing torch dispatch infrastructure along the way. + + Notes: + - Certain primitive layers(like `nn.Linear`, `nn.Embedding`, and activation layers) are preserved because we have specific + heuristic based parallelization strategy for them so that we can conveniently replace them into their parallelized counterparts + in the orignal graph module. + + - The traced graph is a low-level equivalent representation of the original graph module, and is only used for + parallel axis propagation and analysis, the original graph module is still used for real execution. + """ + + def __init__( + self, module: GraphModule, new_graph: Graph, decomposition_table=None, leaf_function_targets=None, **kwargs + ): + super().__init__(module, **kwargs) + self.new_graph = new_graph + self.tracer = DecompTracer(new_graph) + + self.decomposition_table = decomposition_table + if self.decomposition_table is None: + self.decomposition_table = {} + + self.leaf_function_targets = leaf_function_targets + if self.leaf_function_targets is None: + self.leaf_function_targets = [] + + self.fun_mode = FunctionalTensorMode() + self.mode = ProxyTorchDispatchMode(self.tracer, tracing_mode="real") + + def placeholder(self, target, args, kwargs): + out = super().placeholder(target, args, kwargs) + out = pytree.tree_map_only(FunctionalTensor, lambda x: from_fun(x), out) + proxy = self.tracer.create_proxy("placeholder", target, args, kwargs) + + with disable_proxy_modes_tracing(): + track_tensor_tree(out, proxy, constant=None, tracer=self.tracer) + + out = pytree.tree_map_only(torch.Tensor, lambda x: to_fun(x), out) + return out + + def call_function(self, target, args, kwargs): + if target in self.leaf_function_targets: + args = pytree.tree_map_only(FunctionalTensor, lambda x: from_fun(x), args) + kwargs = pytree.tree_map_only(FunctionalTensor, lambda x: from_fun(x), kwargs) + + with disable_proxy_modes_tracing(), disable_functional_mode(): + out = target(*args, **kwargs) + + args, kwargs = pytree.tree_map_only((torch.Tensor,), fetch_object_proxy(self.tracer), (args, kwargs)) + proxy_args, proxy_kwargs = pytree.tree_map_only( + (SymInt, SymFloat, SymBool), + fetch_sym_proxy(self.tracer), + pytree.tree_map_only(_ProxyTensor, lambda e: e.proxy, (args, kwargs)), + ) + proxy = self.tracer.create_proxy("call_function", target, proxy_args, proxy_kwargs) + + with disable_proxy_modes_tracing(): + track_tensor_tree(out, proxy, constant=None, tracer=self.tracer) + + out = pytree.tree_map_only(torch.Tensor, lambda x: to_fun(x), out) + return out + + return super().call_function(target, args, kwargs) + + def call_module(self, target, args, kwargs): + assert isinstance(target, str) + submod = self.fetch_attr(target) + if not is_leaf_module(submod): + return super().call_module(target, args, kwargs) + + args = pytree.tree_map_only(FunctionalTensor, lambda x: from_fun(x), args) + kwargs = pytree.tree_map_only(FunctionalTensor, lambda x: from_fun(x), kwargs) + + with disable_proxy_modes_tracing(), disable_functional_mode(): + out = submod(*args, **kwargs) + + args, kwargs = pytree.tree_map_only((torch.Tensor,), fetch_object_proxy(self.tracer), (args, kwargs)) + proxy_args, proxy_kwargs = pytree.tree_map_only( + (SymInt, SymFloat, SymBool), + fetch_sym_proxy(self.tracer), + pytree.tree_map_only(_ProxyTensor, lambda e: e.proxy, (args, kwargs)), + ) + proxy = self.tracer.create_proxy("call_module", target, proxy_args, proxy_kwargs) + + with disable_proxy_modes_tracing(): + track_tensor_tree(out, proxy, constant=None, tracer=self.tracer) + + out = pytree.tree_map_only(torch.Tensor, lambda x: to_fun(x), out) + return out + + def get_attr(self, target, args, kwargs): + out = super().get_attr(target, args, kwargs) + proxy = Proxy(self.new_graph.get_attr(target), self.tracer) + with disable_proxy_modes_tracing(): + track_tensor_tree(out, proxy, constant=None, tracer=self.tracer) + return out + + def output(self, target, args, kwargs): + args = pytree.tree_map_only(FunctionalTensor, lambda x: from_fun(x), args) + kwargs = pytree.tree_map_only(FunctionalTensor, lambda x: from_fun(x), kwargs) + out = super().output(target, args, kwargs) + + def unwrap(e): + return get_proxy_slot(e, self.tracer, e, lambda x: x.proxy.node) + + self.new_graph.output(pytree.tree_map(unwrap, out)) + return out + + def run(self, *args, **kwargs): + with self.fun_mode: + args = pytree.tree_map_only(torch.Tensor, lambda x: to_fun(x), args) + kwargs = pytree.tree_map_only(torch.Tensor, lambda x: to_fun(x), kwargs) + with traceback.preserve_node_meta(), trace_decomp_origin(), decompose(self.decomposition_table), self.mode: + return super().run(*args, **kwargs) + + +def decompose_and_functionalize( + graph_module: GraphModule, + decomposition_table: Dict[torch._ops.OperatorBase, Callable] = core_aten_decompositions(), + leaf_function_targets: List[Callable] = [F.scaled_dot_product_attention], +) -> Callable: + """ + API to decompose and functionalize a high-level graph module. + + Args: + graph_module (`GraphModule`): + The high-level graph module to be decomposed and functionalized. + decomposition_table (`Dict[torch._ops.OperatorBase, Callable]`, defaults to `core_aten_decompostions()`): + The lookup table which maps high-level torch op to their equivalent low-level implementation. + leaf_function_targets (`List[Callable]`, defaults to `[F.scaled_dot_product_attention]`): + Functions which will not be traced through for convenience, `F.scaled_dot_product_attention` is + treated as a leaf function by default so that we don't have to deal with all detailed version of + sdpas in the traced graph. + + Returns: + Callable: a wrapper which returns the traced low-level graph when called with concrete arguments. + """ + new_graph = Graph(owning_module=graph_module) + interp = DecompositionInterpreter(graph_module, new_graph, decomposition_table, leaf_function_targets) + + def wrapper(*args, **kwargs): + interp.run(*args, **kwargs) + return new_graph + + return wrapper diff --git a/optimum/fx/parallelization/op_registry/__init__.py b/optimum/fx/parallelization/op_registry/__init__.py new file mode 100644 index 0000000000..8f8df0f7bd --- /dev/null +++ b/optimum/fx/parallelization/op_registry/__init__.py @@ -0,0 +1,15 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from .op_handlers import REGISTRY, FallbackParallelAxisPropagateHandler diff --git a/optimum/fx/parallelization/op_registry/op_handlers.py b/optimum/fx/parallelization/op_registry/op_handlers.py new file mode 100644 index 0000000000..56b8fc16bc --- /dev/null +++ b/optimum/fx/parallelization/op_registry/op_handlers.py @@ -0,0 +1,450 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from abc import abstractmethod +from typing import Any, List, Optional + +import torch +from torch.fx import Node + +from ..core import Config +from ..utils import is_activation, is_embedding, is_linear + + +class Registry: + """ + Registry class handles registration of parallel axis propagation handlers of different aten ops. + To support a new aten op, you need to register the corresponding handler class by decorating it with `register` function. + """ + + def __init__(self) -> None: + self.mapping = {} + + def register(self, op_types): + def wrapper(cls): + if isinstance(op_types, (list, tuple)): + for op_type in op_types: + self.mapping[op_type] = cls + else: + self.mapping[op_types] = cls + return cls + + return wrapper + + def is_supported(self, op_type) -> bool: + return op_type in self.mapping + + +REGISTRY = Registry() + + +class OpParallelAxisPropagateHandler: + def __init__(self, node: Node, meta_key: str, config: Config) -> None: + self.node = node + self.meta_key = meta_key + self.config = config + + def extract_axis(self, arg: Any) -> Optional[int]: + if not isinstance(arg, Node): + return None + return arg.meta[self.meta_key].get("parallel_axis", None) + + @abstractmethod + def propagate(self) -> List[int]: + raise NotImplementedError + + +@REGISTRY.register( + [ + torch.ops.aten.pow.Tensor_Scalar, + torch.ops.aten.rsqrt.default, + torch.ops.aten.clone.default, + torch.ops.aten.bitwise_not.default, + torch.ops.aten.abs.default, + torch.ops.aten._to_copy.default, + torch.ops.aten.acos.default, + torch.ops.aten.acosh.default, + torch.ops.aten.alias.default, + torch.ops.aten.asin.default, + torch.ops.aten.asinh.default, + torch.ops.aten.atan.default, + torch.ops.aten.atanh.default, + torch.ops.aten.ceil.default, + torch.ops.aten.clamp.default, + torch.ops.aten.cos.default, + torch.ops.aten.cosh.default, + torch.ops.aten.erf.default, + torch.ops.aten.exp.default, + torch.ops.aten.trunc.default, + torch.ops.aten.tanh.default, + torch.ops.aten.tan.default, + torch.ops.aten.add.Scalar, + torch.ops.aten.sub.Scalar, + torch.ops.aten.sqrt.default, + torch.ops.aten.sin.default, + torch.ops.aten.sinh.default, + torch.ops.aten.sign.default, + torch.ops.aten.sigmoid.default, + torch.ops.aten.round.default, + torch.ops.aten.remainder.Scalar, + torch.ops.aten.relu.default, + torch.ops.aten.reciprocal.default, + torch.ops.aten.neg.default, + torch.ops.aten.ne.Scalar, + torch.ops.aten.native_dropout.default, + torch.ops.aten.mul.Scalar, + torch.ops.aten.logical_not.default, + torch.ops.aten.lt.Scalar, + torch.ops.aten.le.Scalar, + torch.ops.aten.log.default, + torch.ops.aten.log10.default, + torch.ops.aten.log2.default, + torch.ops.aten.log1p.default, + torch.ops.aten.leaky_relu.default, + torch.ops.aten.isnan.default, + torch.ops.aten.isinf.default, + torch.ops.aten.hardtanh.default, + torch.ops.aten.gt.Scalar, + torch.ops.aten.gelu.default, + torch.ops.aten.ge.Scalar, + torch.ops.aten.fmod.Scalar, + torch.ops.aten.floor.default, + torch.ops.aten.fill.Scalar, + torch.ops.aten.div.Scalar_mode, + torch.ops.aten.div.Scalar, + torch.ops.aten.bitwise_and.Scalar, + torch.ops.aten.bitwise_or.Scalar, + torch.ops.aten.bitwise_xor.Scalar, + ] +) +class UnaryOpParallelAxisPropagateHandler(OpParallelAxisPropagateHandler): + def propagate(self) -> List[int]: + arg = self.node.all_input_nodes[0] + axis = self.extract_axis(arg) + return [axis] + + +@REGISTRY.register( + [ + torch.ops.aten.atan2.default, + torch.ops.aten.add.Tensor, + torch.ops.aten.bitwise_and.Tensor, + torch.ops.aten.bitwise_or.Tensor, + torch.ops.aten.bitwise_xor.Tensor, + torch.ops.aten.div.Tensor, + torch.ops.aten.div.Tensor_mode, + torch.ops.aten.eq.Tensor, + torch.ops.aten.fmod.Tensor, + torch.ops.aten.ge.Tensor, + torch.ops.aten.gt.Tensor, + torch.ops.aten.le.Tensor, + torch.ops.aten.logical_and.default, + torch.ops.aten.logical_or.default, + torch.ops.aten.logical_xor.default, + torch.ops.aten.lt.Tensor, + torch.ops.aten.maximum.default, + torch.ops.aten.minimum.default, + torch.ops.aten.mul.Tensor, + torch.ops.aten.ne.Tensor, + torch.ops.aten.pow.Tensor_Tensor, + torch.ops.aten.remainder.Tensor, + torch.ops.aten.sub.Tensor, + ] +) +class BinaryOpParallelAxisPropagateHandler(OpParallelAxisPropagateHandler): + def propagate(self) -> List[int]: + input_nodes = self.node.all_input_nodes + # only one node + if len(input_nodes) == 1: + return UnaryOpParallelAxisPropagateHandler(self.node, self.meta_key, self.config).propagate() + + assert len(input_nodes) == 2, "binary op should have exact two nodes as inputs" + lhs_shape, rhs_shape = input_nodes[0].meta["val"].shape, input_nodes[1].meta["val"].shape + lhs_axis = self.extract_axis(input_nodes[0]) + rhs_axis = self.extract_axis(input_nodes[1]) + i, j = len(lhs_shape) - 1, len(rhs_shape) - 1 + while i >= 0 and j >= 0: + k = max(lhs_shape[i], rhs_shape[j]) + assert ( + k % min(lhs_shape[i], rhs_shape[j]) == 0 + ), f"shape {lhs_shape} and {rhs_shape} are not broadcastable!" + i -= 1 + j -= 1 + + if i < 0 and lhs_axis is not None: + lhs_axis += j + 1 + if j < 0 and rhs_axis is not None: + rhs_axis += i + 1 + + if lhs_axis is None: + return [rhs_axis] + elif rhs_axis is None: + return [lhs_axis] + elif lhs_axis != rhs_axis: + return [] + return [lhs_axis] + + +@REGISTRY.register( + [ + torch.ops.aten.amax.default, + torch.ops.aten.amin.default, + torch.ops.aten.any.dim, + torch.ops.aten._log_softmax.default, + torch.ops.aten._softmax.default, + torch.ops.aten.cumsum.default, + torch.ops.aten.mean.dim, + # torch.ops.aten.min.dim, + # torch.ops.aten.max.dim, + torch.ops.aten.var.dim, + torch.ops.aten.sum.dim_IntList, + torch.ops.aten.prod.dim_int, + ] +) +class ReductionOpParallelAxisPropagateHandler(OpParallelAxisPropagateHandler): + def extract_dims( + self, + ) -> List[int]: + ndim = self.node.meta["val"].ndim + dims = None + if "dim" in self.node.kwargs: + dims = self.node.kwargs["dim"] + elif len(self.node.args) > 1 and isinstance(self.node.args[1], (int, list)): + dims = self.node.args[1] + + if isinstance(dims, int): + dims = [dims] + if not dims: + dims = list(range(ndim)) + dims = [(dim + ndim) % ndim for dim in dims] + + keepdim = False + if "keepdim" in self.node.kwargs: + keepdim = self.node.kwargs + elif len(self.node.args) > 2 and isinstance(self.node.args[2], bool): + keepdim = self.node.args[2] + + return dims, keepdim + + def propagate(self) -> List[int]: + dims, keepdim = self.extract_dims() + arg = self.node.all_input_nodes[0] + axis = self.extract_axis(arg) + if axis in dims: + return [] + if axis is None: + return [None] + if keepdim: + return [axis] + return [axis - sum([1 if dim < axis else 0 for dim in dims])] + + +@REGISTRY.register(torch.ops.aten.view.default) +class ViewLikeOpParallelAxisPropagateHandler(OpParallelAxisPropagateHandler): + def propagate(self) -> List[int]: + arg = self.node.args[0] + axis = self.extract_axis(arg) + if axis is None: + return [None] + shape_before, shape_after = arg.meta["val"].shape, self.node.meta["val"].shape + size = 1 + for i in range(len(shape_before) - 1, axis - 1, -1): + size *= shape_before[i] + + cur, i, res = 1, len(shape_after) - 1, [] + while cur <= size and i >= 0: + cur *= shape_after[i] + if cur == size: + res.append(i) + i -= 1 + + return res + + +@REGISTRY.register(torch.ops.aten.unsqueeze.default) +class UnsqueezeParallelAxisPropagateHandler(OpParallelAxisPropagateHandler): + def propagate(self) -> List[int]: + arg, dim = self.node.args[0], self.node.args[1] + ndim = arg.meta["val"].ndim + axis = self.extract_axis(arg) + if axis is None: + return [None] + dim = (dim + ndim) % ndim + if dim <= axis: + return [axis + 1] + return [axis] + + +@REGISTRY.register( + [ + torch.ops.aten.squeeze.dim, + torch.ops.aten.squeeze.dims, + ] +) +class SqueezeParallelAxisPropagateHandler(OpParallelAxisPropagateHandler): + def propagate(self) -> List[int]: + arg, dims = self.node.args[0], self.node.args[1] + axis = self.extract_axis(arg) + if axis is None: + return [None] + + ndim = self.node.args[0].meta["val"].ndim + if isinstance(dims, int): + dims = [dims] + dims = [(dim + ndim) % ndim for dim in dims] + if axis in dims: + # being conservative + return [] + return [axis - sum([1 if dim < axis else 0 for dim in dims])] + + +@REGISTRY.register(torch.ops.aten.permute.default) +class PermuteParallelAxisPropagateHandler(OpParallelAxisPropagateHandler): + def propagate(self) -> List[int]: + arg, dims = self.node.args[0], self.node.args[1] + ndim = arg.meta["val"].ndim + axis = self.extract_axis(arg) + if axis is None: + return [None] + + for i, dim in enumerate(dims): + if (dim + ndim) % ndim == axis: + return [i] + return [] + + +@REGISTRY.register(torch.ops.aten.slice.Tensor) +class SliceParallelAxisPropagateHandler(OpParallelAxisPropagateHandler): + def propagate(self) -> List[int]: + arg, slice_dim = self.node.args[0], self.node.args[1] + axis = self.extract_axis(arg) + if axis is None: + return [None] + ndim = arg.meta["val"].ndim + slice_dim = (slice_dim + ndim) % ndim + if slice_dim == axis: + # slice on the parallel axis is not allowed + return [] + return [axis] + + +@REGISTRY.register(torch.ops.aten.expand.default) +class ExpandParallelAxisPropagateHandler(OpParallelAxisPropagateHandler): + def propagate(self) -> List[int]: + arg, size = self.node.args[0], self.node.args[1] + axis = self.extract_axis(arg) + if axis is None: + return [None] + assert len(size) >= arg.meta["val"].ndim, "input size must be broadcastable to the target size in expand" + return [axis + len(size) - arg.meta["val"].ndim] + + +@REGISTRY.register(torch.ops.aten.cat.default) +class CatParallelAxisPropagateHandler(OpParallelAxisPropagateHandler): + def propagate(self) -> List[int]: + nodes, cat_axis = self.node.all_input_nodes, self.node.args[1] + axis, ndim = self.extract_axis(nodes[0]), nodes[0].meta["val"].ndim + cat_axis = (cat_axis + ndim) % ndim + if cat_axis == axis: + return [] + for i in range(1, len(nodes)): + if self.extract_axis(nodes[i]) != axis: + return [] + return [axis] + + +@REGISTRY.register(torch.ops.aten.constant_pad_nd.default) +class PadParallelAxisPropagateHandler(OpParallelAxisPropagateHandler): + def propagate(self) -> List[int]: + pad, ndim = self.node.args[1], self.node.args[0].meta["val"].ndim + axis = self.extract_axis(self.node.args[0]) + if axis is None: + return [None] + if axis >= ndim - pad // 2: + return [] + return [axis] + + +@REGISTRY.register(torch.ops.aten.copy.default) +class CopyParallelAxisPropagateHandler(OpParallelAxisPropagateHandler): + def propagate(self) -> List[int]: + dst, src = self.node.all_input_nodes + axis_dst = self.extract_axis(dst) + axis_src = self.extract_axis(src) + if axis_dst != axis_src: + return [] + return [axis_dst] + + +@REGISTRY.register(torch.nn.functional.scaled_dot_product_attention) +class SpdaAttnParallelAxisPropagateHandler(OpParallelAxisPropagateHandler): + def propagate(self) -> List[int]: + q, k, v = self.node.args[:3] + q_axis = self.extract_axis(q) + # parallel axis must be the head dimension if being parallelized + if q_axis != self.extract_axis(k) or q_axis != self.extract_axis(v) or q_axis not in {None, 1}: + return [] + return [q_axis] + + +class FallbackParallelAxisPropagateHandler(OpParallelAxisPropagateHandler): + def propagate(self) -> List[int]: + # by default we don't parallelize inputs and constants(except parameters embeded in modules) + if self.node.op in ["placeholder", "get_attr"]: + return [None] + elif self.node.op == "output": + for node in self.node.all_input_nodes: + # TODO: allow parallelized nodes in output, and append comm ops in graph tp all-gather + # parallelized output if intructed + if self.extract_axis(node) is not None: + return [] + return [None] + elif is_linear(self.node): + input_arg = self.node.all_input_nodes[0] + axis = self.extract_axis(input_arg) + if axis is None: + # with input being not parallelized, output can be parallelized on the head dimension, + # i.e., `ColumnLinear`, or not being parallelized by all-gather at the end + return [2, None] + elif self.config.enable_sequence_parallel and axis == 1: + # with input being parallelized on sequence dimension, output can be parallelized on + # the head dimension, i.e., `ColumnLinear` with sequence parallel, or not being parallelized + # by all-gather at the end + return [2, None] + elif axis == 2: + # with input being parallelized on head dimension, output can be parallelized on the + # sequence dimension or not parallelized by all-reduce at the end, i.e., `RowLinear` + # when sp is not enabled + return [1, None] if self.config.enable_sequence_parallel else [None] + else: + return [] + elif is_embedding(self.node): + input_arg = self.node.all_input_nodes[0] + axis = self.extract_axis(input_arg) + if axis is None: + # only support the embedding parameter being parallelized on `vocab` dim or not parallelized for now, + # the output can be parallelized on sequence dim or not parallelized + return [1, None] if self.config.enable_sequence_parallel else [None] + else: + return [] + elif is_activation(self.node): + return UnaryOpParallelAxisPropagateHandler(self.node, self.meta_key, self.config).propagate() + + # last resort, if no input is being parallelized, then we make output also not parallelized, + # this will give us relief on writing policies for strange ops which don't actually need + # parallelization in most cases + if all(self.extract_axis(arg) is None for arg in self.node.all_input_nodes): + return [None] + + raise NotImplementedError(f"don't know how to propagate axis for {self.node.target}") diff --git a/optimum/fx/parallelization/passes.py b/optimum/fx/parallelization/passes.py index 379b027d40..14b652fff7 100644 --- a/optimum/fx/parallelization/passes.py +++ b/optimum/fx/parallelization/passes.py @@ -23,15 +23,14 @@ from torch.fx import Graph, GraphModule, Node from .core import Config, ParallelExecutionCtx, ParameterMeta +from .decomp import decompose_and_functionalize from .distributed import scatter +from .op_registry import REGISTRY, FallbackParallelAxisPropagateHandler from .parallel_layers import ColumnParallelLinear, RowParallelLinear, VocabParallelEmbedding from .utils import ( is_embedding, is_linear, - is_permute, is_shape_consumer, - is_shape_generator, - is_transpose, stable_topological_sort, ) @@ -135,238 +134,151 @@ def clean_all(self, graph_module: GraphModule) -> None: self.clear_marker_per_node(node) -class ParallelLayerAnnotatePass(AnalyzeBase): +class ParallelAxisSolverPass(AnalyzeBase): """ - A pass which tries to automatically identify parallel layers in the graph. Note that for simplicity - we only consider classical ways of parallelizing layers in transformers architecture for now, we are not - solving an optimization problem which tries to give a best solution of parallelizing any model under - memory/hardware constraints. - - For `nn.Embedding` layers, we parallelize them on the vocabulary dim by default, because they are often tied - to the `lm_head` of the model, which is usually a `ColumnLinear`(parallelized on vocab dim). - - For `nn.Linear` layers, we parallelize them by grouping them as `upstream` nodes and `downstream` nodes, and - `upstream` nodes are marked as `ColumnLinear`, `downstream` nodes are marked as `RowLinear`. - - Typical examples in transformer models: - - Attention Bert-style MLP Llama-style MLP - __________________________________________________________________________ - Linear Linear Linear Linear - \\ / | \\ --> upstream - Matmul Linear Activation Activation Linear - __________________________________________________________________________ - \\ / | \\ / - \\ / ___________ \\ / - Matmul / Linear \ Mul - | / \ | - _______________________________/ \___________________________ - Linear Linear --> downstream - - Note that there are some patterns that can not be clearly marked, like this one: - - Linear - | \\ - | Linear <-- which label should we mark for the intermediate linear, `upstream` or `downstream` - | / - Add - | - Linear - - For patterns like this we will be conservative and raise errors directly because we don't know how to parallelize - it. Another concern is about the correctness, it's possible that we might end up with a wrong parallelization solution - even if the pattern itself is clear, but for now we are mainly targeting on transformer models and the current solution - should work fairly well. + A pass which tries to automatically identify parallel layers in the graph. There are three steps + involved to find a possible parallel solution given the traced graph module and process group. + + - Decompostion & Functionalization + The vanilla graph traced by dynamo frontend is a high-level graph which contains high-level + pytorch ops, and there could be thousands of them, which makes graph analysis hard in order + to cover all cases. So we decompose the high-level graph into low-level graph which only + conrtains core aten ops, which is a much smaller set. And functionalization is also needed + to remove inplace ops in the graph so that we get `aten.Add` instead of `aten.Add_` in the + graph, which furthur reduces the op set we need to consider. + + - Parallel Axis Propagation + We need to write parallel axis propagation rules for aten ops in the decomposed and functionalized + graph, note that we don't need to cover every possible parallelization strategy because in general + only certain ops(usually involves computation) can be parallelized in transformer models. And we just + need to write rules for a subset of core aten op set in order to support most of the transformer models. + + - Backtracking Search + After we have defined parallel axis propagation rules for each op in the graph, we do a brute force + backtracking search to try to find a possible solution which respects the propagation rule of every + op in the graph. + + + Note that there are several practical concerns + + - Time Complexity. Although brute force backtracking introduces an exponential time complexity, we reduces + the search space by injecting human heuristics. First, we only consider parallelization on the head dimension + (for tensor parallel) or the sequence dimension(to support sequence parallel), then at any time the tensor is + parallelized on at most one dimension. Second, we only allow axis switch around certain layers(like `nn.Linear` + or `nn.Embedding), and all other ops fall into their places by the parallel axis of their input and rules we write. + + - Optimal Solution. Note that since we return the first solution we find, then it might not be optimal in terms of + memory consumption and communication overhead. But again we can adjust the order of search and try parallelize + as much as we can first before fall back to non-parallelized search paths. And we don't pay too much attention + on calculating communication overhead because in practice they are bounded under the constraint that only certain + layers are allowed to communicate. + + Our goal is not to solve an optimization problem which tries to give a best solution of parallelizing any model under memory/hardware + constraints, but rather a cheap solution which relieves you from writing boilerplate code for parallelizing layers of different models. """ - def try_form_parallel_linear_groups(self, linear: Node) -> None: - """ - We try to form linears by forming closures in a greedy way, we start with an unmarked linear node, and traverses down - recusively to find all the potential `downstream` linears, note that once we have reached a linear, the recursion stops. - And the newly found `downstream` linears are used as new seeds to traverse upwards to find all the potential `upstream` - linears, the process goes on until number of linears on both sides converges. - Args: - linear (Node): the first linear node used as `upstream` node seed to form closure. - - Raises: - RuntimeError: - raises runtime error when the pattern itself is not clear, there are no clear boundaries that can be drawn. - """ - upstream_nodes, downstream_nodes = {linear}, set() - - seeds, next_seeds = [(linear, "down")], [] - - def traverse(start: Node, cur: Node, direction: str = "down"): - if is_linear(cur) and cur is not start: - if direction == "up" and cur not in upstream_nodes: - upstream_nodes.add(cur) - next_seeds.append((cur, "down")) - elif direction == "down" and cur not in downstream_nodes: - downstream_nodes.add(cur) - next_seeds.append((cur, "up")) - return - - next_nodes = cur.all_input_nodes if direction == "up" else cur.users - for node in next_nodes: - # we should ignore shape-related dependencies - if is_shape_generator(node): - continue - traverse(start, node, direction) - - while seeds: - next_seeds = [] - for node, direction in seeds: - traverse(start=node, cur=node, direction=direction) - seeds = next_seeds - - if any(self.already_executed_per_node(node) for node in (upstream_nodes | downstream_nodes)) or ( - upstream_nodes & downstream_nodes - ): - raise RuntimeError( - "Failed to automatically group and parallelize ops in graph in greedy way: " - "no clear boudaries between `upstream` and `downstream` ops." - ) - - for node in upstream_nodes: - self.place_marker_per_node(node, {"axis": "column", "gather_output": False if downstream_nodes else True}) - - for node in downstream_nodes: - self.place_marker_per_node(node, {"axis": "row", "input_is_parallel": True}) + def trace_back(self, graph_module: GraphModule, decomp_graph: Graph) -> None: + node_map = {node.name: node for node in graph_module.graph.nodes} + + for node in decomp_graph.nodes: + if "traced_from" in node.meta: + node_name, _ = node.meta["traced_from"][0] + assert node_name in node_map, f"un-recognized node origin {node_name} not in graph being traced" + orig_node = node_map[node_name] + self.clear_marker_per_node(orig_node) + self.place_marker_per_node( + orig_node, {"parallel_axis": self.get_stored_field_info(node, field="parallel_axis")} + ) def run(self, graph_module: GraphModule, ctx: ParallelExecutionCtx, config: Config) -> GraphModule: - graph: Graph = graph_module.graph + graph: Graph = decompose_and_functionalize(graph_module)(*ctx.example_inputs) stable_topological_sort(graph) - for node in graph.nodes: - if is_linear(node) and not self.already_executed_per_node(node): - self.try_form_parallel_linear_groups(node) - elif is_embedding(node): - # directly mark `nn.Embedding` layers - self.place_marker_per_node(node, {"axis": "vocab"}) - return graph_module + nodes = list(graph.nodes) + def search(idx: int): + if idx == len(nodes): + return True -class ParallelAxisPropagationPass(AnalyzeBase): - """ - A pass which tries to track which axis is being parallelized in the dataflow. For transformer models, the - axis being paralled for tensor parallism is almost always 2, i.e., the attention head axis, except for - Q and K matrices which need to swap the sequence length axis and head axis to do the attention computation, - so we focus on operations like `transpose` or `permute` which swaps axis, and try inducting the parallel - axis after these operations. - """ + node = nodes[idx] + if node.op == "call_function" and REGISTRY.is_supported(node.target): + prop_cls = REGISTRY.mapping[node.target] + else: + prop_cls = FallbackParallelAxisPropagateHandler - def propagate_transpose(self, node: Node, parallel_axis: int) -> bool: - dims = node.meta["example_value"].dim() - if "dim0" in node.kwargs and "dim1" in node.kwargs: - dim0, dim1 = node.kwargs["dim0"], node.kwargs["dim1"] - elif len(node.args) == 3: - dim0, dim1 = node.args[1:] - - dim0 = (dim0 + dims) % dims - dim1 = (dim1 + dims) % dims - - if dim0 == parallel_axis: - self.place_marker_per_node(node, {"parallel_axis": dim1}) - return True - elif dim1 == parallel_axis: - self.place_marker_per_node(node, {"parallel_axis": dim0}) - return True - return False - - def propagate_permute(self, node: Node, parallel_axis: int) -> bool: - if "dims" in node.kwargs: - dims = node.kwargs["dims"] - else: - dims = ( - list(node.args[1]) - if isinstance(node.args[1], tuple) - else [arg for arg in node.args if isinstance(arg, int)] - ) + prop = prop_cls(node, self.meta_key(), config) + axis_candidates = prop.propagate() + for axis in axis_candidates: + self.place_marker_per_node(node, {"parallel_axis": axis}) + if search(idx + 1): + return True + self.clear_marker_per_node(node) - dim_len = node.meta["example_value"].dim() - dims = [dim + dim_len if dim < 0 else dim for dim in dims] + return False - for i, dim in enumerate(dims): - if dim == parallel_axis: - self.place_marker_per_node(node, {"parallel_axis": i}) - return True - return False - - def propagate_getitem(self, node: Node, parallel_axis: int) -> bool: - slices = node.args[1] - dims = node.meta["example_value"].dim() - assert parallel_axis < dims - inc, i, j = 0, 0, 0 - - while i < parallel_axis and j < len(slices): - if isinstance(slices[j], int): - inc -= 1 - i += 1 - elif slices[j] is None: - inc += 1 - elif slices[j] is Ellipsis: - i = dims - k = j - while k < len(slices): - if slices[k] is not Ellipsis: - i -= 1 - k += 1 - else: - i += 1 - j += 1 + if not search(0): + raise RuntimeError("Failed to find a solution to automatically parallelize ops in graph in greedy way.") - if inc != 0: - assert parallel_axis + inc < dims and parallel_axis + inc >= 0 - self.place_marker_per_node(node, {"parallel_axis": parallel_axis + inc}) - return True - return False + self.trace_back(graph_module, graph) + return graph_module + + +class ParallelLayerAnnotatePass(AnalyzeBase): + """ + This pass annotates layers which have different parallel axis(requires communication inside the layer) in their + input and output tensors. Since heuristics applied during the searching process respect traditional classical ways of + parallelizing layers(like Megatron-style `ColumnLinear` or `RowLinear`), we are guaranteed to match a valid replacement + annotation according to parallelization strategy of input and output tensors. + """ def run(self, graph_module: GraphModule, ctx: ParallelExecutionCtx, config: Config) -> GraphModule: - g: Graph = graph_module.graph - stable_topological_sort(g) + for node in graph_module.graph.nodes: + if is_linear(node): + axis_before = ParallelAxisSolverPass.get_stored_field_info(node.args[0], "parallel_axis") + axis_after = ParallelAxisSolverPass.get_stored_field_info(node, "parallel_axis") + info = {} + if axis_before is None: + info["axis"] = "column" + info["gather_output"] = True if axis_after is None else False + elif axis_before == 1: + assert ( + config.enable_sequence_parallel + ), "illegal parallel axis for sequence parallelism deactivated setting" + info["axis"] = "column" + info["sequence_parallel"] = True + info["gather_output"] = True if axis_after is None else False + elif axis_before == 2: + info["axis"] = "row" + info["input_is_parallel"] = True + if axis_after == 1: + assert ( + config.enable_sequence_parallel + ), "illegal parallel axis for sequence parallelism deactivated setting" + info["sequence_parallel"] = True + else: + info["sequence_parallel"] = False + self.place_marker_per_node(node, info) - for node in g.nodes: - if ParallelLayerAnnotatePass.already_executed_per_node(node): - # start propagating at ColumnLinear, marking the beginning of parallelized region - axis = ParallelLayerAnnotatePass.get_stored_field_info(node, field="axis", must_have=True) - gather_output = ParallelLayerAnnotatePass.get_stored_field_info(node, field="gather_output") - if axis == "column" and not gather_output: - self.place_marker_per_node(node, {"parallel_axis": 2}) - # stop propagating at RowLinear, concluding the ending of parallelized region + elif is_embedding(node): + axis_before = ParallelAxisSolverPass.get_stored_field_info(node.args[0], "parallel_axis") + axis_after = ParallelAxisSolverPass.get_stored_field_info(node, "parallel_axis") + assert axis_before is None and axis_after in [1, None] + info = {"axis": "vocab"} + if axis_after == 1: + assert ( + config.enable_sequence_parallel + ), "illegal parallel axis for sequence parallelism deactivated setting" + info["sequence_parallel"] = True else: - continue - else: - already_marked_args, parallel_axis = [], None - for arg in node.all_input_nodes: - if not self.already_executed_per_node(arg): - continue - if parallel_axis is None: - parallel_axis = self.get_stored_field_info(arg, field="parallel_axis", must_have=True) - else: - assert parallel_axis == self.get_stored_field_info( - arg, field="parallel_axis", must_have=True - ), "`parallel_axis` should be equal for all arguments in any related ops" - already_marked_args.append(arg) - - if not already_marked_args: - continue - - marked = False - if is_transpose(node): - marked = self.propagate_transpose(node, parallel_axis) - elif is_permute(node): - marked = self.propagate_permute(node, parallel_axis) - - # fall back - if not marked: - self.place_marker_per_node(node, {"parallel_axis": parallel_axis}) + info["sequence_parallel"] = False + self.place_marker_per_node(node, info) + return graph_module class ParallelLayerReplacePass(PassBase): """ - A pass which modifies graph according to information provided by previous analytical passes, - in general it does two things for now: + A pass which modifies graph according to information provided by previous analytical passes, in general it does two things for now: 1. replaces linears and embedding layers with their parallel counterparts. 2. modifies hard-coded arguments like the number of attention heads in the graph by dividing it by parallelism level. """ @@ -453,7 +365,7 @@ def update(node: Node, new_shape: List[Any], parallel_axis: int): else: node.update_arg(parallel_axis + 1, shape[parallel_axis]) - parallel_axis = ParallelAxisPropagationPass.get_stored_field_info(node, field="parallel_axis") + parallel_axis = ParallelAxisSolverPass.get_stored_field_info(node, field="parallel_axis") if parallel_axis is None: return @@ -582,18 +494,18 @@ def run(self, graph_module: GraphModule, ctx: ParallelExecutionCtx, config: Conf def build_parallel_pass_pipeline() -> PassPipeline: """ Ensemble a pass pipeline which contains the following passes: - 1. `ParallelLayerAnnotatePass` to annoate which linears are `ColumnLinear`, which are `RowLinear` - 2. `ParallelAxisPropagationPass` to propate parallel axis along the data flow - 3. `ParallelLinearReplacePass` to do the actual replacement and modification of hard-coded attributes - 4. `InitializeOrLoadWeightsPass` to load or initialize weights for parameters + 1. `ParallelAxisSolverPass` to find a parallelization solution of tensors in the graph. + 2. `ParallelLayerAnnotatePass` to annotate parallelized layers according to the solution found in the first step. + 3. `ParallelLinearReplacePass` to do the actual replacement and modification of hard-coded attributes. + 4. `InitializeOrLoadWeightsPass` to load or initialize weights for parameters. Returns: PassPipeline: the pipeline used for automatic parallelism. """ return PassPipeline( [ + ParallelAxisSolverPass(), ParallelLayerAnnotatePass(), - ParallelAxisPropagationPass(), ParallelLayerReplacePass(), InitializeOrLoadWeightsPass(), ] diff --git a/optimum/fx/parallelization/utils.py b/optimum/fx/parallelization/utils.py index f129ffbd40..b7b1ccd41c 100644 --- a/optimum/fx/parallelization/utils.py +++ b/optimum/fx/parallelization/utils.py @@ -17,7 +17,6 @@ import hashlib import importlib import json -import operator import os import re import tempfile @@ -45,6 +44,14 @@ def ensure_divisibility(numerator: int, denominator: int) -> None: ) +def is_activation(node: Node) -> bool: + # only consider leaf Module activations + if node.op != "call_module": + return False + mod = node.graph.owning_module + return getattr(mod.get_submodule(node.target), "__module__", "").startswith("torch.nn.modules.activation") + + def is_linear(node: Node) -> bool: if node.op != "call_module": return False @@ -67,26 +74,6 @@ def is_shape_consumer(node: Node) -> bool: return False -def is_transpose(node: Node) -> bool: - if node.op == "call_method": - return node.target in {"transpose", "transpose_"} - elif node.op == "call_function": - return node.target is torch.transpose - return False - - -def is_permute(node: Node) -> bool: - if node.op == "call_method": - return node.target in {"permute"} - elif node.op == "call_function": - return node.target is torch.permute - return False - - -def is_getitem(node: Node) -> bool: - return node.op == "call_function" and node.target is operator.getitem - - def is_output(node: Node) -> bool: return node.op == "output"