From e2328837df571befb631fc0167c259537616c359 Mon Sep 17 00:00:00 2001 From: ANDREW Date: Tue, 22 Oct 2024 19:59:42 -0400 Subject: [PATCH 1/2] [models]: added all but conv2d predictors --- .gitattributes | 8 + .../distributed/_tools/a100_models/bmm.joblib | 3 + .../distributed/_tools/a100_models/mm.joblib | 3 + .../_tools/a100_models/sdpa.joblib | 3 + .../_tools/a100_models/sdpa_backward.joblib | 3 + .../distributed/_tools/h100_models/bmm.joblib | 3 + .../distributed/_tools/h100_models/mm.joblib | 3 + .../_tools/h100_models/sdpa.joblib | 3 + .../_tools/h100_models/sdpa_backward.joblib | 3 + torch/distributed/_tools/runtime_estimator.py | 460 +++++++++++++++--- 10 files changed, 437 insertions(+), 55 deletions(-) create mode 100644 torch/distributed/_tools/a100_models/bmm.joblib create mode 100644 torch/distributed/_tools/a100_models/mm.joblib create mode 100644 torch/distributed/_tools/a100_models/sdpa.joblib create mode 100644 torch/distributed/_tools/a100_models/sdpa_backward.joblib create mode 100644 torch/distributed/_tools/h100_models/bmm.joblib create mode 100644 torch/distributed/_tools/h100_models/mm.joblib create mode 100644 torch/distributed/_tools/h100_models/sdpa.joblib create mode 100644 torch/distributed/_tools/h100_models/sdpa_backward.joblib diff --git a/.gitattributes b/.gitattributes index e904301752950..3a9521336f1e2 100644 --- a/.gitattributes +++ b/.gitattributes @@ -5,3 +5,11 @@ .github/scripts/gql_mocks.json linguist-generated=true third_party/LICENSES_BUNDLED.txt linguist-generated=true tools/build/bazel/requirements.txt linguist-generated=true +torch/distributed/_tools/a100_models/bmm.joblib filter=lfs diff=lfs merge=lfs -text +torch/distributed/_tools/a100_models/mm.joblib filter=lfs diff=lfs merge=lfs -text +torch/distributed/_tools/a100_models/sdpa.joblib filter=lfs diff=lfs merge=lfs -text +torch/distributed/_tools/a100_models/sdpa_backward.joblib filter=lfs diff=lfs merge=lfs -text +torch/distributed/_tools/h100_models/sdpa.joblib filter=lfs diff=lfs merge=lfs -text +torch/distributed/_tools/h100_models/sdpa_backward.joblib filter=lfs diff=lfs merge=lfs -text +torch/distributed/_tools/h100_models/bmm.joblib filter=lfs diff=lfs merge=lfs -text +torch/distributed/_tools/h100_models/mm.joblib filter=lfs diff=lfs merge=lfs -text diff --git a/torch/distributed/_tools/a100_models/bmm.joblib b/torch/distributed/_tools/a100_models/bmm.joblib new file mode 100644 index 0000000000000..4c0d2073d622c --- /dev/null +++ b/torch/distributed/_tools/a100_models/bmm.joblib @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:fe40334437e5ad1a4cd2d496c6eb157481b114418408d738e6a7b9a4c99ab530 +size 78938299 diff --git a/torch/distributed/_tools/a100_models/mm.joblib b/torch/distributed/_tools/a100_models/mm.joblib new file mode 100644 index 0000000000000..b0acfd47b1b04 --- /dev/null +++ b/torch/distributed/_tools/a100_models/mm.joblib @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:00c789a3e8a99dd76c59619dc245b367fecf5c59250c6aa919c8b12cde413338 +size 80274914 diff --git a/torch/distributed/_tools/a100_models/sdpa.joblib b/torch/distributed/_tools/a100_models/sdpa.joblib new file mode 100644 index 0000000000000..8bfc8fd5c9069 --- /dev/null +++ b/torch/distributed/_tools/a100_models/sdpa.joblib @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2c58e2fb8637f68562ba0ad16e61fd9dd499b2ab56617b072c14957c86333e0f +size 146079412 diff --git a/torch/distributed/_tools/a100_models/sdpa_backward.joblib b/torch/distributed/_tools/a100_models/sdpa_backward.joblib new file mode 100644 index 0000000000000..5814fda0fa6ae --- /dev/null +++ b/torch/distributed/_tools/a100_models/sdpa_backward.joblib @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:35a6edecc6eb2297a34ee1661f7c796c3f6e68a02cd91cccdfb30c3e6ad6e26c +size 165674354 diff --git a/torch/distributed/_tools/h100_models/bmm.joblib b/torch/distributed/_tools/h100_models/bmm.joblib new file mode 100644 index 0000000000000..81d064a32d971 --- /dev/null +++ b/torch/distributed/_tools/h100_models/bmm.joblib @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:516ed420b674477cda3b3821b30dbe64c3251199712e418864a0307c8ef6d10f +size 83504981 diff --git a/torch/distributed/_tools/h100_models/mm.joblib b/torch/distributed/_tools/h100_models/mm.joblib new file mode 100644 index 0000000000000..ec87f7de15bf6 --- /dev/null +++ b/torch/distributed/_tools/h100_models/mm.joblib @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:130a7338e1e88fab9465f66ba4693e24290061908108cd37e8be1dd7faba7339 +size 97798975 diff --git a/torch/distributed/_tools/h100_models/sdpa.joblib b/torch/distributed/_tools/h100_models/sdpa.joblib new file mode 100644 index 0000000000000..9e0be5b6f9277 --- /dev/null +++ b/torch/distributed/_tools/h100_models/sdpa.joblib @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f9cc36c6e6734a28860fcb6ec18ed2901bef9740d9cfd5385df8eea80d88b53b +size 175298513 diff --git a/torch/distributed/_tools/h100_models/sdpa_backward.joblib b/torch/distributed/_tools/h100_models/sdpa_backward.joblib new file mode 100644 index 0000000000000..4ac20630f1d98 --- /dev/null +++ b/torch/distributed/_tools/h100_models/sdpa_backward.joblib @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0d60cce92bdc450be21d3fe3980dd93012d218754ce0b505784b45b0a2837558 +size 185574443 diff --git a/torch/distributed/_tools/runtime_estimator.py b/torch/distributed/_tools/runtime_estimator.py index 87f4d3f36b60e..fa3b6b3f738d2 100644 --- a/torch/distributed/_tools/runtime_estimator.py +++ b/torch/distributed/_tools/runtime_estimator.py @@ -1,6 +1,10 @@ # Owner(s): ["module: unknown"] import math import os +import joblib +import subprocess +import numpy as np +import time from collections import defaultdict from typing import Any, Callable, Dict, List, Set, Tuple from typing_extensions import Self @@ -73,6 +77,13 @@ _IGNORE_OPS = _VIEW_OPS | _CREATE_OPS +# Similar to `flop_registry`, stores the functions that make predictions +_LEARNED_OPS: Dict[Any, Any] = {} + +# Caches the learned models that predict ops' runtimes. +_LEARNED_OPS_PREDICTORS: Dict[str, Any] = {} + + __all__ = ["RuntimeEstimator"] @@ -144,6 +155,22 @@ def __init__(self) -> None: self.mod_bw_post_order: List[str] = [] self.total_runtime: float = 0.0 + self.gpu_type = self.get_device_type() + + def get_device_type(self) -> int: + try: + result = subprocess.check_output(['nvidia-smi', '--query-gpu=name', '--format=csv,noheader']) + gpu_name = result.decode('utf-8').strip() + + if "A100" in gpu_name: + return "a100" + elif "H100" in gpu_name: + return "h100" + else: + raise ValueError("GPU type not supported") + except subprocess.CalledProcessError as e: + raise ValueError("Error retrieving GPU name") + # Adapted from: https://github.com/pytorch/pytorch/blob/9b902b3ee3bd608a19543362b66bf06c373dd374/torch/_subclasses/fake_tensor.py#L1969 # noqa: PGH004,B950 # NB: returns fake tensors @classmethod @@ -275,7 +302,7 @@ def _benchmark_estimate(cls, func, args, kwargs) -> Tuple[Any, float]: # type: # Adapted from: https://github.com/pytorch/pytorch/blob/9b902b3ee3bd608a19543362b66bf06c373dd374/torch/_inductor/scheduler.py#L589 # noqa: PGH004,B950 @classmethod - def _roofline_estimate(cls, func, args, kwargs) -> Tuple[Any, float]: # type: ignore[no-untyped-def] + def _get_transfer_time(cls, flat_args_kwargs, flat_outs) -> float: # type: ignore[no-untyped-def] """ Estimates the runtime of a function using a roofline cost model. @@ -309,62 +336,72 @@ def get_num_bytes(t: torch.Tensor) -> int: ) return mem_consumed - def get_compute_time(func_packet, args, kwargs, out, out_dtypes) -> float: # type: ignore[no-untyped-def] - """ - Estimates the compute time of an aten operator. + gpu_memory_bandwidth = get_gpu_dram_gbps() + read_bytes = sum( + get_num_bytes(t) + for t in flat_args_kwargs + if isinstance(t, torch.Tensor) + ) + write_bytes = sum( + get_num_bytes(t) for t in flat_outs if isinstance(t, torch.Tensor) + ) + counted_bytes = read_bytes + write_bytes + # The GPU memory bandwidth is in GB/s so the transfer time is in nanoseconds + transfer_time = counted_bytes / gpu_memory_bandwidth + return transfer_time - Args: - func_packet: The operator overload packet. - args: The arguments to the operator. - kwargs: The keyword arguments to the operator. - out: The output of the operator. - out_dtypes: The output data types. + @classmethod + def _get_compute_time(cls, func_packet, args, kwargs, out, out_dtypes) -> float: # type: ignore[no-untyped-def] + """ + Estimates the compute time of an aten operator. - Returns: - float: The estimated compute time in nanoseconds. - """ - if func_packet in flop_registry: - assert ( - len(out_dtypes) == 1 - ), f"Only support single out dtype got {out_dtypes} for {func_packet}" - dtype = out_dtypes.pop() - # This actually gives peta-FLOPs/s hence multiply by 1e15 to get the FLOPs/s - peak_gpu_flops = get_device_tflops(dtype) * 1e15 - # We can expect to achieve 75% of theoretical peak flops - factor = 0.75 - peak_empirical_flops = factor * peak_gpu_flops - flop_count_func = flop_registry[func_packet] - # We divide by a factor of 2 to get the MACs (multiply and accumulate) - flop_count = flop_count_func(*args, **kwargs, out_val=out) / 2 - # We multiply by 1e9 to get the time in nano seconds - compute_time = (flop_count / peak_empirical_flops) * 1e9 - return compute_time - return 0.0 - - def get_transfer_time(flat_args_kwargs, flat_outs) -> float: # type: ignore[no-untyped-def] - """ - Estimates the memory transfer time of input and output tensors. + Args: + func_packet: The operator overload packet. + args: The arguments to the operator. + kwargs: The keyword arguments to the operator. + out: The output of the operator. + out_dtypes: The output data types. - Args: - flat_args_kwargs (List[torch.Tensor]): The flat list of arguments and keyword arguments. - flat_outs (List[torch.Tensor]): The flat list of outputs. + Returns: + float: The estimated compute time in nanoseconds. + """ + if func_packet in flop_registry: + assert ( + len(out_dtypes) == 1 + ), f"Only support single out dtype got {out_dtypes} for {func_packet}" + dtype = out_dtypes.pop() + # This actually gives peta-FLOPs/s hence multiply by 1e15 to get the FLOPs/s + peak_gpu_flops = get_device_tflops(dtype) * 1e15 + # We can expect to achieve 75% of theoretical peak flops + factor = 0.75 + peak_empirical_flops = factor * peak_gpu_flops + flop_count_func = flop_registry[func_packet] + # We divide by a factor of 2 to get the MACs (multiply and accumulate) + flop_count = flop_count_func(*args, **kwargs, out_val=out) / 2 + # We multiply by 1e9 to get the time in nano seconds + compute_time = (flop_count / peak_empirical_flops) * 1e9 + return compute_time + return 0.0 - Returns: - float: The estimated memory transfer time in nanoseconds. - """ - gpu_memory_bandwidth = get_gpu_dram_gbps() - read_bytes = sum( - get_num_bytes(t) - for t in flat_args_kwargs - if isinstance(t, torch.Tensor) - ) - write_bytes = sum( - get_num_bytes(t) for t in flat_outs if isinstance(t, torch.Tensor) - ) - counted_bytes = read_bytes + write_bytes - # The GPU memory bandwidth is in GB/s so the transfer time is in nanoseconds - transfer_time = counted_bytes / gpu_memory_bandwidth - return transfer_time + # Adapted from: https://github.com/pytorch/pytorch/blob/9b902b3ee3bd608a19543362b66bf06c373dd374/torch/_inductor/scheduler.py#L589 # noqa: PGH004,B950 + @classmethod + def _roofline_estimate(cls, func, args, kwargs) -> Tuple[Any, float]: # type: ignore[no-untyped-def] + """ + Estimates the runtime of a function using a roofline cost model. + + Args: + func: The function to estimate. + args: The arguments to pass to the function. + kwargs: The keyword arguments to pass to the function. + out: The output of the function. + + Returns: + Tuple[Any, float]: A tuple containing the result of the function and + the mean operation time in milliseconds. + """ + assert ( + torch.cuda.is_available() + ), "Roofline estimation needs to access CUDA capabilities to make estimations" # Roofline Cost Model Explanation @@ -397,7 +434,7 @@ def get_transfer_time(flat_args_kwargs, flat_outs) -> float: # type: ignore[no- if func_packet not in _IGNORE_OPS: flat_args_kwargs, args_spec = pytree.tree_flatten((args, kwargs)) flat_outs, out_spec = pytree.tree_flatten(out) - transfer_time = get_transfer_time(flat_args_kwargs, flat_outs) + transfer_time = cls._get_transfer_time(flat_args_kwargs, flat_outs) out_dtypes = { t.dtype @@ -408,7 +445,320 @@ def get_transfer_time(flat_args_kwargs, flat_outs) -> float: # type: ignore[no- args, kwargs = pytree.tree_unflatten(flat_args_kwargs, args_spec) out = pytree.tree_unflatten(flat_outs, out_spec) - compute_time = get_compute_time(func_packet, args, kwargs, out, out_dtypes) + compute_time = cls._get_compute_time(func_packet, args, kwargs, out, out_dtypes) + # We get the estimated time as the max of the transfer time and + # compute time. We divide by 1e6 to get the time in ms + op_time = max(transfer_time, compute_time) / 1e6 + + return (out, op_time) + + @classmethod + def _learned_estimate_predictor(cls, func_packet, args, kwargs, out, out_dtypes) -> float: # type: ignore[no-untyped-def] + """ + TODO: + 1) the order of the features + 2) where the models are stored + + + Estimates the compute time of an aten operator. + + Args: + func_packet: The operator overload packet. + args: The arguments to the operator. + kwargs: The keyword arguments to the operator. + out: The output of the operator. + out_dtypes: The output data types. + + Returns: + float: The estimated compute time in nanoseconds. + + + # TODO: comments. + Note: for the prediction functions, we mimic the arguments for mm_flop. + """ + def get_learned_model(op: str) -> Any: + if op not in _LEARNED_OPS_PREDICTORS: + base_dir = os.path.join(os.getcwd()) + path = os.path.join(base_dir, f"{cls.gpu_type}_models", f"{op}.joblib") + + _LEARNED_OPS_PREDICTORS[op] = joblib.load(path) + return _LEARNED_OPS_PREDICTORS[op] + + + from functools import wraps + from torch.utils._pytree import tree_map + + def get_shape(i): + if isinstance(i, torch.Tensor): + return i.shape + return i + + def shape_wrapper(f): + """ + Similar to flop_counter.shape_wrapper(), but also takes takes gflops. + """ + @wraps(f) + def nf(dtype, gflops, *args, out_val=None, **kwargs): + args, kwargs, out_shape = tree_map(get_shape, (args, kwargs, out_val)) + return f(dtype, gflops, *args, out_shape=out_shape, **kwargs) + return nf + + def register_timing_formula(targets, get_raw=False): + """ + Similar to flop_counter.register_flop_formula(). + """ + def register_fun(flop_formula): + if not get_raw: + flop_formula = shape_wrapper(flop_formula) + + def register(target): + if not isinstance(target, torch._ops.OpOverloadPacket): + raise ValueError( + f"register_flop_formula(targets): expected each target to be " + f"OpOverloadPacket (i.e. torch.ops.mylib.foo), got " + f"{target} which is of type {type(target)}") + if target in flop_registry: + raise RuntimeError(f"duplicate registrations for {target}") + flop_registry[target] = flop_formula + + # To handle allowing multiple aten_ops at once + torch.utils._pytree.tree_map_(register, targets) + + return flop_formula + + return register_fun + + + def convert_dtype(dtype) -> list[int]: + """ + To use dtype in a learned model, we convert them to one-hot encodings. + + Learned model supports the dtypes: + - torch.float16 + - torch.float32 + - torch.bfloat16 + """ + dtypes = [torch.float16, torch.float32, torch.bfloat16] + return [1 if dtype == d else 0 for d in dtypes] + + @register_timing_formula([aten.mm, aten.addmm]) + def mm_time(dtype, gflops, a_shape, b_shape, *args, out_shape=None, **kwargs) -> float: + model = get_learned_model("mm") + + m, n = a_shape + n2, p = b_shape + assert n == n2 + + features = np.array([[m, n, p, gflops] + convert_dtype(dtype)]) + return model.predict(features) + + @register_timing_formula([aten.bmm, aten.baddmm]) + def bmm_time(dtype, gflops, a_shape, b_shape, out_shape=None, **kwargs) -> float: + model = get_learned_model("bmm") + + b, m, n = a_shape + b2, n2, p = b_shape + assert b == b2 and n == n2 + + features = np.array([[b, m, n, p, gflops] + convert_dtype(dtype)]) + return model.predict(features) + + def is_causal_sdpa(args: tuple) -> bool: + """ + TODO: the way that flop_counter implements sdpa args/kwargs—namely, `is_causal`—should be updated. + This is a heuristic hackaround. + """ + if len(args) >= 2: + if args[0] is not None: + return True + elif len(args) > 2: + if isinstance(args[-1], bool): + return args[-1] + return False + + @register_timing_formula(aten._scaled_dot_product_cudnn_attention) + def sdpa_cudnn_time(dtype, gflops, query_shape, key_shape, value_shape, *args, out_shape=None, **kwargs) -> float: + model = get_learned_model("sdpa") + + b, h, s_q, d_qk = query_shape + _b2, _h2, s_kv, _d2 = key_shape + _b3, _h3, _s3, d_v = value_shape + assert b == _b2 == _b3 and h == _h2 == _h3 and d_qk == _d2 and s_kv == _s3 and d_qk == _d2 + + backends_ohe = [1, 0, 0] + is_causal_ohe = [0, 1] if is_causal_sdpa(args) else [1, 0] + features = np.array([[b, h, s_q, s_kv, d_qk, d_v, gflops] + convert_dtype(dtype) + backends_ohe + is_causal_ohe]) + return model.predict(features) + + @register_timing_formula(aten._scaled_dot_product_efficient_attention) + def sdpa_efficient_time(dtype, gflops, query_shape, key_shape, value_shape, *args, out_shape=None, **kwargs) -> float: + model = get_learned_model("sdpa") + + b, h, s_q, d_qk = query_shape + _b2, _h2, s_kv, _d2 = key_shape + _b3, _h3, _s3, d_v = value_shape + assert b == _b2 == _b3 and h == _h2 == _h3 and d_qk == _d2 and s_kv == _s3 and d_qk == _d2 + + backends_ohe = [0, 1, 0] + is_causal_ohe = [0, 1] if is_causal_sdpa(args) else [1, 0] + features = np.array([[b, h, s_q, s_kv, d_qk, d_v, gflops] + convert_dtype(dtype) + backends_ohe + is_causal_ohe]) + return model.predict(features) + + @register_timing_formula(aten._scaled_dot_product_flash_attention) + def sdpa_flash_time(dtype, gflops, query_shape, key_shape, value_shape, *args, out_shape=None, **kwargs) -> float: + model = get_learned_model("sdpa") + + b, h, s_q, d_qk = query_shape + _b2, _h2, s_kv, _d2 = key_shape + _b3, _h3, _s3, d_v = value_shape + assert b == _b2 == _b3 and h == _h2 == _h3 and d_qk == _d2 and s_kv == _s3 and d_qk == _d2 + + backends_ohe = [0, 0, 1] + is_causal_ohe = [0, 1] if is_causal_sdpa(args) else [1, 0] + features = np.array([[b, h, s_q, s_kv, d_qk, d_v, gflops] + convert_dtype(dtype) + backends_ohe + is_causal_ohe]) + return model.predict(features) + + @register_timing_formula(aten._scaled_dot_product_cudnn_attention_backward) + def sdpa_backward_cudnn_time(dtype, gflops, grad_out_shape, query_shape, key_shape, value_shape, *args, out_shape=None, **kwargs) -> float: + model = get_learned_model("sdpa_backward") + + b, h, s_q, d_qk = query_shape + _b2, _h2, s_kv, _d2 = key_shape + _b3, _h3, _s3, d_v = value_shape + _b4, _h4, _s4, _d4 = grad_out_shape + assert b == _b2 == _b3 == _b4 and h == _h2 == _h3 == _h4 and d_qk == _d2 + assert d_v == _d4 and s_kv == _s3 and s_q == _s4 + + backends_ohe = [1, 0, 0] + is_causal_ohe = [0, 1] if is_causal_sdpa(args) else [1, 0] + features = np.array([[b, h, s_q, s_kv, d_qk, d_v, gflops] + convert_dtype(dtype) + backends_ohe + is_causal_ohe]) + return model.predict(features) + + @register_timing_formula(aten._scaled_dot_product_efficient_attention_backward) + def sdpa_backward_efficient_time(dtype, gflops, grad_out_shape, query_shape, key_shape, value_shape, *args, out_shape=None, **kwargs) -> float: + model = get_learned_model("sdpa_backward") + + b, h, s_q, d_qk = query_shape + _b2, _h2, s_kv, _d2 = key_shape + _b3, _h3, _s3, d_v = value_shape + _b4, _h4, _s4, _d4 = grad_out_shape + assert b == _b2 == _b3 == _b4 and h == _h2 == _h3 == _h4 and d_qk == _d2 + assert d_v == _d4 and s_kv == _s3 and s_q == _s4 + + backends_ohe = [0, 1, 0] + is_causal_ohe = [0, 1] if is_causal_sdpa(args) else [1, 0] + features = np.array([[b, h, s_q, s_kv, d_qk, d_v, gflops] + convert_dtype(dtype) + backends_ohe + is_causal_ohe]) + return model.predict(features) + + @register_timing_formula(aten._scaled_dot_product_flash_attention_backward) + def sdpa_backward_flash_time(dtype, gflops, grad_out_shape, query_shape, key_shape, value_shape, *args, out_shape=None, **kwargs) -> float: + model = get_learned_model("sdpa_backward") + + b, h, s_q, d_qk = query_shape + _b2, _h2, s_kv, _d2 = key_shape + _b3, _h3, _s3, d_v = value_shape + _b4, _h4, _s4, _d4 = grad_out_shape + assert b == _b2 == _b3 == _b4 and h == _h2 == _h3 == _h4 and d_qk == _d2 + assert d_v == _d4 and s_kv == _s3 and s_q == _s4 + + backends_ohe = [0, 0, 1] + is_causal_ohe = [0, 1] if is_causal_sdpa(args) else [1, 0] + features = np.array([[b, h, s_q, s_kv, d_qk, d_v, gflops] + convert_dtype(dtype) + backends_ohe + is_causal_ohe]) + return model.predict(features) + + # @register_timing_formula([aten.convolution, aten._convolution]) + # def conv_time(dtype, gflops, x_shape, w_shape, _bias, _stride, _padding, _dilation, transposed, *args, out_shape=None, **kwargs) -> float: + # """ + # TODO: need to add support for higher dims. + # """ + # model = get_learned_model("conv") + + # # batch_size = x_shape[0] + # # conv_shape = (x_shape if transposed else out_shape)[2:] + # # c_out, c_in, *filter_size = w_shape + + # # features = np.array([[b, m, k, n, gflops]]) + # return model.predict(features) + + # @register_timing_formula(aten.convolution_backward) + # def conv_backward_time( + # dtype, + # gflops, + # grad_out_shape, + # x_shape, + # w_shape, + # _bias, + # _stride, + # _padding, + # _dilation, + # transposed, + # _output_padding, + # _groups, + # output_mask, + # out_shape) -> float: + # """ + # TODO: need to add support for higher dims. + # """ + # model = get_learned_model("conv_backward") + + # # features = np.array([[b, m, k, n, gflops]]) + # return model.predict(features) + + if func_packet in _LEARNED_OPS: + assert ( + len(out_dtypes) == 1 + ), f"Only support single out dtype got {out_dtypes} for {func_packet}" + dtype = out_dtypes.pop() + + flop_count_func = flop_registry[func_packet] + gflops = flop_count_func(*args, **kwargs, out_val=out) / 1e9 + + predictor_func = _LEARNED_OPS[func_packet] + # Returns compute time in ms, so multiply by 1e6 to get nanoseconds + compute_time = predictor_func(dtype, gflops, *args, **kwargs, out_val=out) + compute_time *= 1e6 + else: + compute_time = cls._get_compute_time(func_packet, args, kwargs, out, out_dtypes) + return 0.0 + + @classmethod + def _learned_estimate(cls, func, args, kwargs) -> Tuple[Any, float]: # type: ignore[no-untyped-def] + """ + Estimates the runtime of a function using a learned estimator. + + Args: + func: The function to estimate. + args: The arguments to pass to the function. + kwargs: The keyword arguments to pass to the function. + res: The result of the function. + + Returns: + Tuple[Any, float]: A tuple containing the result of the function and + the mean operation time in milliseconds. + """ + assert ( + torch.cuda.is_available() + ), "Learned estimator needs to access CUDA capabilities to make estimations" + + kwargs = kwargs if kwargs else {} + out = func(*args, **kwargs) + op_time = 0.0 + func_packet = func._overloadpacket + if func_packet not in _IGNORE_OPS: + flat_args_kwargs, args_spec = pytree.tree_flatten((args, kwargs)) + flat_outs, out_spec = pytree.tree_flatten(out) + transfer_time = cls._get_transfer_time(flat_args_kwargs, flat_outs) + + out_dtypes = { + t.dtype + for t in flat_outs + if isinstance(t, torch.Tensor) and t.dtype in cls._float_types + } + + args, kwargs = pytree.tree_unflatten(flat_args_kwargs, args_spec) + out = pytree.tree_unflatten(flat_outs, out_spec) + + compute_time = cls._learned_estimate_predictor(func_packet, args, kwargs, out, out_dtypes) # We get the estimated time as the max of the transfer time and # compute time. We divide by 1e6 to get the time in ms op_time = max(transfer_time, compute_time) / 1e6 From e889a5e7407beaae5d93c4b52a98887e1dd7b143 Mon Sep 17 00:00:00 2001 From: Andrew Zhao Date: Sat, 26 Oct 2024 12:18:15 -0400 Subject: [PATCH 2/2] [models]: runtime_estimator works --- .../_tools/test_runtime_estimator.py | 20 +- torch/distributed/_tools/runtime_estimator.py | 529 ++++++++++-------- 2 files changed, 302 insertions(+), 247 deletions(-) diff --git a/test/distributed/_tools/test_runtime_estimator.py b/test/distributed/_tools/test_runtime_estimator.py index 400903f17673f..3c41bcbce61d6 100644 --- a/test/distributed/_tools/test_runtime_estimator.py +++ b/test/distributed/_tools/test_runtime_estimator.py @@ -133,6 +133,7 @@ def _init_model_and_args( def test_transformer_runtime( self, ): + print("Transformer Test") """Runs a basic GPT-2 model""" vocab_size = 8192 bsz, seq_len = 8, 1024 @@ -155,14 +156,20 @@ def test_transformer_runtime( roofline_estimate = self._runtime_estimate( "operator-level-cost-model", self._train_step, fake_args ) + learned_estimate = self._runtime_estimate( + "operator-level-learned-model", self._train_step, fake_args + ) benchmark_accuracy = actual_runtime / benchmark_estimate roofline_accuracy = actual_runtime / roofline_estimate + learned_accuracy = actual_runtime / learned_estimate print( f"Actual: {actual_runtime} Benchmark Estimate: {benchmark_estimate} Accuracy: {benchmark_accuracy}" - f"\n Actual: {actual_runtime} Roofline Estimatee: {roofline_estimate} Accuracy: {roofline_accuracy}" + f"\nActual: {actual_runtime} Roofline Estimate: {roofline_estimate} Accuracy: {roofline_accuracy}" + f"\nActual: {actual_runtime} Learned Estimate: {learned_estimate} Accuracy: {learned_accuracy}" ) self.assertAlmostEqual(benchmark_accuracy, 1.0, delta=0.2) self.assertAlmostEqual(roofline_accuracy, 1.0, delta=0.3) + self.assertAlmostEqual(learned_accuracy, 1.0, delta=0.3) @skipIfTorchDynamo("https://github.com/pytorch/pytorch/issues/115653") @unittest.skipIf(not TEST_CUDA, "CUDA not available") @@ -170,6 +177,7 @@ def test_conv_model_runtime( self, ): """Runs a simple CNN model""" + print("CNN Test") num_classes = 100 bsz, img_sz = 256, 128 model_args = ConvArgs(img_sz, num_classes) @@ -183,14 +191,20 @@ def test_conv_model_runtime( roofline_estimate = self._runtime_estimate( "operator-level-cost-model", self._train_step, fake_args ) + learned_estimate = self._runtime_estimate( + "operator-level-learned-model", self._train_step, fake_args + ) benchmark_accuracy = actual_runtime / benchmark_estimate roofline_accuracy = actual_runtime / roofline_estimate + learned_accuracy = actual_runtime / learned_estimate print( - f"Actual: {actual_runtime} Benchmark Estimate: {benchmark_estimate} Accuracy: {benchmark_accuracy}\n" - f"Actual: {actual_runtime} Roofline Estimatee: {roofline_estimate} Accuracy: {roofline_accuracy}" + f"Actual: {actual_runtime} Benchmark Estimate: {benchmark_estimate} Accuracy: {benchmark_accuracy}" + f"\nActual: {actual_runtime} Roofline Estimate: {roofline_estimate} Accuracy: {roofline_accuracy}" + f"\nActual: {actual_runtime} Learned Estimate: {learned_estimate} Accuracy: {learned_accuracy}" ) self.assertAlmostEqual(benchmark_accuracy, 1.0, delta=0.2) self.assertAlmostEqual(roofline_accuracy, 1.0, delta=0.4) + self.assertAlmostEqual(learned_accuracy, 1.0, delta=0.4) if __name__ == "__main__": diff --git a/torch/distributed/_tools/runtime_estimator.py b/torch/distributed/_tools/runtime_estimator.py index fa3b6b3f738d2..3bd434864ba74 100644 --- a/torch/distributed/_tools/runtime_estimator.py +++ b/torch/distributed/_tools/runtime_estimator.py @@ -4,6 +4,7 @@ import joblib import subprocess import numpy as np +import pandas as pd import time from collections import defaultdict from typing import Any, Callable, Dict, List, Set, Tuple @@ -87,6 +88,262 @@ __all__ = ["RuntimeEstimator"] +def get_learned_model(op: str, gpu_type: str) -> Any: + if op not in _LEARNED_OPS_PREDICTORS: + base_dir = os.path.dirname(os.path.realpath(__file__)) + path = os.path.join(base_dir, f"{gpu_type}_models", f"{op}.joblib") + + _LEARNED_OPS_PREDICTORS[op] = joblib.load(path) + return _LEARNED_OPS_PREDICTORS[op] + +from functools import wraps +from torch.utils._pytree import tree_map + +def get_shape(i): + if isinstance(i, torch.Tensor): + return i.shape + return i + +def shape_wrapper(f): + """ + Similar to flop_counter.shape_wrapper(), but also takes takes gflops. + """ + @wraps(f) + def nf(gpu_type, dtype, gflops, *args, out_val=None, **kwargs): + args, kwargs, out_shape = tree_map(get_shape, (args, kwargs, out_val)) + return f(gpu_type, dtype, gflops, *args, out_shape=out_shape, **kwargs) + return nf + +def register_timing_formula(targets, get_raw=False): + """ + Similar to flop_counter.register_flop_formula(). + """ + def register_fun(time_formula): + if not get_raw: + time_formula = shape_wrapper(time_formula) + + def register(target): + if not isinstance(target, torch._ops.OpOverloadPacket): + raise ValueError( + f"register_flop_formula(targets): expected each target to be " + f"OpOverloadPacket (i.e. torch.ops.mylib.foo), got " + f"{target} which is of type {type(target)}") + if target in _LEARNED_OPS: + raise RuntimeError(f"duplicate registrations for {target}") + _LEARNED_OPS[target] = time_formula + + # To handle allowing multiple aten_ops at once + torch.utils._pytree.tree_map_(register, targets) + + return time_formula + + return register_fun + +def convert_dtype(dtype) -> Dict[str, int]: + """ + Convert dtype to a one-hot encoding as a pandas Series. + + Learned model supports the dtypes: + - torch.float16 + - torch.float32 + - torch.bfloat16 + """ + dtypes = [torch.float16, torch.float32, torch.bfloat16] + dtype_one_hot = [1 if dtype == d else 0 for d in dtypes] + dtype_names = ["dtype_16", "dtype_32", "dtype_b16"] + return dict(zip(dtype_names, dtype_one_hot)) + +@register_timing_formula(aten.mm) +def mm_time(gpu_type, dtype, gflops, a_shape, b_shape, *args, out_shape=None, **kwargs) -> float: + model = get_learned_model("mm", gpu_type) + + m, n = a_shape + n2, p = b_shape + assert n == n2 + + dtypes = convert_dtype(dtype) + features = { + "n": n, + "m": m, + "p": p, + "gflops": gflops, + "dtype_16": dtypes["dtype_16"], + "dtype_32": dtypes["dtype_32"], + "dtype_b16": dtypes["dtype_b16"], + } + features_df = pd.DataFrame([features]) + return float(model.predict(features_df)[0]) + +@register_timing_formula(aten.addmm) +def addmm_time(gpu_type, dtype, gflops, self_shape, a_shape, b_shape, out_shape=None, **kwargs) -> float: + return mm_time(gpu_type, dtype, gflops, a_shape, b_shape) + +@register_timing_formula(aten.bmm) +def bmm_time(gpu_type, dtype, gflops, a_shape, b_shape, out_shape=None, **kwargs) -> float: + model = get_learned_model("bmm", gpu_type) + + b, m, n = a_shape + b2, n2, p = b_shape + assert b == b2 and n == n2 + + dtypes = convert_dtype(dtype) + features = { + "b": b, + "n": n, + "m": m, + "p": p, + "gflops": gflops, + "dtype_16": dtypes["dtype_16"], + "dtype_32": dtypes["dtype_32"], + "dtype_b16": dtypes["dtype_b16"], + } + features_df = pd.DataFrame([features]) + return float(model.predict(features_df)[0]) + +@register_timing_formula(aten.baddbmm) +def baddbmm_time(gpu_type, dtype, gflops, self_shape, a_shape, b_shape, out_shape=None, **kwargs) -> float: + return bmm_time(gpu_type, dtype, gflops, a_shape, b_shape) + +def is_causal_sdpa(args: tuple) -> bool: + """ + TODO: the way that flop_counter implements sdpa args/kwargs—namely, `is_causal`—should be updated. + This is a heuristic hackaround. + """ + if len(args) >= 2 and args[0] is not None: + return True + if len(args) > 2 and isinstance(args[-1], bool): + return args[-1] + return False + +def build_sdpa_features(b, h, s_q, s_kv, d_qk, d_v, gflops, dtype, backend, is_causal: bool) -> pd.DataFrame: + if backend == "cudnn": + backends_ohe = [1, 0, 0] + elif backend == "efficient": + backends_ohe = [0, 1, 0] + elif backend == "flash": + backends_ohe = [0, 0, 1] + else: + raise ValueError(f"Unknown backend: {backend}") + + dtypes = convert_dtype(dtype) + is_causal_ohe = [0, 1] if is_causal else [1, 0] + + features = { + "b": b, + "h": h, + "s_q": s_q, + "s_kv": s_kv, + "d_qk": d_qk, + "d_v": d_v, + "gflops": gflops, + "dtype_16": dtypes["dtype_16"], + "dtype_32": dtypes["dtype_32"], + "dtype_b16": dtypes["dtype_b16"], + "backend_cudnn": backends_ohe[0], + "backend_efficient": backends_ohe[1], + "backend_flash": backends_ohe[2], + "is_causal_0": is_causal_ohe[0], + "is_causal_1": is_causal_ohe[1] + } + return pd.DataFrame([features]) + + +def check_sdpa_shapes(query_shape, key_shape, value_shape): + b, h, s_q, d_qk = query_shape + _b2, _h2, s_kv, _d2 = key_shape + _b3, _h3, _s3, d_v = value_shape + assert b == _b2 == _b3 and h == _h2 == _h3 and d_qk == _d2 and s_kv == _s3 and d_qk == _d2 + return b, h, s_q, s_kv, d_qk, d_v + +def check_sdpa_shapes_backward(query_shape, key_shape, value_shape, grad_out_shape): + b, h, s_q, d_qk = query_shape + _b2, _h2, s_kv, _d2 = key_shape + _b3, _h3, _s3, d_v = value_shape + _b4, _h4, _s4, _d4 = grad_out_shape + assert b == _b2 == _b3 == _b4 and h == _h2 == _h3 == _h4 and d_qk == _d2 and d_v == _d4 and s_kv == _s3 and s_q == _s4 + return b, h, s_q, s_kv, d_qk, d_v + +@register_timing_formula(aten._scaled_dot_product_cudnn_attention) +def sdpa_cudnn_time(gpu_type, dtype, gflops, query_shape, key_shape, value_shape, *args, out_shape=None, **kwargs) -> float: + model = get_learned_model("sdpa", gpu_type) + b, h, s_q, s_kv, d_qk, d_v = check_sdpa_shapes(query_shape, key_shape, value_shape) + features = build_sdpa_features(b, h, s_q, s_kv, d_qk, d_v, gflops, dtype, "cudnn", is_causal=is_causal_sdpa(args)) + return float(model.predict(features)[0]) + +@register_timing_formula(aten._scaled_dot_product_efficient_attention) +def sdpa_efficient_time(gpu_type, dtype, gflops, query_shape, key_shape, value_shape, *args, out_shape=None, **kwargs) -> float: + model = get_learned_model("sdpa", gpu_type) + b, h, s_q, s_kv, d_qk, d_v = check_sdpa_shapes(query_shape, key_shape, value_shape) + features = build_sdpa_features(b, h, s_q, s_kv, d_qk, d_v, gflops, dtype, "efficient", is_causal=is_causal_sdpa(args)) + return float(model.predict(features)[0]) + +@register_timing_formula(aten._scaled_dot_product_flash_attention) +def sdpa_flash_time(gpu_type, dtype, gflops, query_shape, key_shape, value_shape, *args, out_shape=None, **kwargs) -> float: + model = get_learned_model("sdpa", gpu_type) + b, h, s_q, s_kv, d_qk, d_v = check_sdpa_shapes(query_shape, key_shape, value_shape) + features = build_sdpa_features(b, h, s_q, s_kv, d_qk, d_v, gflops, dtype, "flash", is_causal=is_causal_sdpa(args)) + return float(model.predict(features)[0]) + +@register_timing_formula(aten._scaled_dot_product_cudnn_attention_backward) +def sdpa_backward_cudnn_time(gpu_type, dtype, gflops, grad_out_shape, query_shape, key_shape, value_shape, *args, out_shape=None, **kwargs) -> float: + model = get_learned_model("sdpa_backward", gpu_type) + b, h, s_q, s_kv, d_qk, d_v = check_sdpa_shapes_backward(query_shape, key_shape, value_shape, grad_out_shape) + features = build_sdpa_features(b, h, s_q, s_kv, d_qk, d_v, gflops, dtype, "cudnn", is_causal=is_causal_sdpa(args)) + return float(model.predict(features)[0]) + +@register_timing_formula(aten._scaled_dot_product_efficient_attention_backward) +def sdpa_backward_efficient_time(gpu_type, dtype, gflops, grad_out_shape, query_shape, key_shape, value_shape, *args, out_shape=None, **kwargs) -> float: + model = get_learned_model("sdpa_backward", gpu_type) + b, h, s_q, s_kv, d_qk, d_v = check_sdpa_shapes_backward(query_shape, key_shape, value_shape, grad_out_shape) + features = build_sdpa_features(b, h, s_q, s_kv, d_qk, d_v, gflops, dtype, "efficient", is_causal=is_causal_sdpa(args)) + return float(model.predict(features)[0]) + +@register_timing_formula(aten._scaled_dot_product_flash_attention_backward) +def sdpa_backward_flash_time(gpu_type, dtype, gflops, grad_out_shape, query_shape, key_shape, value_shape, *args, out_shape=None, **kwargs) -> float: + model = get_learned_model("sdpa_backward", gpu_type) + b, h, s_q, s_kv, d_qk, d_v = check_sdpa_shapes_backward(query_shape, key_shape, value_shape, grad_out_shape) + features = build_sdpa_features(b, h, s_q, s_kv, d_qk, d_v, gflops, dtype, "flash", is_causal=is_causal_sdpa(args)) + return float(model.predict(features)[0]) + +# @register_timing_formula([aten.convolution, aten._convolution]) +# def conv_time(gpu_type, dtype, gflops, x_shape, w_shape, _bias, _stride, _padding, _dilation, transposed, *args, out_shape=None, **kwargs) -> float: +# """ +# TODO: need to add support for higher dims. +# """ +# model = get_learned_model("conv", gpu_type) + +# # batch_size = x_shape[0] +# # conv_shape = (x_shape if transposed else out_shape)[2:] +# # c_out, c_in, *filter_size = w_shape + +# # features = np.array([[b, m, k, n, gflops]]) +# return model.predict(features) + +# @register_timing_formula(aten.convolution_backward) +# def conv_backward_time( +# gpu_type, +# dtype, +# gflops, +# grad_out_shape, +# x_shape, +# w_shape, +# _bias, +# _stride, +# _padding, +# _dilation, +# transposed, +# _output_padding, +# _groups, +# output_mask, +# out_shape) -> float: +# """ +# TODO: need to add support for higher dims. +# """ +# model = get_learned_model("conv_backward", gpu_type) + +# # features = np.array([[b, m, k, n, gflops]]) +# return model.predict(features) + class RuntimeEstimator(TorchDispatchMode): """ Estimates the GPU runtime in milliseconds using various estimation methods under the ``FakeTensorMode``. @@ -141,6 +398,9 @@ class RuntimeEstimator(TorchDispatchMode): _no_fallback_kernel: Set[torch._ops._OpNamespace] = set() fake_mode: FakeTensorMode + gpu_types: Dict[int, str] = {} + count = {} + def __init__(self) -> None: super().__init__() self._estimate: Callable @@ -155,7 +415,10 @@ def __init__(self) -> None: self.mod_bw_post_order: List[str] = [] self.total_runtime: float = 0.0 - self.gpu_type = self.get_device_type() + gpu_id = torch.cuda.current_device() # Get the current GPU ID + if gpu_id not in RuntimeEstimator.gpu_types: + RuntimeEstimator.gpu_types[gpu_id] = self.get_device_type() # Initialize gpu_type for the GPU + self.gpu_type = RuntimeEstimator.gpu_types[gpu_id] # Assign gpu_type based on the current GPU def get_device_type(self) -> int: try: @@ -470,256 +733,26 @@ def _learned_estimate_predictor(cls, func_packet, args, kwargs, out, out_dtypes) out_dtypes: The output data types. Returns: - float: The estimated compute time in nanoseconds. + float: The estimated compute time in milliseconds. # TODO: comments. Note: for the prediction functions, we mimic the arguments for mm_flop. """ - def get_learned_model(op: str) -> Any: - if op not in _LEARNED_OPS_PREDICTORS: - base_dir = os.path.join(os.getcwd()) - path = os.path.join(base_dir, f"{cls.gpu_type}_models", f"{op}.joblib") - - _LEARNED_OPS_PREDICTORS[op] = joblib.load(path) - return _LEARNED_OPS_PREDICTORS[op] - - - from functools import wraps - from torch.utils._pytree import tree_map - - def get_shape(i): - if isinstance(i, torch.Tensor): - return i.shape - return i - - def shape_wrapper(f): - """ - Similar to flop_counter.shape_wrapper(), but also takes takes gflops. - """ - @wraps(f) - def nf(dtype, gflops, *args, out_val=None, **kwargs): - args, kwargs, out_shape = tree_map(get_shape, (args, kwargs, out_val)) - return f(dtype, gflops, *args, out_shape=out_shape, **kwargs) - return nf - - def register_timing_formula(targets, get_raw=False): - """ - Similar to flop_counter.register_flop_formula(). - """ - def register_fun(flop_formula): - if not get_raw: - flop_formula = shape_wrapper(flop_formula) - - def register(target): - if not isinstance(target, torch._ops.OpOverloadPacket): - raise ValueError( - f"register_flop_formula(targets): expected each target to be " - f"OpOverloadPacket (i.e. torch.ops.mylib.foo), got " - f"{target} which is of type {type(target)}") - if target in flop_registry: - raise RuntimeError(f"duplicate registrations for {target}") - flop_registry[target] = flop_formula - - # To handle allowing multiple aten_ops at once - torch.utils._pytree.tree_map_(register, targets) - - return flop_formula - - return register_fun - - - def convert_dtype(dtype) -> list[int]: - """ - To use dtype in a learned model, we convert them to one-hot encodings. - - Learned model supports the dtypes: - - torch.float16 - - torch.float32 - - torch.bfloat16 - """ - dtypes = [torch.float16, torch.float32, torch.bfloat16] - return [1 if dtype == d else 0 for d in dtypes] - - @register_timing_formula([aten.mm, aten.addmm]) - def mm_time(dtype, gflops, a_shape, b_shape, *args, out_shape=None, **kwargs) -> float: - model = get_learned_model("mm") - - m, n = a_shape - n2, p = b_shape - assert n == n2 - - features = np.array([[m, n, p, gflops] + convert_dtype(dtype)]) - return model.predict(features) - - @register_timing_formula([aten.bmm, aten.baddmm]) - def bmm_time(dtype, gflops, a_shape, b_shape, out_shape=None, **kwargs) -> float: - model = get_learned_model("bmm") - - b, m, n = a_shape - b2, n2, p = b_shape - assert b == b2 and n == n2 - - features = np.array([[b, m, n, p, gflops] + convert_dtype(dtype)]) - return model.predict(features) - - def is_causal_sdpa(args: tuple) -> bool: - """ - TODO: the way that flop_counter implements sdpa args/kwargs—namely, `is_causal`—should be updated. - This is a heuristic hackaround. - """ - if len(args) >= 2: - if args[0] is not None: - return True - elif len(args) > 2: - if isinstance(args[-1], bool): - return args[-1] - return False - - @register_timing_formula(aten._scaled_dot_product_cudnn_attention) - def sdpa_cudnn_time(dtype, gflops, query_shape, key_shape, value_shape, *args, out_shape=None, **kwargs) -> float: - model = get_learned_model("sdpa") - - b, h, s_q, d_qk = query_shape - _b2, _h2, s_kv, _d2 = key_shape - _b3, _h3, _s3, d_v = value_shape - assert b == _b2 == _b3 and h == _h2 == _h3 and d_qk == _d2 and s_kv == _s3 and d_qk == _d2 - - backends_ohe = [1, 0, 0] - is_causal_ohe = [0, 1] if is_causal_sdpa(args) else [1, 0] - features = np.array([[b, h, s_q, s_kv, d_qk, d_v, gflops] + convert_dtype(dtype) + backends_ohe + is_causal_ohe]) - return model.predict(features) - - @register_timing_formula(aten._scaled_dot_product_efficient_attention) - def sdpa_efficient_time(dtype, gflops, query_shape, key_shape, value_shape, *args, out_shape=None, **kwargs) -> float: - model = get_learned_model("sdpa") - - b, h, s_q, d_qk = query_shape - _b2, _h2, s_kv, _d2 = key_shape - _b3, _h3, _s3, d_v = value_shape - assert b == _b2 == _b3 and h == _h2 == _h3 and d_qk == _d2 and s_kv == _s3 and d_qk == _d2 - - backends_ohe = [0, 1, 0] - is_causal_ohe = [0, 1] if is_causal_sdpa(args) else [1, 0] - features = np.array([[b, h, s_q, s_kv, d_qk, d_v, gflops] + convert_dtype(dtype) + backends_ohe + is_causal_ohe]) - return model.predict(features) - - @register_timing_formula(aten._scaled_dot_product_flash_attention) - def sdpa_flash_time(dtype, gflops, query_shape, key_shape, value_shape, *args, out_shape=None, **kwargs) -> float: - model = get_learned_model("sdpa") - - b, h, s_q, d_qk = query_shape - _b2, _h2, s_kv, _d2 = key_shape - _b3, _h3, _s3, d_v = value_shape - assert b == _b2 == _b3 and h == _h2 == _h3 and d_qk == _d2 and s_kv == _s3 and d_qk == _d2 - - backends_ohe = [0, 0, 1] - is_causal_ohe = [0, 1] if is_causal_sdpa(args) else [1, 0] - features = np.array([[b, h, s_q, s_kv, d_qk, d_v, gflops] + convert_dtype(dtype) + backends_ohe + is_causal_ohe]) - return model.predict(features) - - @register_timing_formula(aten._scaled_dot_product_cudnn_attention_backward) - def sdpa_backward_cudnn_time(dtype, gflops, grad_out_shape, query_shape, key_shape, value_shape, *args, out_shape=None, **kwargs) -> float: - model = get_learned_model("sdpa_backward") - - b, h, s_q, d_qk = query_shape - _b2, _h2, s_kv, _d2 = key_shape - _b3, _h3, _s3, d_v = value_shape - _b4, _h4, _s4, _d4 = grad_out_shape - assert b == _b2 == _b3 == _b4 and h == _h2 == _h3 == _h4 and d_qk == _d2 - assert d_v == _d4 and s_kv == _s3 and s_q == _s4 - - backends_ohe = [1, 0, 0] - is_causal_ohe = [0, 1] if is_causal_sdpa(args) else [1, 0] - features = np.array([[b, h, s_q, s_kv, d_qk, d_v, gflops] + convert_dtype(dtype) + backends_ohe + is_causal_ohe]) - return model.predict(features) - - @register_timing_formula(aten._scaled_dot_product_efficient_attention_backward) - def sdpa_backward_efficient_time(dtype, gflops, grad_out_shape, query_shape, key_shape, value_shape, *args, out_shape=None, **kwargs) -> float: - model = get_learned_model("sdpa_backward") - - b, h, s_q, d_qk = query_shape - _b2, _h2, s_kv, _d2 = key_shape - _b3, _h3, _s3, d_v = value_shape - _b4, _h4, _s4, _d4 = grad_out_shape - assert b == _b2 == _b3 == _b4 and h == _h2 == _h3 == _h4 and d_qk == _d2 - assert d_v == _d4 and s_kv == _s3 and s_q == _s4 - - backends_ohe = [0, 1, 0] - is_causal_ohe = [0, 1] if is_causal_sdpa(args) else [1, 0] - features = np.array([[b, h, s_q, s_kv, d_qk, d_v, gflops] + convert_dtype(dtype) + backends_ohe + is_causal_ohe]) - return model.predict(features) - - @register_timing_formula(aten._scaled_dot_product_flash_attention_backward) - def sdpa_backward_flash_time(dtype, gflops, grad_out_shape, query_shape, key_shape, value_shape, *args, out_shape=None, **kwargs) -> float: - model = get_learned_model("sdpa_backward") - - b, h, s_q, d_qk = query_shape - _b2, _h2, s_kv, _d2 = key_shape - _b3, _h3, _s3, d_v = value_shape - _b4, _h4, _s4, _d4 = grad_out_shape - assert b == _b2 == _b3 == _b4 and h == _h2 == _h3 == _h4 and d_qk == _d2 - assert d_v == _d4 and s_kv == _s3 and s_q == _s4 - - backends_ohe = [0, 0, 1] - is_causal_ohe = [0, 1] if is_causal_sdpa(args) else [1, 0] - features = np.array([[b, h, s_q, s_kv, d_qk, d_v, gflops] + convert_dtype(dtype) + backends_ohe + is_causal_ohe]) - return model.predict(features) - - # @register_timing_formula([aten.convolution, aten._convolution]) - # def conv_time(dtype, gflops, x_shape, w_shape, _bias, _stride, _padding, _dilation, transposed, *args, out_shape=None, **kwargs) -> float: - # """ - # TODO: need to add support for higher dims. - # """ - # model = get_learned_model("conv") - - # # batch_size = x_shape[0] - # # conv_shape = (x_shape if transposed else out_shape)[2:] - # # c_out, c_in, *filter_size = w_shape - - # # features = np.array([[b, m, k, n, gflops]]) - # return model.predict(features) - - # @register_timing_formula(aten.convolution_backward) - # def conv_backward_time( - # dtype, - # gflops, - # grad_out_shape, - # x_shape, - # w_shape, - # _bias, - # _stride, - # _padding, - # _dilation, - # transposed, - # _output_padding, - # _groups, - # output_mask, - # out_shape) -> float: - # """ - # TODO: need to add support for higher dims. - # """ - # model = get_learned_model("conv_backward") - - # # features = np.array([[b, m, k, n, gflops]]) - # return model.predict(features) - + op_time = 0.0 if func_packet in _LEARNED_OPS: assert ( len(out_dtypes) == 1 ), f"Only support single out dtype got {out_dtypes} for {func_packet}" dtype = out_dtypes.pop() - + flop_count_func = flop_registry[func_packet] gflops = flop_count_func(*args, **kwargs, out_val=out) / 1e9 - predictor_func = _LEARNED_OPS[func_packet] - # Returns compute time in ms, so multiply by 1e6 to get nanoseconds - compute_time = predictor_func(dtype, gflops, *args, **kwargs, out_val=out) - compute_time *= 1e6 - else: - compute_time = cls._get_compute_time(func_packet, args, kwargs, out, out_dtypes) - return 0.0 + gpu_id = torch.cuda.current_device() + op_time = predictor_func(cls.gpu_types[gpu_id], dtype, gflops, *args, **kwargs, out_val=out) + cls.count[func_packet] = cls.count.get(func_packet, 0) + 1 + return op_time @classmethod def _learned_estimate(cls, func, args, kwargs) -> Tuple[Any, float]: # type: ignore[no-untyped-def] @@ -747,7 +780,6 @@ def _learned_estimate(cls, func, args, kwargs) -> Tuple[Any, float]: # type: ig if func_packet not in _IGNORE_OPS: flat_args_kwargs, args_spec = pytree.tree_flatten((args, kwargs)) flat_outs, out_spec = pytree.tree_flatten(out) - transfer_time = cls._get_transfer_time(flat_args_kwargs, flat_outs) out_dtypes = { t.dtype @@ -757,11 +789,17 @@ def _learned_estimate(cls, func, args, kwargs) -> Tuple[Any, float]: # type: ig args, kwargs = pytree.tree_unflatten(flat_args_kwargs, args_spec) out = pytree.tree_unflatten(flat_outs, out_spec) - - compute_time = cls._learned_estimate_predictor(func_packet, args, kwargs, out, out_dtypes) - # We get the estimated time as the max of the transfer time and - # compute time. We divide by 1e6 to get the time in ms - op_time = max(transfer_time, compute_time) / 1e6 + + if func_packet in _LEARNED_OPS: + op_time = cls._learned_estimate_predictor(func_packet, args, kwargs, out, out_dtypes) + else: + # Roofline estimate. + transfer_time = cls._get_transfer_time(flat_args_kwargs, flat_outs) + compute_time = cls._get_compute_time(func_packet, args, kwargs, out, out_dtypes) + + # We get the estimated time as the max of the transfer time and + # compute time. We divide by 1e6 to get the time in ms + op_time = max(transfer_time, compute_time) / 1e6 return (out, op_time) @@ -828,6 +866,8 @@ def __call__(self, estimate_mode_type: str) -> Self: self._estimate = RuntimeEstimator._benchmark_estimate elif estimate_mode_type == "operator-level-cost-model": self._estimate = RuntimeEstimator._roofline_estimate + elif estimate_mode_type == "operator-level-learned-model": + self._estimate = RuntimeEstimator._learned_estimate else: raise NotImplementedError( f"estimate_mode_type {estimate_mode_type} not supported" @@ -870,6 +910,7 @@ def __exit__(self, *args: Any) -> None: f"Estimated ({self._estimate_mode_type})" f"total_time: {self.total_runtime:.3f} ms" ) + print("count", self.count) if len(self._no_fallback_kernel) > 0: print("no_fallback_kernel: ", list(self._no_fallback_kernel)) super().__exit__(*args)