Skip to content

Commit

Permalink
[4/N] Non-Tensor: Support layout, device and dtype for aten operations (
Browse files Browse the repository at this point in the history
  • Loading branch information
EikanWang authored and pytorchmergebot committed Jul 23, 2024
1 parent 68c725a commit 3fe72e0
Show file tree
Hide file tree
Showing 5 changed files with 251 additions and 42 deletions.
41 changes: 41 additions & 0 deletions test/inductor/test_torchinductor.py
Original file line number Diff line number Diff line change
Expand Up @@ -802,6 +802,47 @@ def fn(a, b):
),
)

@skipCUDAIf(not SM80OrLater, "Requires sm80")
@skip_if_halide # aoti
def test_aoti_eager_dtype_device_layout(self):
ns = "aten"
op_name = "tril_indices"
dispatch_key = "CPU"
device = "cpu"
if self.device.lower() == "cuda":
dispatch_key = "CUDA"
device = "cuda"

with _scoped_library("aten", "IMPL") as torch_compile_op_lib_impl:
row = 128
col = 256
offset = 1
dtype = torch.int32
layout = torch.strided
pin_memory = False
ref = torch.tril_indices(
row=row,
col=col,
offset=offset,
dtype=dtype,
layout=layout,
pin_memory=pin_memory,
device=device,
)
register_ops_with_aoti_compile(
ns, [op_name], dispatch_key, torch_compile_op_lib_impl
)
res = torch.tril_indices(
row=row,
col=col,
offset=offset,
dtype=dtype,
layout=layout,
pin_memory=pin_memory,
device=device,
)
self.assertEqual(ref, res)

@skipCUDAIf(not SM80OrLater, "Requires sm80")
@skip_if_halide # aoti
def test_aoti_eager_support_out(self):
Expand Down
124 changes: 94 additions & 30 deletions torch/_inductor/aoti_eager.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import json
import logging
import os
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Tuple
Expand All @@ -11,6 +12,9 @@
from .runtime.runtime_utils import cache_dir


log = logging.getLogger(__name__)


def aoti_eager_cache_dir(namespace: str, device: str) -> Path:
return Path(cache_dir()) / "aoti_eager" / namespace / device

Expand All @@ -34,29 +38,48 @@ def load_aoti_eager_cache(
if not op_conf.exists():
return []

with aoti_eager_op_conf_lock(op_func_name_with_overload):
with open(op_conf) as f:
json_data = json.load(f)
for item in json_data:
# Get absolution path for kernel library
kernel_lib_abs_path = device_kernel_cache / item["kernel_path"]
item["kernel_path"] = kernel_lib_abs_path.as_posix()

# Check if the kernel library exists
if not kernel_lib_abs_path.exists():
return []

for metadata in item["meta_info"]:
if metadata.get("is_dynamic"):
raise NotImplementedError("Only support static shape for now")
if "device_type" in metadata and metadata["device_type"] == "cpu":
metadata["device_index"] = -1
if "dtype" in metadata:
metadata["dtype"] = getattr(
torch, metadata["dtype"].split(".")[-1]
)

return json_data
try:
with aoti_eager_op_conf_lock(op_func_name_with_overload):
with open(op_conf) as f:
json_data = json.load(f)
for item in json_data:
# Get absolution path for kernel library
kernel_lib_abs_path = device_kernel_cache / item["kernel_path"]
item["kernel_path"] = kernel_lib_abs_path.as_posix()

# Check if the kernel library exists
if not kernel_lib_abs_path.exists():
return []

for metadata in item["meta_info"]:
if metadata.get("is_dynamic"):
raise NotImplementedError(
"Only support static shape for now"
)
if (
"device_type" in metadata
and metadata["device_type"] == "cpu"
):
metadata["device_index"] = -1
for dtype_key in ["dtype", "dtype_value"]:
if dtype_key in metadata:
metadata[dtype_key] = getattr(
torch, metadata[dtype_key].split(".")[-1]
)
if "layout_value" in metadata:
metadata["layout_value"] = getattr(
torch, metadata["layout_value"].split(".")[-1]
)
if "memory_format_value" in metadata:
metadata["memory_format_value"] = getattr(
torch, metadata["memory_format_value"].split(".")[-1]
)

return json_data
except Exception as e:
err_msg = f"Failed to load aoti eager cache: {e}"
log.exception(err_msg)
return []


def supported_builtin_dtype_torch_dtype() -> Dict[type, torch.dtype]:
Expand Down Expand Up @@ -120,6 +143,28 @@ def extract_string_metadata(input: str) -> Dict[str, Any]:
return metadata


def extract_dtype_metadata(input: torch.dtype) -> Dict[str, Any]:
assert isinstance(input, torch.dtype)
metadata: Dict[str, Any] = {}
metadata["dtype_value"] = f"{input}"
return metadata


def extract_device_metadata(input: torch.device) -> Dict[str, Any]:
assert isinstance(input, torch.device)
metadata: Dict[str, Any] = {}
metadata["device_type_value"] = f"{input.type}"
metadata["device_index_value"] = input.index
return metadata


def extract_layout_metadata(input: torch.layout) -> Dict[str, Any]:
assert isinstance(input, torch.layout)
metadata: Dict[str, Any] = {}
metadata["layout_value"] = f"{input}"
return metadata


def aoti_compile_with_persistent_cache(
ns: str,
op_func_name_with_overload: str,
Expand All @@ -140,20 +185,31 @@ def aoti_compile_with_persistent_cache(
assert not dynamic, "Only support static shape for now"
flattened_inputs = list(args) + list(kwargs.values())
if not all(
isinstance(input, (supported_scalar_types(), torch.Tensor, list, str))
isinstance(
input,
(
supported_scalar_types(),
torch.Tensor,
list,
str,
torch.dtype,
torch.device,
torch.layout,
),
)
for input in flattened_inputs
):
raise NotImplementedError(
"Only support tensor, tensor list, int, float, bool for now"
)
err_msg = f"Unsupported input types: {flattened_inputs}"
log.exception(err_msg)
raise NotImplementedError(err_msg)

for input in flattened_inputs:
if isinstance(input, list) and not all(
isinstance(item, torch.Tensor) for item in input
):
raise NotImplementedError(
"Regarding list, _impl_with_aoti_compile only support tensor list now."
)
err_msg = f"_impl_with_aoti_compile encounters unsupported input types: {flattened_inputs}"
log.exception(err_msg)
raise NotImplementedError(err_msg)

persistent_cache = aoti_eager_cache_dir(ns, device_type)
if not persistent_cache.exists():
Expand Down Expand Up @@ -193,6 +249,12 @@ def aoti_compile_with_persistent_cache(
metadata = extract_scalar_metadata(device_type, input)
elif isinstance(input, str):
metadata = extract_string_metadata(input)
elif isinstance(input, torch.dtype):
metadata = extract_dtype_metadata(input)
elif isinstance(input, torch.device):
metadata = extract_device_metadata(input)
elif isinstance(input, torch.layout):
metadata = extract_layout_metadata(input)
else:
raise NotImplementedError(f"Unsupported input type: {type(input)}")

Expand Down Expand Up @@ -231,4 +293,6 @@ def aoti_compile_with_persistent_cache(

return kernel_lib_path
except Exception as e:
err_msg = f"Failed to compile {op_func_name_with_overload}: {e}"
log.exception(err_msg)
return ""
Loading

0 comments on commit 3fe72e0

Please sign in to comment.