Skip to content

Commit

Permalink
Add Symbolic Shape Hint to Triton Codegen Config (#20056)
Browse files Browse the repository at this point in the history
Add symbolic shape hint to Triton codegen config so that we can avoid
unnecessary recompile when input shapes are keeping changing. Below
screenshot shows that with proper configuration, we can speed up the
training a lot by reducing unnecessary recompile.


![image](https://github.com/microsoft/onnxruntime/assets/11661208/699944d2-81cd-4c22-84e7-73a4fa0d2a28)
  • Loading branch information
centwang authored Mar 25, 2024
1 parent 4a196d1 commit d30c81d
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 10 deletions.
2 changes: 2 additions & 0 deletions orttraining/orttraining/python/training/ort_triton/_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import getpass
import hashlib
import os
import sys
import tempfile
from types import ModuleType
from typing import Tuple
Expand Down Expand Up @@ -61,6 +62,7 @@ def load(cls, source_code) -> ModuleType:
mod.__file__ = path
mod.key = key
exec(code, mod.__dict__, mod.__dict__)
sys.modules[mod.__name__] = mod
# another thread might set this first
cls.cache.setdefault(key, mod)
return cls.cache[key]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,13 @@
import functools
import json
import os
import re
import sys
from types import ModuleType
from typing import List, Tuple, Union

import onnx
from onnx import ModelProto
from torch._C import _from_dlpack
from torch.utils.dlpack import to_dlpack

Expand Down Expand Up @@ -41,18 +43,39 @@ class _ShapeCache:
"""

cache = dict() # noqa: RUF012
symbolic_shape_hint = None
min_symbolic_shape = 0
clear = staticmethod(cache.clear)

@classmethod
def get_shape(cls, onnx_key: int, shapes: List[List[int]]) -> List[List[Union[int, str]]]:
def set_symbolic_shape_hint(cls, symbolic_shape_hint_config):
for k, v in symbolic_shape_hint_config.items():
if k == "*":
cls.min_symbolic_shape = v
else:
if cls.symbolic_shape_hint is None:
cls.symbolic_shape_hint = dict()
cls.symbolic_shape_hint[k] = v

@classmethod
def get_shape(cls, onnx_key: int, model: ModelProto, shapes: List[List[int]]) -> List[List[Union[int, str]]]:
if onnx_key not in cls.cache:
if cls.symbolic_shape_hint is not None:
for i, input in enumerate(model.graph.input):
if input.type.tensor_type.HasField("shape"):
for j, dim in enumerate(input.type.tensor_type.shape.dim):
if dim.dim_param:
for k, v in cls.symbolic_shape_hint.items():
if re.fullmatch(k, dim.dim_param):
shapes[i][j] = f"i{i}_dim{j}_{v}"
break
cls.cache[onnx_key] = shapes
else:
changed = False
for i, shape in enumerate(shapes):
for j, dim in enumerate(shape):
if dim != cls.cache[onnx_key][i][j] and isinstance(cls.cache[onnx_key][i][j], int):
max_dim = max(dim, cls.cache[onnx_key][i][j])
if isinstance(cls.cache[onnx_key][i][j], int) and dim != cls.cache[onnx_key][i][j]:
max_dim = max(dim, cls.cache[onnx_key][i][j], cls.min_symbolic_shape)
shape[j] = f"i{i}_dim{j}_{next_power_of_2(max_dim)}"
changed = True
elif isinstance(cls.cache[onnx_key][i][j], str):
Expand All @@ -67,13 +90,12 @@ def get_shape(cls, onnx_key: int, shapes: List[List[int]]) -> List[List[Union[in
return cls.cache[onnx_key]


def _gen_key(onnx_key: int, onnx_str: bytes, shapes: List[List[Union[int, str]]]) -> int:
def _gen_key(onnx_key: int, model: ModelProto, shapes: List[List[Union[int, str]]]) -> int:
# pylint: disable=unused-argument
return hash(f"{onnx_key}|{str(shapes).replace(' ', '')}")


def _gen_module(onnx_key: int, onnx_str: bytes, shapes: List[List[Union[int, str]]]) -> Tuple[str, ModuleType]:
model = onnx.load_model_from_string(onnx_str)
def _gen_module(onnx_key: int, model: ModelProto, shapes: List[List[Union[int, str]]]) -> Tuple[str, ModuleType]:
sorted_graph = SortedGraph(model, [parse_shape(shape) for shape in shapes])
if _DEBUG_MODE:
os.makedirs(os.path.dirname("triton_debug/"), exist_ok=True)
Expand All @@ -96,14 +118,28 @@ def get_config() -> str:
"scalar": only related scalar initializers will be added to subgraphs.
"all": all related initializers will be added to subgraphs.
The min_nodes is used to control the minimum number of non-no-op nodes in a subgraph.
User can also specify symbolic_shape_hint in the config, which is a dict to control the symbolic shape hint.
Each entry is a regex pattern to match the dim_param in ONNX model and the value is the power of 2 for the symbolic
shape. Each dim_param will be replaced by i{input_index}_dim{dim_index}_{power_of_2} in the symbolic shape.
"""

config = dict()
config_file = os.getenv("ORTMODULE_TRITON_CONFIG_FILE", "")
if config_file and os.path.exists(config_file):
with open(config_file, encoding="UTF-8") as f:
return f.read()
config = json.load(f)

if "ops" not in config:
config["ops"] = get_supported_ops()
if "initializer" not in config:
config["initializer"] = "scalar"
if "min_nodes" not in config:
config["min_nodes"] = 2

if "symbolic_shape_hint" in config and len(config["symbolic_shape_hint"]) > 0:
_ShapeCache.set_symbolic_shape_hint(config["symbolic_shape_hint"])
del config["symbolic_shape_hint"]

config = {"ops": get_supported_ops(), "initializer": "scalar", "min_nodes": 2}
return json.dumps(config)


Expand Down Expand Up @@ -136,8 +172,9 @@ def call_triton_by_onnx(onnx_key: int, onnx_str: bytes, *tensors):
assert all(tensor is not None for tensor in tensors)
torch_tensors = [_from_dlpack(tensor) for tensor in tensors]
concrete_shapes = [list(tensor.size()) for tensor in torch_tensors]
shapes = _ShapeCache.get_shape(onnx_key, concrete_shapes)
func_name, mod = ModuleCache.load(_gen_key, _gen_module, onnx_key, onnx_str, shapes)
model = onnx.load_model_from_string(onnx_str)
shapes = _ShapeCache.get_shape(onnx_key, model, concrete_shapes)
func_name, mod = ModuleCache.load(_gen_key, _gen_module, onnx_key, model, shapes)
func = getattr(mod, func_name)
output = func(*torch_tensors)
if isinstance(output, tuple):
Expand Down

0 comments on commit d30c81d

Please sign in to comment.