Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PyTorch UT regression (1/2) #2579

Open
guangyey opened this issue Oct 28, 2024 · 8 comments
Open

PyTorch UT regression (1/2) #2579

guangyey opened this issue Oct 28, 2024 · 8 comments
Assignees

Comments

@guangyey
Copy link

[TL,DR]
The following case is a PyTorch UT reproducer. It has regression with the commit b6cdccd compared with 91b14bf. It means that buf0 has NaN value with the newer commit b6cdccd.
This impacts PyTorch upstream update triton commit pin, see pytorch/pytorch#137886. The original PyTorch Inductor UT failure log is here
Please help take a look.

# AOT ID: ['0_inference']
from ctypes import c_void_p, c_long, c_int
import torch
import math
import random
import os
import tempfile
from math import inf, nan
from torch._inductor.hooks import run_intermediate_hooks
from torch._inductor.utils import maybe_profile
from torch._inductor.codegen.memory_planning import _align as align
from torch import device, empty_strided
from torch._inductor.async_compile import AsyncCompile
from torch._inductor.select_algorithm import extern_kernels
from torch._inductor.codegen.multi_kernel import MultiKernelCall
import triton
import triton.language as tl
from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid, grid_combo_kernels, start_graph, end_graph
from torch._C import _xpu_getCurrentRawStream as get_raw_stream

aten = torch.ops.aten
inductor_ops = torch.ops.inductor
_quantized = torch.ops._quantized
assert_size_stride = torch._C._dynamo.guards.assert_size_stride
empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu
reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor
alloc_from_pool = torch.ops.inductor._alloc_from_pool
async_compile = AsyncCompile()


# kernel path: /tmp/torchinductor_pt-gpu/lh/clhs7ewi2suufpok3r46q5ad6gbpucosj75rpsv3xk4wix5akgqr.py
# Topologically Sorted Source Nodes: [logcumsumexp], Original ATen: [aten.logcumsumexp]
# Source node to ATen node mapping:
#   logcumsumexp => logcumsumexp
# Graph fragment:
#   %logcumsumexp : [num_users=1] = call_function[target=torch.ops.aten.logcumsumexp.default](args = (%arg2_1, 0), kwargs = {})
triton_red_fused_logcumsumexp_0 = async_compile.triton('triton_red_fused_logcumsumexp_0', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton.jit
def _triton_helper_fn_minimum_maximum_ne_isinf_bitwise_not_bitwise_or_sub_exp_log1p_add_where0(arg0_0, arg1_0):
    tmp0 = triton_helpers.minimum(arg0_0, arg1_0)
    tmp1 = triton_helpers.maximum(arg0_0, arg1_0)
    tmp2 = tmp0 != tmp1
    tmp3 = libdevice.isinf(tmp0).to(tl.int1)
    tmp4 = ~tmp3
    tmp5 = tmp2 | tmp4
    tmp6 = tmp0 - tmp1
    tmp7 = tl_math.exp(tmp6)
    tmp8 = libdevice.log1p(tmp7)
    tmp9 = tmp8 + tmp1
    tmp10 = tl.where(tmp5, tmp9, arg0_0)
    return tmp10

@triton_heuristics.reduction(
    size_hints=[32, 16],
    reduction_hint=ReductionHint.DEFAULT,
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*fp32', 'ks0': 'i32', 'xnumel': 'i32', 'rnumel': 'i32'}, 'device': DeviceProperties(type='xpu', index=0, cc={'driver_version': '1.3.28202', 'gpu_eu_count': 512, 'gpu_subslice_count': 64, 'has_atomic64': True, 'has_bfloat16_conversions': True, 'has_fp16': True, 'has_fp64': True, 'has_subgroup_2d_block_io': False, 'has_subgroup_matrix_multiply_accumulate': True, 'has_subgroup_matrix_multiply_accumulate_tensor_float32': False, 'max_compute_units': 512, 'max_num_sub_groups': 64, 'max_work_group_size': 1024, 'name': 'Intel(R) Data Center GPU Max 1550', 'platform_name': 'Intel(R) oneAPI Unified Runtime over Level-Zero', 'sub_group_sizes': [16, 32], 'total_memory': 68719476736, 'type': 'gpu', 'vendor': 'Intel(R) Corporation', 'version': '12.60.7'}, major=None, regs_per_multiprocessor=None, max_threads_per_multi_processor=None, multi_processor_count=None, warp_size=32), 'constants': {}, 'configs': [AttrsDescriptor.from_dict({'tt.divisibility': (0, 1), 'tt.equal_to': ()})]},
    inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused_logcumsumexp_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': 'CF7F53439E249A073E11F4DED3691C524E608B71CEAA76F5534798E8D20D7BC1', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
)
@triton.jit
def triton_red_fused_logcumsumexp_0(in_ptr0, out_ptr0, ks0, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr):
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
    xmask = xindex < xnumel
    rbase = tl.arange(0, RBLOCK)[None, :]
    x0 = xindex
    tmp3 = tl.full([XBLOCK, 1], float('nan'), tl.float32)
    for roffset in range(0, rnumel, RBLOCK):
        rindex = roffset + rbase
        rmask = rindex < rnumel
        r1 = rindex
        tmp0 = tl.load(in_ptr0 + (x0 + (ks0*r1)), rmask & xmask, eviction_policy='evict_last', other=0.0)
        tmp1 = tmp0.to(tl.float32)
        tmp2 = tl.broadcast_to(tmp1, [XBLOCK, RBLOCK])
        tmp4, = tl.associative_scan((tmp2,), 1, _triton_helper_fn_minimum_maximum_ne_isinf_bitwise_not_bitwise_or_sub_exp_log1p_add_where0)
        tmp5 = triton_helpers.select_one((tmp4), rbase == (RBLOCK - 1), dim=-1, keep_dims=True)
        tmp6 = triton_helpers.minimum(tmp3, tmp5)
        tmp7 = triton_helpers.maximum(tmp3, tmp5)
        tmp8 = tmp6 != tmp7
        tmp9 = libdevice.isinf(tmp6).to(tl.int1)
        tmp10 = ~tmp9
        tmp11 = tmp8 | tmp10
        tmp12 = tmp6 - tmp7
        tmp13 = tl_math.exp(tmp12)
        tmp14 = libdevice.log1p(tmp13)
        tmp15 = tmp14 + tmp7
        tmp16 = tl.where(tmp11, tmp15, tmp3)
        tmp17 = triton_helpers.minimum(tmp3, tmp4)
        tmp18 = triton_helpers.maximum(tmp3, tmp4)
        tmp19 = tmp17 != tmp18
        tmp20 = libdevice.isinf(tmp17).to(tl.int1)
        tmp21 = ~tmp20
        tmp22 = tmp19 | tmp21
        tmp23 = tmp17 - tmp18
        tmp24 = tl_math.exp(tmp23)
        tmp25 = libdevice.log1p(tmp24)
        tmp26 = tmp25 + tmp18
        tmp27 = tl.where(tmp22, tmp26, tmp3)
        tmp28 = tl.where(roffset > 0, tmp27, tmp4)
        tmp3 = tl.where(roffset > 0, tmp16, tmp5)
        tl.store(out_ptr0 + (x0 + (ks0*r1)), tmp28, rmask & xmask)
''', device_str='xpu')


# kernel path: /tmp/torchinductor_pt-gpu/vs/cvs5bktsdarflc6clkplnxcdjmvscx5v7jv5f265zizw4qc25nuf.py
# Topologically Sorted Source Nodes: [logcumsumexp_1], Original ATen: [aten.logcumsumexp]
# Source node to ATen node mapping:
#   logcumsumexp_1 => logcumsumexp_1
# Graph fragment:
#   %logcumsumexp_1 : [num_users=1] = call_function[target=torch.ops.aten.logcumsumexp.default](args = (%arg2_1, 1), kwargs = {})
triton_red_fused_logcumsumexp_1 = async_compile.triton('triton_red_fused_logcumsumexp_1', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton.jit
def _triton_helper_fn_minimum_maximum_ne_isinf_bitwise_not_bitwise_or_sub_exp_log1p_add_where0(arg0_0, arg1_0):
    tmp0 = triton_helpers.minimum(arg0_0, arg1_0)
    tmp1 = triton_helpers.maximum(arg0_0, arg1_0)
    tmp2 = tmp0 != tmp1
    tmp3 = libdevice.isinf(tmp0).to(tl.int1)
    tmp4 = ~tmp3
    tmp5 = tmp2 | tmp4
    tmp6 = tmp0 - tmp1
    tmp7 = tl_math.exp(tmp6)
    tmp8 = libdevice.log1p(tmp7)
    tmp9 = tmp8 + tmp1
    tmp10 = tl.where(tmp5, tmp9, arg0_0)
    return tmp10

@triton_heuristics.reduction(
    size_hints=[16, 32],
    reduction_hint=ReductionHint.INNER,
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*fp32', 'ks0': 'i32', 'xnumel': 'i32', 'rnumel': 'i32'}, 'device': DeviceProperties(type='xpu', index=0, cc={'driver_version': '1.3.28202', 'gpu_eu_count': 512, 'gpu_subslice_count': 64, 'has_atomic64': True, 'has_bfloat16_conversions': True, 'has_fp16': True, 'has_fp64': True, 'has_subgroup_2d_block_io': False, 'has_subgroup_matrix_multiply_accumulate': True, 'has_subgroup_matrix_multiply_accumulate_tensor_float32': False, 'max_compute_units': 512, 'max_num_sub_groups': 64, 'max_work_group_size': 1024, 'name': 'Intel(R) Data Center GPU Max 1550', 'platform_name': 'Intel(R) oneAPI Unified Runtime over Level-Zero', 'sub_group_sizes': [16, 32], 'total_memory': 68719476736, 'type': 'gpu', 'vendor': 'Intel(R) Corporation', 'version': '12.60.7'}, major=None, regs_per_multiprocessor=None, max_threads_per_multi_processor=None, multi_processor_count=None, warp_size=32), 'constants': {}, 'configs': [AttrsDescriptor.from_dict({'tt.divisibility': (0, 1), 'tt.equal_to': ()})]},
    inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused_logcumsumexp_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': 'CF7F53439E249A073E11F4DED3691C524E608B71CEAA76F5534798E8D20D7BC1', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
)
@triton.jit
def triton_red_fused_logcumsumexp_1(in_ptr0, out_ptr0, ks0, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr):
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
    xmask = xindex < xnumel
    rbase = tl.arange(0, RBLOCK)[None, :]
    x0 = xindex
    tmp3 = tl.full([XBLOCK, 1], float('nan'), tl.float32)
    for roffset in range(0, rnumel, RBLOCK):
        rindex = roffset + rbase
        rmask = rindex < rnumel
        r1 = rindex
        tmp0 = tl.load(in_ptr0 + (r1 + (ks0*x0)), rmask & xmask, eviction_policy='evict_first', other=0.0)
        tmp1 = tmp0.to(tl.float32)
        tmp2 = tl.broadcast_to(tmp1, [XBLOCK, RBLOCK])
        tmp4, = tl.associative_scan((tmp2,), 1, _triton_helper_fn_minimum_maximum_ne_isinf_bitwise_not_bitwise_or_sub_exp_log1p_add_where0)
        tmp5 = triton_helpers.select_one((tmp4), rbase == (RBLOCK - 1), dim=-1, keep_dims=True)
        tmp6 = triton_helpers.minimum(tmp3, tmp5)
        tmp7 = triton_helpers.maximum(tmp3, tmp5)
        tmp8 = tmp6 != tmp7
        tmp9 = libdevice.isinf(tmp6).to(tl.int1)
        tmp10 = ~tmp9
        tmp11 = tmp8 | tmp10
        tmp12 = tmp6 - tmp7
        tmp13 = tl_math.exp(tmp12)
        tmp14 = libdevice.log1p(tmp13)
        tmp15 = tmp14 + tmp7
        tmp16 = tl.where(tmp11, tmp15, tmp3)
        tmp17 = triton_helpers.minimum(tmp3, tmp4)
        tmp18 = triton_helpers.maximum(tmp3, tmp4)
        tmp19 = tmp17 != tmp18
        tmp20 = libdevice.isinf(tmp17).to(tl.int1)
        tmp21 = ~tmp20
        tmp22 = tmp19 | tmp21
        tmp23 = tmp17 - tmp18
        tmp24 = tl_math.exp(tmp23)
        tmp25 = libdevice.log1p(tmp24)
        tmp26 = tmp25 + tmp18
        tmp27 = tl.where(tmp22, tmp26, tmp3)
        tmp28 = tl.where(roffset > 0, tmp27, tmp4)
        tmp3 = tl.where(roffset > 0, tmp16, tmp5)
        tl.store(out_ptr0 + (r1 + (ks0*x0)), tmp28, rmask & xmask)
''', device_str='xpu')


async_compile.wait(globals())
del async_compile

def call(args):
    arg0_1, arg1_1, arg2_1 = args
    args.clear()
    s1 = arg0_1
    s2 = arg1_1
    assert_size_stride(arg2_1, (s1, s2), (s2, 1))
    with torch.xpu._DeviceGuard(0):
        torch.xpu.set_device(0)
        buf0 = empty_strided_xpu((s1, s2), (s2, 1), torch.float32)
        # Topologically Sorted Source Nodes: [logcumsumexp], Original ATen: [aten.logcumsumexp]
        stream0 = get_raw_stream(0)
        triton_red_fused_logcumsumexp_0.run(arg2_1, buf0, s2, s2, s1, grid=grid(s2), stream=stream0)
        buf1 = empty_strided_xpu((s1, s2), (s2, 1), torch.float32)
        # Topologically Sorted Source Nodes: [logcumsumexp_1], Original ATen: [aten.logcumsumexp]
        triton_red_fused_logcumsumexp_1.run(arg2_1, buf1, s2, s1, s2, grid=grid(s1), stream=stream0)
        del arg2_1
    # buf0 has NaN
    print(buf0, )
    return (buf0, buf1, )


def benchmark_compiled_module(times=10, repeat=10):
    from torch._dynamo.testing import rand_strided
    from torch._inductor.utils import print_performance
    arg0_1 = 16
    arg1_1 = 32
    arg2_1 = rand_strided((16, 32), (32, 1), device='xpu:0', dtype=torch.float32)
    fn = lambda: call([arg0_1, arg1_1, arg2_1])
    return print_performance(fn, times=times, repeat=repeat)


if __name__ == "__main__":
    from torch._inductor.wrapper_benchmark import compiled_module_main
    compiled_module_main('None', benchmark_compiled_module)
@alexbaden
Copy link
Contributor

bisected to this commit:

d651a8444fbf096253b2bc090cc6039be29784a8 is the first bad commit
commit d651a8444fbf096253b2bc090cc6039be29784a8
Author: Whitney Tsang <[email protected]>
Date:   Wed Oct 2 16:09:49 2024 -0400

    More code refactoring and sync from upstream (#2409)
    
    Signed-off-by: Whitney Tsang <[email protected]>

 .../lib/TritonIntelGPUToLLVM/PipelineManager.h     |   2 +
 .../lib/TritonIntelGPUToLLVM/ReduceOpToLLVM.cpp    |  52 ++++------
 .../lib/TritonIntelGPUToLLVM/ReduceScanCommon.h    | 105 +++++++++++++++++--
 .../lib/TritonIntelGPUToLLVM/SPMDOpToLLVM.cpp      |  35 +------
 .../lib/TritonIntelGPUToLLVM/ScanOpToLLVM.cpp      | 114 +++++++++------------
 .../lib/TritonIntelGPUToLLVM/ViewOpToLLVM.cpp      |  28 +++--
 6 files changed, 186 insertions(+), 150 deletions(-)

https://github.com/intel/intel-xpu-backend-for-triton/tree/d651a8444fbf096253b2bc090cc6039be29784a8

As you can tell, this change is from upstream. @guangyey do you know if upstream Triton and upstream PyTorch have the same issue?

@guangyey
Copy link
Author

The upstream PyTorch uses commit cf34004b8a67d290a962da166f5aa2fc66751326 as its CI commit pin. I looked through the commit history, and the bad commit seems not to be included.

@guangyey guangyey changed the title PyTorch UT regression PyTorch UT regression (1/2) Oct 29, 2024
@arunjose696
Copy link
Contributor

@guangyey
I tried the reproducer in your first comment with upstream PyTorch and triton upstream, I can see nans, However I cant have the reproducer you attached work with pytorch 2.5 released version, to check see what changed. It gives me a different error with pytorch 2.5.

I installed pytorch with

pip3 install torch==2.5.1 triton numpy matplotlib pandas --index-url https://download.pytorch.org/whl/test/xpu

When I run your reproducer I get the below error.(Code fails at line 45 triton_red_fused_logcumsumexp_0 = async_compile.triton(..))

Error
attributeError: module 'torch._inductor.runtime.triton_helpers' has no attribute 'set_driver_to_gpu'

Could you see if your reproducer works with torch==2.5.1. at your end?

@guangyey
Copy link
Author

guangyey commented Oct 31, 2024

@arunjose696 It couldn't be reproduced with the 2.5.1 release branch because of the changing internal API in inductor. You have to use main branch with these different triton commits.

@guangyey
Copy link
Author

guangyey commented Nov 5, 2024

May I know if there is any update?

@alexbaden
Copy link
Contributor

This was both caused by and fixed by upstream:
triton-lang/triton#5033
triton-lang/triton#5075

(see corresponding issue in the inductor release thread: pytorch/pytorch#139348

I have asked @anmyachev and @whitneywhtsang to prioritize syncing the relevant commits.

@whitneywhtsang
Copy link
Contributor

This was both caused by and fixed by upstream: triton-lang/triton#5033 triton-lang/triton#5075

(see corresponding issue in the inductor release thread: pytorch/pytorch#139348

I have asked @anmyachev and @whitneywhtsang to prioritize syncing the relevant commits.

The two corresponding commits are now merged in our Intel repo.

@alexbaden
Copy link
Contributor

I have verified that this tests passes:

» python ../pytorch/test/inductor/test_torchinductor_codegen_dynamic_shapes.py -k test_logcumsumexp_dynamic_shapes_xpu                                                                                                                    
[WARNING] Failed to create Level Zero tracer: 2013265921
[WARNING] Failed to create Level Zero tracer: 2013265921
/localdisk/abaden/Projects/intel-xpu-backend-for-triton/python/triton/testing.py:31: DeprecationWarning: The 'warn' function is deprecated, use 'warning' instead
  logging.warn("Wall time is used instead of elapsed_time (not supported). "
WARNING:root:Wall time is used instead of elapsed_time (not supported). The timing measurements could be innacurate.
/localdisk/abaden/Projects/pytorch/torch/utils/_config_module.py:321: UserWarning: Skipping serialization of skipfiles_inline_module_allowlist value {}
  warnings.warn(
/localdisk/abaden/Projects/pytorch/torch/utils/_config_module.py:321: UserWarning: Skipping serialization of skipfiles_inline_module_allowlist value {}
  warnings.warn(
inline_call []
stats [('calls_captured', 6), ('unique_graphs', 3)]
inductor [('benchmarking.TritonBenchmarker.benchmark_gpu', 13), ('fxgraph_cache_miss', 3), ('benchmarking.TritonBenchmarker.triton_do_bench', 1)]
aot_autograd [('total', 3), ('ok', 3)]
.
----------------------------------------------------------------------
Ran 1 test in 15.891s

OK

with Triton commit 0d9c0d3 and PyTorch commit pytorch/pytorch@78a8f7f. Based on that, I believe this ticket can be closed.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

5 participants