diff --git a/.gitattributes b/.gitattributes index e904301752950..3a9521336f1e2 100644 --- a/.gitattributes +++ b/.gitattributes @@ -5,3 +5,11 @@ .github/scripts/gql_mocks.json linguist-generated=true third_party/LICENSES_BUNDLED.txt linguist-generated=true tools/build/bazel/requirements.txt linguist-generated=true +torch/distributed/_tools/a100_models/bmm.joblib filter=lfs diff=lfs merge=lfs -text +torch/distributed/_tools/a100_models/mm.joblib filter=lfs diff=lfs merge=lfs -text +torch/distributed/_tools/a100_models/sdpa.joblib filter=lfs diff=lfs merge=lfs -text +torch/distributed/_tools/a100_models/sdpa_backward.joblib filter=lfs diff=lfs merge=lfs -text +torch/distributed/_tools/h100_models/sdpa.joblib filter=lfs diff=lfs merge=lfs -text +torch/distributed/_tools/h100_models/sdpa_backward.joblib filter=lfs diff=lfs merge=lfs -text +torch/distributed/_tools/h100_models/bmm.joblib filter=lfs diff=lfs merge=lfs -text +torch/distributed/_tools/h100_models/mm.joblib filter=lfs diff=lfs merge=lfs -text diff --git a/test/distributed/_tools/test_runtime_estimator.py b/test/distributed/_tools/test_runtime_estimator.py index 741ba7b2e8a03..0d4e4782675cf 100644 --- a/test/distributed/_tools/test_runtime_estimator.py +++ b/test/distributed/_tools/test_runtime_estimator.py @@ -133,6 +133,7 @@ def _init_model_and_args( def test_transformer_runtime( self, ): + print("Transformer Test") """Runs a basic GPT-2 model""" vocab_size = 8192 bsz, seq_len = 8, 1024 @@ -155,22 +156,30 @@ def test_transformer_runtime( roofline_estimate = self._runtime_estimate( "operator-level-cost-model", self._train_step, fake_args ) + learned_estimate = self._runtime_estimate( + "operator-level-learned-model", self._train_step, fake_args + ) benchmark_accuracy = actual_runtime / benchmark_estimate roofline_accuracy = actual_runtime / roofline_estimate + learned_accuracy = actual_runtime / learned_estimate print( f"Actual: {actual_runtime} Benchmark Estimate: {benchmark_estimate} Accuracy: {benchmark_accuracy}" - f"\n Actual: {actual_runtime} Roofline Estimatee: {roofline_estimate} Accuracy: {roofline_accuracy}" + f"\nActual: {actual_runtime} Roofline Estimate: {roofline_estimate} Accuracy: {roofline_accuracy}" + f"\nActual: {actual_runtime} Learned Estimate: {learned_estimate} Accuracy: {learned_accuracy}" ) + # No accuracy check for benchmark in CI as it is highly variable # self.assertAlmostEqual(benchmark_accuracy, 1.0, delta=0.2) # self.assertAlmostEqual(roofline_accuracy, 1.0, delta=0.3) + @skipIfTorchDynamo("https://github.com/pytorch/pytorch/issues/115653") @unittest.skipIf(not TEST_CUDA, "CUDA not available") def test_conv_model_runtime( self, ): """Runs a simple CNN model""" + print("CNN Test") num_classes = 100 bsz, img_sz = 256, 128 model_args = ConvArgs(img_sz, num_classes) @@ -184,11 +193,16 @@ def test_conv_model_runtime( roofline_estimate = self._runtime_estimate( "operator-level-cost-model", self._train_step, fake_args ) + learned_estimate = self._runtime_estimate( + "operator-level-learned-model", self._train_step, fake_args + ) benchmark_accuracy = actual_runtime / benchmark_estimate roofline_accuracy = actual_runtime / roofline_estimate + learned_accuracy = actual_runtime / learned_estimate print( - f"Actual: {actual_runtime} Benchmark Estimate: {benchmark_estimate} Accuracy: {benchmark_accuracy}\n" - f"Actual: {actual_runtime} Roofline Estimatee: {roofline_estimate} Accuracy: {roofline_accuracy}" + f"Actual: {actual_runtime} Benchmark Estimate: {benchmark_estimate} Accuracy: {benchmark_accuracy}" + f"\nActual: {actual_runtime} Roofline Estimate: {roofline_estimate} Accuracy: {roofline_accuracy}" + f"\nActual: {actual_runtime} Learned Estimate: {learned_estimate} Accuracy: {learned_accuracy}" ) # No accuracy check for benchmark in CI as it is highly variable # self.assertAlmostEqual(benchmark_accuracy, 1.0, delta=0.2) diff --git a/torch/distributed/_tools/a100_models/bmm.joblib b/torch/distributed/_tools/a100_models/bmm.joblib new file mode 100644 index 0000000000000..4c0d2073d622c --- /dev/null +++ b/torch/distributed/_tools/a100_models/bmm.joblib @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:fe40334437e5ad1a4cd2d496c6eb157481b114418408d738e6a7b9a4c99ab530 +size 78938299 diff --git a/torch/distributed/_tools/a100_models/mm.joblib b/torch/distributed/_tools/a100_models/mm.joblib new file mode 100644 index 0000000000000..b0acfd47b1b04 --- /dev/null +++ b/torch/distributed/_tools/a100_models/mm.joblib @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:00c789a3e8a99dd76c59619dc245b367fecf5c59250c6aa919c8b12cde413338 +size 80274914 diff --git a/torch/distributed/_tools/a100_models/sdpa.joblib b/torch/distributed/_tools/a100_models/sdpa.joblib new file mode 100644 index 0000000000000..8bfc8fd5c9069 --- /dev/null +++ b/torch/distributed/_tools/a100_models/sdpa.joblib @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2c58e2fb8637f68562ba0ad16e61fd9dd499b2ab56617b072c14957c86333e0f +size 146079412 diff --git a/torch/distributed/_tools/a100_models/sdpa_backward.joblib b/torch/distributed/_tools/a100_models/sdpa_backward.joblib new file mode 100644 index 0000000000000..5814fda0fa6ae --- /dev/null +++ b/torch/distributed/_tools/a100_models/sdpa_backward.joblib @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:35a6edecc6eb2297a34ee1661f7c796c3f6e68a02cd91cccdfb30c3e6ad6e26c +size 165674354 diff --git a/torch/distributed/_tools/h100_models/bmm.joblib b/torch/distributed/_tools/h100_models/bmm.joblib new file mode 100644 index 0000000000000..81d064a32d971 --- /dev/null +++ b/torch/distributed/_tools/h100_models/bmm.joblib @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:516ed420b674477cda3b3821b30dbe64c3251199712e418864a0307c8ef6d10f +size 83504981 diff --git a/torch/distributed/_tools/h100_models/mm.joblib b/torch/distributed/_tools/h100_models/mm.joblib new file mode 100644 index 0000000000000..ec87f7de15bf6 --- /dev/null +++ b/torch/distributed/_tools/h100_models/mm.joblib @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:130a7338e1e88fab9465f66ba4693e24290061908108cd37e8be1dd7faba7339 +size 97798975 diff --git a/torch/distributed/_tools/h100_models/sdpa.joblib b/torch/distributed/_tools/h100_models/sdpa.joblib new file mode 100644 index 0000000000000..9e0be5b6f9277 --- /dev/null +++ b/torch/distributed/_tools/h100_models/sdpa.joblib @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f9cc36c6e6734a28860fcb6ec18ed2901bef9740d9cfd5385df8eea80d88b53b +size 175298513 diff --git a/torch/distributed/_tools/h100_models/sdpa_backward.joblib b/torch/distributed/_tools/h100_models/sdpa_backward.joblib new file mode 100644 index 0000000000000..4ac20630f1d98 --- /dev/null +++ b/torch/distributed/_tools/h100_models/sdpa_backward.joblib @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0d60cce92bdc450be21d3fe3980dd93012d218754ce0b505784b45b0a2837558 +size 185574443 diff --git a/torch/distributed/_tools/runtime_estimator.py b/torch/distributed/_tools/runtime_estimator.py index 87f4d3f36b60e..3bd434864ba74 100644 --- a/torch/distributed/_tools/runtime_estimator.py +++ b/torch/distributed/_tools/runtime_estimator.py @@ -1,6 +1,11 @@ # Owner(s): ["module: unknown"] import math import os +import joblib +import subprocess +import numpy as np +import pandas as pd +import time from collections import defaultdict from typing import Any, Callable, Dict, List, Set, Tuple from typing_extensions import Self @@ -73,9 +78,272 @@ _IGNORE_OPS = _VIEW_OPS | _CREATE_OPS +# Similar to `flop_registry`, stores the functions that make predictions +_LEARNED_OPS: Dict[Any, Any] = {} + +# Caches the learned models that predict ops' runtimes. +_LEARNED_OPS_PREDICTORS: Dict[str, Any] = {} + + __all__ = ["RuntimeEstimator"] +def get_learned_model(op: str, gpu_type: str) -> Any: + if op not in _LEARNED_OPS_PREDICTORS: + base_dir = os.path.dirname(os.path.realpath(__file__)) + path = os.path.join(base_dir, f"{gpu_type}_models", f"{op}.joblib") + + _LEARNED_OPS_PREDICTORS[op] = joblib.load(path) + return _LEARNED_OPS_PREDICTORS[op] + +from functools import wraps +from torch.utils._pytree import tree_map + +def get_shape(i): + if isinstance(i, torch.Tensor): + return i.shape + return i + +def shape_wrapper(f): + """ + Similar to flop_counter.shape_wrapper(), but also takes takes gflops. + """ + @wraps(f) + def nf(gpu_type, dtype, gflops, *args, out_val=None, **kwargs): + args, kwargs, out_shape = tree_map(get_shape, (args, kwargs, out_val)) + return f(gpu_type, dtype, gflops, *args, out_shape=out_shape, **kwargs) + return nf + +def register_timing_formula(targets, get_raw=False): + """ + Similar to flop_counter.register_flop_formula(). + """ + def register_fun(time_formula): + if not get_raw: + time_formula = shape_wrapper(time_formula) + + def register(target): + if not isinstance(target, torch._ops.OpOverloadPacket): + raise ValueError( + f"register_flop_formula(targets): expected each target to be " + f"OpOverloadPacket (i.e. torch.ops.mylib.foo), got " + f"{target} which is of type {type(target)}") + if target in _LEARNED_OPS: + raise RuntimeError(f"duplicate registrations for {target}") + _LEARNED_OPS[target] = time_formula + + # To handle allowing multiple aten_ops at once + torch.utils._pytree.tree_map_(register, targets) + + return time_formula + + return register_fun + +def convert_dtype(dtype) -> Dict[str, int]: + """ + Convert dtype to a one-hot encoding as a pandas Series. + + Learned model supports the dtypes: + - torch.float16 + - torch.float32 + - torch.bfloat16 + """ + dtypes = [torch.float16, torch.float32, torch.bfloat16] + dtype_one_hot = [1 if dtype == d else 0 for d in dtypes] + dtype_names = ["dtype_16", "dtype_32", "dtype_b16"] + return dict(zip(dtype_names, dtype_one_hot)) + +@register_timing_formula(aten.mm) +def mm_time(gpu_type, dtype, gflops, a_shape, b_shape, *args, out_shape=None, **kwargs) -> float: + model = get_learned_model("mm", gpu_type) + + m, n = a_shape + n2, p = b_shape + assert n == n2 + + dtypes = convert_dtype(dtype) + features = { + "n": n, + "m": m, + "p": p, + "gflops": gflops, + "dtype_16": dtypes["dtype_16"], + "dtype_32": dtypes["dtype_32"], + "dtype_b16": dtypes["dtype_b16"], + } + features_df = pd.DataFrame([features]) + return float(model.predict(features_df)[0]) + +@register_timing_formula(aten.addmm) +def addmm_time(gpu_type, dtype, gflops, self_shape, a_shape, b_shape, out_shape=None, **kwargs) -> float: + return mm_time(gpu_type, dtype, gflops, a_shape, b_shape) + +@register_timing_formula(aten.bmm) +def bmm_time(gpu_type, dtype, gflops, a_shape, b_shape, out_shape=None, **kwargs) -> float: + model = get_learned_model("bmm", gpu_type) + + b, m, n = a_shape + b2, n2, p = b_shape + assert b == b2 and n == n2 + + dtypes = convert_dtype(dtype) + features = { + "b": b, + "n": n, + "m": m, + "p": p, + "gflops": gflops, + "dtype_16": dtypes["dtype_16"], + "dtype_32": dtypes["dtype_32"], + "dtype_b16": dtypes["dtype_b16"], + } + features_df = pd.DataFrame([features]) + return float(model.predict(features_df)[0]) + +@register_timing_formula(aten.baddbmm) +def baddbmm_time(gpu_type, dtype, gflops, self_shape, a_shape, b_shape, out_shape=None, **kwargs) -> float: + return bmm_time(gpu_type, dtype, gflops, a_shape, b_shape) + +def is_causal_sdpa(args: tuple) -> bool: + """ + TODO: the way that flop_counter implements sdpa args/kwargs—namely, `is_causal`—should be updated. + This is a heuristic hackaround. + """ + if len(args) >= 2 and args[0] is not None: + return True + if len(args) > 2 and isinstance(args[-1], bool): + return args[-1] + return False + +def build_sdpa_features(b, h, s_q, s_kv, d_qk, d_v, gflops, dtype, backend, is_causal: bool) -> pd.DataFrame: + if backend == "cudnn": + backends_ohe = [1, 0, 0] + elif backend == "efficient": + backends_ohe = [0, 1, 0] + elif backend == "flash": + backends_ohe = [0, 0, 1] + else: + raise ValueError(f"Unknown backend: {backend}") + + dtypes = convert_dtype(dtype) + is_causal_ohe = [0, 1] if is_causal else [1, 0] + + features = { + "b": b, + "h": h, + "s_q": s_q, + "s_kv": s_kv, + "d_qk": d_qk, + "d_v": d_v, + "gflops": gflops, + "dtype_16": dtypes["dtype_16"], + "dtype_32": dtypes["dtype_32"], + "dtype_b16": dtypes["dtype_b16"], + "backend_cudnn": backends_ohe[0], + "backend_efficient": backends_ohe[1], + "backend_flash": backends_ohe[2], + "is_causal_0": is_causal_ohe[0], + "is_causal_1": is_causal_ohe[1] + } + return pd.DataFrame([features]) + + +def check_sdpa_shapes(query_shape, key_shape, value_shape): + b, h, s_q, d_qk = query_shape + _b2, _h2, s_kv, _d2 = key_shape + _b3, _h3, _s3, d_v = value_shape + assert b == _b2 == _b3 and h == _h2 == _h3 and d_qk == _d2 and s_kv == _s3 and d_qk == _d2 + return b, h, s_q, s_kv, d_qk, d_v + +def check_sdpa_shapes_backward(query_shape, key_shape, value_shape, grad_out_shape): + b, h, s_q, d_qk = query_shape + _b2, _h2, s_kv, _d2 = key_shape + _b3, _h3, _s3, d_v = value_shape + _b4, _h4, _s4, _d4 = grad_out_shape + assert b == _b2 == _b3 == _b4 and h == _h2 == _h3 == _h4 and d_qk == _d2 and d_v == _d4 and s_kv == _s3 and s_q == _s4 + return b, h, s_q, s_kv, d_qk, d_v + +@register_timing_formula(aten._scaled_dot_product_cudnn_attention) +def sdpa_cudnn_time(gpu_type, dtype, gflops, query_shape, key_shape, value_shape, *args, out_shape=None, **kwargs) -> float: + model = get_learned_model("sdpa", gpu_type) + b, h, s_q, s_kv, d_qk, d_v = check_sdpa_shapes(query_shape, key_shape, value_shape) + features = build_sdpa_features(b, h, s_q, s_kv, d_qk, d_v, gflops, dtype, "cudnn", is_causal=is_causal_sdpa(args)) + return float(model.predict(features)[0]) + +@register_timing_formula(aten._scaled_dot_product_efficient_attention) +def sdpa_efficient_time(gpu_type, dtype, gflops, query_shape, key_shape, value_shape, *args, out_shape=None, **kwargs) -> float: + model = get_learned_model("sdpa", gpu_type) + b, h, s_q, s_kv, d_qk, d_v = check_sdpa_shapes(query_shape, key_shape, value_shape) + features = build_sdpa_features(b, h, s_q, s_kv, d_qk, d_v, gflops, dtype, "efficient", is_causal=is_causal_sdpa(args)) + return float(model.predict(features)[0]) + +@register_timing_formula(aten._scaled_dot_product_flash_attention) +def sdpa_flash_time(gpu_type, dtype, gflops, query_shape, key_shape, value_shape, *args, out_shape=None, **kwargs) -> float: + model = get_learned_model("sdpa", gpu_type) + b, h, s_q, s_kv, d_qk, d_v = check_sdpa_shapes(query_shape, key_shape, value_shape) + features = build_sdpa_features(b, h, s_q, s_kv, d_qk, d_v, gflops, dtype, "flash", is_causal=is_causal_sdpa(args)) + return float(model.predict(features)[0]) + +@register_timing_formula(aten._scaled_dot_product_cudnn_attention_backward) +def sdpa_backward_cudnn_time(gpu_type, dtype, gflops, grad_out_shape, query_shape, key_shape, value_shape, *args, out_shape=None, **kwargs) -> float: + model = get_learned_model("sdpa_backward", gpu_type) + b, h, s_q, s_kv, d_qk, d_v = check_sdpa_shapes_backward(query_shape, key_shape, value_shape, grad_out_shape) + features = build_sdpa_features(b, h, s_q, s_kv, d_qk, d_v, gflops, dtype, "cudnn", is_causal=is_causal_sdpa(args)) + return float(model.predict(features)[0]) + +@register_timing_formula(aten._scaled_dot_product_efficient_attention_backward) +def sdpa_backward_efficient_time(gpu_type, dtype, gflops, grad_out_shape, query_shape, key_shape, value_shape, *args, out_shape=None, **kwargs) -> float: + model = get_learned_model("sdpa_backward", gpu_type) + b, h, s_q, s_kv, d_qk, d_v = check_sdpa_shapes_backward(query_shape, key_shape, value_shape, grad_out_shape) + features = build_sdpa_features(b, h, s_q, s_kv, d_qk, d_v, gflops, dtype, "efficient", is_causal=is_causal_sdpa(args)) + return float(model.predict(features)[0]) + +@register_timing_formula(aten._scaled_dot_product_flash_attention_backward) +def sdpa_backward_flash_time(gpu_type, dtype, gflops, grad_out_shape, query_shape, key_shape, value_shape, *args, out_shape=None, **kwargs) -> float: + model = get_learned_model("sdpa_backward", gpu_type) + b, h, s_q, s_kv, d_qk, d_v = check_sdpa_shapes_backward(query_shape, key_shape, value_shape, grad_out_shape) + features = build_sdpa_features(b, h, s_q, s_kv, d_qk, d_v, gflops, dtype, "flash", is_causal=is_causal_sdpa(args)) + return float(model.predict(features)[0]) + +# @register_timing_formula([aten.convolution, aten._convolution]) +# def conv_time(gpu_type, dtype, gflops, x_shape, w_shape, _bias, _stride, _padding, _dilation, transposed, *args, out_shape=None, **kwargs) -> float: +# """ +# TODO: need to add support for higher dims. +# """ +# model = get_learned_model("conv", gpu_type) + +# # batch_size = x_shape[0] +# # conv_shape = (x_shape if transposed else out_shape)[2:] +# # c_out, c_in, *filter_size = w_shape + +# # features = np.array([[b, m, k, n, gflops]]) +# return model.predict(features) + +# @register_timing_formula(aten.convolution_backward) +# def conv_backward_time( +# gpu_type, +# dtype, +# gflops, +# grad_out_shape, +# x_shape, +# w_shape, +# _bias, +# _stride, +# _padding, +# _dilation, +# transposed, +# _output_padding, +# _groups, +# output_mask, +# out_shape) -> float: +# """ +# TODO: need to add support for higher dims. +# """ +# model = get_learned_model("conv_backward", gpu_type) + +# # features = np.array([[b, m, k, n, gflops]]) +# return model.predict(features) + class RuntimeEstimator(TorchDispatchMode): """ Estimates the GPU runtime in milliseconds using various estimation methods under the ``FakeTensorMode``. @@ -130,6 +398,9 @@ class RuntimeEstimator(TorchDispatchMode): _no_fallback_kernel: Set[torch._ops._OpNamespace] = set() fake_mode: FakeTensorMode + gpu_types: Dict[int, str] = {} + count = {} + def __init__(self) -> None: super().__init__() self._estimate: Callable @@ -144,6 +415,25 @@ def __init__(self) -> None: self.mod_bw_post_order: List[str] = [] self.total_runtime: float = 0.0 + gpu_id = torch.cuda.current_device() # Get the current GPU ID + if gpu_id not in RuntimeEstimator.gpu_types: + RuntimeEstimator.gpu_types[gpu_id] = self.get_device_type() # Initialize gpu_type for the GPU + self.gpu_type = RuntimeEstimator.gpu_types[gpu_id] # Assign gpu_type based on the current GPU + + def get_device_type(self) -> int: + try: + result = subprocess.check_output(['nvidia-smi', '--query-gpu=name', '--format=csv,noheader']) + gpu_name = result.decode('utf-8').strip() + + if "A100" in gpu_name: + return "a100" + elif "H100" in gpu_name: + return "h100" + else: + raise ValueError("GPU type not supported") + except subprocess.CalledProcessError as e: + raise ValueError("Error retrieving GPU name") + # Adapted from: https://github.com/pytorch/pytorch/blob/9b902b3ee3bd608a19543362b66bf06c373dd374/torch/_subclasses/fake_tensor.py#L1969 # noqa: PGH004,B950 # NB: returns fake tensors @classmethod @@ -275,7 +565,7 @@ def _benchmark_estimate(cls, func, args, kwargs) -> Tuple[Any, float]: # type: # Adapted from: https://github.com/pytorch/pytorch/blob/9b902b3ee3bd608a19543362b66bf06c373dd374/torch/_inductor/scheduler.py#L589 # noqa: PGH004,B950 @classmethod - def _roofline_estimate(cls, func, args, kwargs) -> Tuple[Any, float]: # type: ignore[no-untyped-def] + def _get_transfer_time(cls, flat_args_kwargs, flat_outs) -> float: # type: ignore[no-untyped-def] """ Estimates the runtime of a function using a roofline cost model. @@ -309,62 +599,72 @@ def get_num_bytes(t: torch.Tensor) -> int: ) return mem_consumed - def get_compute_time(func_packet, args, kwargs, out, out_dtypes) -> float: # type: ignore[no-untyped-def] - """ - Estimates the compute time of an aten operator. + gpu_memory_bandwidth = get_gpu_dram_gbps() + read_bytes = sum( + get_num_bytes(t) + for t in flat_args_kwargs + if isinstance(t, torch.Tensor) + ) + write_bytes = sum( + get_num_bytes(t) for t in flat_outs if isinstance(t, torch.Tensor) + ) + counted_bytes = read_bytes + write_bytes + # The GPU memory bandwidth is in GB/s so the transfer time is in nanoseconds + transfer_time = counted_bytes / gpu_memory_bandwidth + return transfer_time - Args: - func_packet: The operator overload packet. - args: The arguments to the operator. - kwargs: The keyword arguments to the operator. - out: The output of the operator. - out_dtypes: The output data types. + @classmethod + def _get_compute_time(cls, func_packet, args, kwargs, out, out_dtypes) -> float: # type: ignore[no-untyped-def] + """ + Estimates the compute time of an aten operator. - Returns: - float: The estimated compute time in nanoseconds. - """ - if func_packet in flop_registry: - assert ( - len(out_dtypes) == 1 - ), f"Only support single out dtype got {out_dtypes} for {func_packet}" - dtype = out_dtypes.pop() - # This actually gives peta-FLOPs/s hence multiply by 1e15 to get the FLOPs/s - peak_gpu_flops = get_device_tflops(dtype) * 1e15 - # We can expect to achieve 75% of theoretical peak flops - factor = 0.75 - peak_empirical_flops = factor * peak_gpu_flops - flop_count_func = flop_registry[func_packet] - # We divide by a factor of 2 to get the MACs (multiply and accumulate) - flop_count = flop_count_func(*args, **kwargs, out_val=out) / 2 - # We multiply by 1e9 to get the time in nano seconds - compute_time = (flop_count / peak_empirical_flops) * 1e9 - return compute_time - return 0.0 - - def get_transfer_time(flat_args_kwargs, flat_outs) -> float: # type: ignore[no-untyped-def] - """ - Estimates the memory transfer time of input and output tensors. + Args: + func_packet: The operator overload packet. + args: The arguments to the operator. + kwargs: The keyword arguments to the operator. + out: The output of the operator. + out_dtypes: The output data types. - Args: - flat_args_kwargs (List[torch.Tensor]): The flat list of arguments and keyword arguments. - flat_outs (List[torch.Tensor]): The flat list of outputs. + Returns: + float: The estimated compute time in nanoseconds. + """ + if func_packet in flop_registry: + assert ( + len(out_dtypes) == 1 + ), f"Only support single out dtype got {out_dtypes} for {func_packet}" + dtype = out_dtypes.pop() + # This actually gives peta-FLOPs/s hence multiply by 1e15 to get the FLOPs/s + peak_gpu_flops = get_device_tflops(dtype) * 1e15 + # We can expect to achieve 75% of theoretical peak flops + factor = 0.75 + peak_empirical_flops = factor * peak_gpu_flops + flop_count_func = flop_registry[func_packet] + # We divide by a factor of 2 to get the MACs (multiply and accumulate) + flop_count = flop_count_func(*args, **kwargs, out_val=out) / 2 + # We multiply by 1e9 to get the time in nano seconds + compute_time = (flop_count / peak_empirical_flops) * 1e9 + return compute_time + return 0.0 - Returns: - float: The estimated memory transfer time in nanoseconds. - """ - gpu_memory_bandwidth = get_gpu_dram_gbps() - read_bytes = sum( - get_num_bytes(t) - for t in flat_args_kwargs - if isinstance(t, torch.Tensor) - ) - write_bytes = sum( - get_num_bytes(t) for t in flat_outs if isinstance(t, torch.Tensor) - ) - counted_bytes = read_bytes + write_bytes - # The GPU memory bandwidth is in GB/s so the transfer time is in nanoseconds - transfer_time = counted_bytes / gpu_memory_bandwidth - return transfer_time + # Adapted from: https://github.com/pytorch/pytorch/blob/9b902b3ee3bd608a19543362b66bf06c373dd374/torch/_inductor/scheduler.py#L589 # noqa: PGH004,B950 + @classmethod + def _roofline_estimate(cls, func, args, kwargs) -> Tuple[Any, float]: # type: ignore[no-untyped-def] + """ + Estimates the runtime of a function using a roofline cost model. + + Args: + func: The function to estimate. + args: The arguments to pass to the function. + kwargs: The keyword arguments to pass to the function. + out: The output of the function. + + Returns: + Tuple[Any, float]: A tuple containing the result of the function and + the mean operation time in milliseconds. + """ + assert ( + torch.cuda.is_available() + ), "Roofline estimation needs to access CUDA capabilities to make estimations" # Roofline Cost Model Explanation @@ -397,7 +697,7 @@ def get_transfer_time(flat_args_kwargs, flat_outs) -> float: # type: ignore[no- if func_packet not in _IGNORE_OPS: flat_args_kwargs, args_spec = pytree.tree_flatten((args, kwargs)) flat_outs, out_spec = pytree.tree_flatten(out) - transfer_time = get_transfer_time(flat_args_kwargs, flat_outs) + transfer_time = cls._get_transfer_time(flat_args_kwargs, flat_outs) out_dtypes = { t.dtype @@ -408,12 +708,100 @@ def get_transfer_time(flat_args_kwargs, flat_outs) -> float: # type: ignore[no- args, kwargs = pytree.tree_unflatten(flat_args_kwargs, args_spec) out = pytree.tree_unflatten(flat_outs, out_spec) - compute_time = get_compute_time(func_packet, args, kwargs, out, out_dtypes) + compute_time = cls._get_compute_time(func_packet, args, kwargs, out, out_dtypes) # We get the estimated time as the max of the transfer time and # compute time. We divide by 1e6 to get the time in ms op_time = max(transfer_time, compute_time) / 1e6 return (out, op_time) + + @classmethod + def _learned_estimate_predictor(cls, func_packet, args, kwargs, out, out_dtypes) -> float: # type: ignore[no-untyped-def] + """ + TODO: + 1) the order of the features + 2) where the models are stored + + + Estimates the compute time of an aten operator. + + Args: + func_packet: The operator overload packet. + args: The arguments to the operator. + kwargs: The keyword arguments to the operator. + out: The output of the operator. + out_dtypes: The output data types. + + Returns: + float: The estimated compute time in milliseconds. + + + # TODO: comments. + Note: for the prediction functions, we mimic the arguments for mm_flop. + """ + op_time = 0.0 + if func_packet in _LEARNED_OPS: + assert ( + len(out_dtypes) == 1 + ), f"Only support single out dtype got {out_dtypes} for {func_packet}" + dtype = out_dtypes.pop() + + flop_count_func = flop_registry[func_packet] + gflops = flop_count_func(*args, **kwargs, out_val=out) / 1e9 + predictor_func = _LEARNED_OPS[func_packet] + gpu_id = torch.cuda.current_device() + op_time = predictor_func(cls.gpu_types[gpu_id], dtype, gflops, *args, **kwargs, out_val=out) + cls.count[func_packet] = cls.count.get(func_packet, 0) + 1 + return op_time + + @classmethod + def _learned_estimate(cls, func, args, kwargs) -> Tuple[Any, float]: # type: ignore[no-untyped-def] + """ + Estimates the runtime of a function using a learned estimator. + + Args: + func: The function to estimate. + args: The arguments to pass to the function. + kwargs: The keyword arguments to pass to the function. + res: The result of the function. + + Returns: + Tuple[Any, float]: A tuple containing the result of the function and + the mean operation time in milliseconds. + """ + assert ( + torch.cuda.is_available() + ), "Learned estimator needs to access CUDA capabilities to make estimations" + + kwargs = kwargs if kwargs else {} + out = func(*args, **kwargs) + op_time = 0.0 + func_packet = func._overloadpacket + if func_packet not in _IGNORE_OPS: + flat_args_kwargs, args_spec = pytree.tree_flatten((args, kwargs)) + flat_outs, out_spec = pytree.tree_flatten(out) + + out_dtypes = { + t.dtype + for t in flat_outs + if isinstance(t, torch.Tensor) and t.dtype in cls._float_types + } + + args, kwargs = pytree.tree_unflatten(flat_args_kwargs, args_spec) + out = pytree.tree_unflatten(flat_outs, out_spec) + + if func_packet in _LEARNED_OPS: + op_time = cls._learned_estimate_predictor(func_packet, args, kwargs, out, out_dtypes) + else: + # Roofline estimate. + transfer_time = cls._get_transfer_time(flat_args_kwargs, flat_outs) + compute_time = cls._get_compute_time(func_packet, args, kwargs, out, out_dtypes) + + # We get the estimated time as the max of the transfer time and + # compute time. We divide by 1e6 to get the time in ms + op_time = max(transfer_time, compute_time) / 1e6 + + return (out, op_time) def display_modulewise_stats(self, depth: int = 2) -> None: """ @@ -478,6 +866,8 @@ def __call__(self, estimate_mode_type: str) -> Self: self._estimate = RuntimeEstimator._benchmark_estimate elif estimate_mode_type == "operator-level-cost-model": self._estimate = RuntimeEstimator._roofline_estimate + elif estimate_mode_type == "operator-level-learned-model": + self._estimate = RuntimeEstimator._learned_estimate else: raise NotImplementedError( f"estimate_mode_type {estimate_mode_type} not supported" @@ -520,6 +910,7 @@ def __exit__(self, *args: Any) -> None: f"Estimated ({self._estimate_mode_type})" f"total_time: {self.total_runtime:.3f} ms" ) + print("count", self.count) if len(self._no_fallback_kernel) > 0: print("no_fallback_kernel: ", list(self._no_fallback_kernel)) super().__exit__(*args)