Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AMD] Added instr.sched guards for the FA-like kernels #5163

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions test/TritonGPU/amd/amd-instruction-sched.mlir
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
// RUN: triton-opt %s -split-input-file -triton-amdgpu-insert-instruction-sched-hints -triton-amdgpu-lower-insert-instruction-sched-hints='variant=iglp0' -verify-diagnostics | FileCheck %s -check-prefix=INSERT_IGLP0
// RUN: triton-opt %s -split-input-file -triton-amdgpu-insert-instruction-sched-hints -triton-amdgpu-lower-insert-instruction-sched-hints='variant=iglp1' -verify-diagnostics | FileCheck %s -check-prefix=INSERT_IGLP1
// RUN: triton-opt %s -split-input-file -convert-triton-to-tritongpu='target=hip:gfx942 num-ctas=1 num-warps=4 threads-per-warp=64' -tritongpu-coalesce -tritonamdgpu-accelerate-matmul='arch-generation-name=gfx942 matrix-instruction-size=32 kPack=1' -tritongpu-remove-layout-conversions -tritonamdgpu-stream-pipeline-v2='num_stages=1' -triton-amdgpu-insert-instruction-sched-hints -decompose-unsupported-amd-conversions -optimize-amd-lds-usage='target-arch=gfx942' -convert-scf-to-cf -convert-index-to-llvm -allocate-shared-memory -convert-triton-amdgpu-to-llvm='arch=gfx942' -verify-diagnostics | FileCheck %s -check-prefix=INSTR_COUNT_NS1
// RUN: triton-opt %s -split-input-file -convert-triton-to-tritongpu='target=hip:gfx942 num-ctas=1 num-warps=4 threads-per-warp=64' -tritongpu-coalesce -tritonamdgpu-accelerate-matmul='arch-generation-name=gfx942 matrix-instruction-size=32 kPack=1' -tritongpu-remove-layout-conversions -tritonamdgpu-stream-pipeline-v2='num_stages=2' -triton-amdgpu-insert-instruction-sched-hints -decompose-unsupported-amd-conversions -optimize-amd-lds-usage='target-arch=gfx942' -convert-scf-to-cf -convert-index-to-llvm -allocate-shared-memory -convert-triton-amdgpu-to-llvm='arch=gfx942' -verify-diagnostics | FileCheck %s -check-prefix=INSTR_COUNT_NS2
// RUN: triton-opt %s -split-input-file -convert-triton-to-tritongpu='target=hip:gfx942 num-ctas=1 num-warps=4 threads-per-warp=64' -tritongpu-coalesce -tritonamdgpu-accelerate-matmul='arch-generation-name=gfx942 matrix-instruction-size=32 kPack=1' -tritongpu-remove-layout-conversions -tritonamdgpu-stream-pipeline-v2='num_stages=2' -triton-amdgpu-insert-instruction-sched-hints -decompose-unsupported-amd-conversions -optimize-amd-lds-usage='target-arch=gfx942' -convert-scf-to-cf -convert-index-to-llvm -allocate-shared-memory -convert-triton-amdgpu-to-llvm='arch=gfx942' -triton-amdgpu-lower-insert-instruction-sched-hints='variant=ck_v3' -debug-only='lower-insert-instruction-sched-hints' -verify-diagnostics 2>&1 | FileCheck %s -check-prefix=USE_CKV3_GLOBAL_LOAD
// RUN: triton-opt %s -split-input-file -triton-amdgpu-insert-instruction-sched-hints='variant=llvm_iglp_0' -triton-amdgpu-lower-insert-instruction-sched-hints -verify-diagnostics | FileCheck %s -check-prefix=INSERT_IGLP0
// RUN: triton-opt %s -split-input-file -triton-amdgpu-insert-instruction-sched-hints='variant=llvm_iglp_1' -triton-amdgpu-lower-insert-instruction-sched-hints -verify-diagnostics | FileCheck %s -check-prefix=INSERT_IGLP1
// RUN: triton-opt %s -split-input-file -convert-triton-to-tritongpu='target=hip:gfx942 num-ctas=1 num-warps=4 threads-per-warp=64' -tritongpu-coalesce -tritonamdgpu-accelerate-matmul='arch-generation-name=gfx942 matrix-instruction-size=32 kPack=1' -tritongpu-remove-layout-conversions -tritonamdgpu-stream-pipeline-v2='num_stages=1' -triton-amdgpu-insert-instruction-sched-hints='variant=guard' -decompose-unsupported-amd-conversions -optimize-amd-lds-usage='target-arch=gfx942' -convert-scf-to-cf -convert-index-to-llvm -allocate-shared-memory -convert-triton-amdgpu-to-llvm='arch=gfx942' -verify-diagnostics | FileCheck %s -check-prefix=INSTR_COUNT_NS1
// RUN: triton-opt %s -split-input-file -convert-triton-to-tritongpu='target=hip:gfx942 num-ctas=1 num-warps=4 threads-per-warp=64' -tritongpu-coalesce -tritonamdgpu-accelerate-matmul='arch-generation-name=gfx942 matrix-instruction-size=32 kPack=1' -tritongpu-remove-layout-conversions -tritonamdgpu-stream-pipeline-v2='num_stages=2' -triton-amdgpu-insert-instruction-sched-hints='variant=guard' -decompose-unsupported-amd-conversions -optimize-amd-lds-usage='target-arch=gfx942' -convert-scf-to-cf -convert-index-to-llvm -allocate-shared-memory -convert-triton-amdgpu-to-llvm='arch=gfx942' -verify-diagnostics | FileCheck %s -check-prefix=INSTR_COUNT_NS2
// RUN: triton-opt %s -split-input-file -convert-triton-to-tritongpu='target=hip:gfx942 num-ctas=1 num-warps=4 threads-per-warp=64' -tritongpu-coalesce -tritonamdgpu-accelerate-matmul='arch-generation-name=gfx942 matrix-instruction-size=32 kPack=1' -tritongpu-remove-layout-conversions -tritonamdgpu-stream-pipeline-v2='num_stages=2' -triton-amdgpu-insert-instruction-sched-hints='variant=local_prefetch' -decompose-unsupported-amd-conversions -optimize-amd-lds-usage='target-arch=gfx942' -convert-scf-to-cf -convert-index-to-llvm -allocate-shared-memory -convert-triton-amdgpu-to-llvm='arch=gfx942' -triton-amdgpu-lower-insert-instruction-sched-hints -debug-only='lower-insert-instruction-sched-hints' -verify-diagnostics 2>&1 | FileCheck %s -check-prefix=USE_LOCAL_PREFETCH_GLOBAL_LOAD
// RUN: triton-opt %s -split-input-file -convert-triton-to-tritongpu='target=hip:gfx942 num-ctas=1 num-warps=4 threads-per-warp=64' -tritongpu-coalesce -tritongpu-remove-layout-conversions -tritonamdgpu-stream-pipeline-v2='num_stages=1' | FileCheck %s -check-prefix=LABELING_PS_1
// RUN: triton-opt %s -split-input-file -convert-triton-to-tritongpu='target=hip:gfx942 num-ctas=1 num-warps=4 threads-per-warp=64' -tritongpu-coalesce -tritongpu-remove-layout-conversions -tritonamdgpu-stream-pipeline-v2='num_stages=2' | FileCheck %s -check-prefix=LABELING_PS_2

Expand Down Expand Up @@ -68,8 +68,8 @@ module {
// INSTR_COUNT_NS2-SAME: numGlobalLoadsB = #amdgpu.InstCounter<4, vector<4xf16>>
// INSTR_COUNT_NS2-SAME: numMMAs = #amdgpu.InstCounter<16, tensor<32x32x8xf16>>

// USE_CKV3_GLOBAL_LOAD: [lower-insert-instruction-sched-hints]
// USE_CKV3_GLOBAL_LOAD-SAME: Skipping instruction scheduling because `ck_v3` scheduling can be used only with `buffer_load` instructions.
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: [lower-insert-instruction-sched-hints]
// USE_LOCAL_PREFETCH_GLOBAL_LOAD-SAME: skipping `local-prefetch` scheduling given it needs `buffer_load` instructions.

// LABELING_PS_1: scf.for
// LABELING_PS_1: %[[REG0_OP0:.+]] = tt.load {{.*}} {OpIdx = #amdgpu.OpIdx<0>}
Expand Down
5 changes: 2 additions & 3 deletions third_party/amd/backend/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ def make_ttgir(mod, metadata, options):
prefetch = os.getenv("TRITON_HIP_STREAM_PREFETCH", "0") == "1"
amd.passes.ttgpuir.add_stream_pipelinev2(pm, options.num_stages, prefetch)
passes.common.add_canonicalizer(pm)
amd.passes.ttgpuir.insert_instruction_sched_hints(pm)
amd.passes.ttgpuir.insert_instruction_sched_hints(pm, options.instruction_sched_variant)
passes.ttgpuir.add_optimize_dot_operands(pm, True)
passes.ttgpuir.add_remove_layout_conversions(pm)
passes.ttgpuir.add_reduce_data_duplication(pm)
Expand Down Expand Up @@ -275,8 +275,7 @@ def make_llir(src, metadata, options):
passes.common.add_canonicalizer(pm)
passes.common.add_cse(pm)
passes.common.add_symbol_dce(pm)
amd.passes.ttgpuir.lower_instruction_sched_hints(pm, options.arch, options.num_stages,
options.instruction_sched_variant)
amd.passes.ttgpuir.lower_instruction_sched_hints(pm, options.arch, options.num_stages)
if os.environ.get("TRITON_DISABLE_LINE_INFO", "0") == "0":
passes.llvmir.add_di_scope(pm)
amd.passes.ttgpuir.add_builtin_func_to_llvmir(pm, __HIP_FTZ)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ add_mlir_doc(TritonAMDGPUOps TritonAMDGPUOps dialects/ -gen-op-doc)
add_public_tablegen_target(TritonAMDGPUTableGen)

set(LLVM_TARGET_DEFINITIONS TritonAMDGPUAttrDefs.td)
mlir_tablegen(TritonAMDGPUEnums.h.inc -gen-enum-decls)
mlir_tablegen(TritonAMDGPUEnums.cpp.inc -gen-enum-defs)
mlir_tablegen(TritonAMDGPUAttrDefs.h.inc -gen-attrdef-decls)
mlir_tablegen(TritonAMDGPUAttrDefs.cpp.inc -gen-attrdef-defs)
add_public_tablegen_target(TritonAMDGPUAttrDefsIncGen)
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@

// clang-format off
#include "amd/include/Dialect/TritonAMDGPU/IR/Dialect.h.inc"
#include "amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUEnums.h.inc"
// clang-format on

#define GET_ATTRDEF_CLASSES
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

include "mlir/IR/AttrTypeBase.td"
include "TritonAMDGPUDialect.td"
include "mlir/IR/EnumAttr.td"

class TritonAMDGPU_Attr<string name, list<Trait> traits = [],
string baseCppClass = "::mlir::Attribute">
Expand Down Expand Up @@ -59,4 +60,35 @@ def TritonAMDGPU_InstCounter : TritonAMDGPU_Attr<"InstCounter"> {
}


class TritonAMDGPU_I32Enum<string name, string description, list<I32EnumAttrCase> cases>
: I32EnumAttr<name, description, cases> {
let genSpecializedAttr = 0;
let cppNamespace = "::mlir::triton::amdgpu";
}

class TritonAMDGPU_I32EnumAttr<string mnemonic, TritonAMDGPU_I32Enum enumInfo> :
EnumAttr<TritonAMDGPU_Dialect, enumInfo, mnemonic> {
let assemblyFormat = "`<` $value `>`";
let cppNamespace = "::mlir::triton::amdgpu";
}

def SchedHintCaseNone : I32EnumAttrCase<"none", 0>;
def SchedHintCaseLLVMIglp0 : I32EnumAttrCase<"llvm_iglp_0", 1>;
def SchedHintCaseLLVMIglp1 : I32EnumAttrCase<"llvm_iglp_1", 2>;
def SchedHintCaseLocalPrefetch : I32EnumAttrCase<"local_prefetch", 3>;
def SchedHintCaseGuard : I32EnumAttrCase<"guard", 4>;

def TritonAMDGPU_SchedHintsEnum : TritonAMDGPU_I32Enum<
"SchedHint", "Instruction Scheduling Hints for AMD GPUs", [
SchedHintCaseNone,
SchedHintCaseLLVMIglp0,
SchedHintCaseLLVMIglp1,
SchedHintCaseLocalPrefetch,
SchedHintCaseGuard
]>;

def TritonAMDGPU_SchedHintVariantAttr :
TritonAMDGPU_I32EnumAttr<"SchedHintVariant", TritonAMDGPU_SchedHintsEnum>;


#endif
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def InstructionSchedHint : TT_AMDGPU_Op<"instruction_sched_hint", []> {
}];

let arguments = (ins
TritonAMDGPU_SchedHintVariantAttr:$schedVariant,
TritonAMDGPU_InstCounter:$numDsReadsA,
TritonAMDGPU_InstCounter:$numDsReadsB,
TritonAMDGPU_InstCounter:$numDsWritesA,
Expand All @@ -70,12 +71,12 @@ def InstructionSchedHint : TT_AMDGPU_Op<"instruction_sched_hint", []> {
);

let builders = [
OpBuilder<(ins), [{
OpBuilder<(ins "SchedHint":$variant), [{
auto ctx = $_state.getContext();
auto noneType = NoneType::get(ctx);
auto emptyAttr = amdgpu::InstCounterAttr::get(ctx, 0, noneType);
build($_builder, $_state, emptyAttr, emptyAttr, emptyAttr, emptyAttr,
emptyAttr, emptyAttr, false, false, emptyAttr);
build($_builder, $_state, variant, emptyAttr, emptyAttr, emptyAttr,
emptyAttr, emptyAttr, emptyAttr, false, false, emptyAttr);
}]>
];

Expand Down
5 changes: 2 additions & 3 deletions third_party/amd/include/TritonAMDGPUToLLVM/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,10 @@ createConvertTritonAMDGPUToLLVMPass(StringRef targetArch, bool ftz);
std::unique_ptr<OperationPass<ModuleOp>>
createConvertBuiltinFuncToLLVMPass(bool ftz);
std::unique_ptr<OperationPass<ModuleOp>>
createTritonAMDGPUInsertInstructionSchedHintsPass();
createTritonAMDGPUInsertInstructionSchedHintsPass(StringRef variant);
std::unique_ptr<OperationPass<ModuleOp>>
createTritonAMDGPULowerInstructionSchedHintsPass(StringRef arch,
int32_t numStages,
StringRef variant);
int32_t numStages);

#define GEN_PASS_REGISTRATION
#include "TritonAMDGPUToLLVM/Passes.h.inc"
Expand Down
11 changes: 7 additions & 4 deletions third_party/amd/include/TritonAMDGPUToLLVM/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -61,15 +61,20 @@ def ConvertBuiltinFuncToLLVM : Pass<"convert-builtin-func-to-llvm", "mlir::Modul

def TritonAMDGPUInsertInstructionSchedHints : Pass<"triton-amdgpu-insert-instruction-sched-hints", "mlir::ModuleOp"> {
let summary = "Insert instruction scheduling hints after the dot ops in the main loop";
let constructor = "mlir::triton::createTritonAMDGPUInsertInstructionSchedHintsPass()";
let constructor = "mlir::triton::createTritonAMDGPUInsertInstructionSchedHintsPass(/*variant=*/\"\")";

let dependentDialects = ["mlir::LLVM::LLVMDialect",
"mlir::triton::amdgpu::TritonAMDGPUDialect"];

let options = [
Option<"variant", "variant", "std::string", /*default*/"\"\"",
"instruction scheduling variant">,
];
}

def TritonAMDGPULowerInstructionSchedHints : Pass<"triton-amdgpu-lower-insert-instruction-sched-hints", "mlir::ModuleOp"> {
let summary = "Lower instruction scheduling hints to LLVM intrinsics";
let constructor = "mlir::triton::createTritonAMDGPULowerInstructionSchedHintsPass(/*arch=*/\"\",/*numStages=*/2, /*variant=*/\"\")";
let constructor = "mlir::triton::createTritonAMDGPULowerInstructionSchedHintsPass(/*arch=*/\"\",/*numStages=*/2)";

let dependentDialects = ["mlir::LLVM::LLVMDialect",
"mlir::ROCDL::ROCDLDialect",
Expand All @@ -80,8 +85,6 @@ def TritonAMDGPULowerInstructionSchedHints : Pass<"triton-amdgpu-lower-insert-in
"gfx target device architecture, e.g., gfx942">,
Option<"numStages", "num_stages", "int32_t", /*default*/"2",
"number of pipeline stages">,
Option<"variant", "variant", "std::string", /*default*/"\"none\"",
"instruction scheduling variant">,
];
}

Expand Down
2 changes: 2 additions & 0 deletions third_party/amd/lib/Dialect/TritonAMDGPU/IR/Dialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ void mlir::triton::amdgpu::TritonAMDGPUDialect::initialize() {
>();
}

#include "Dialect/TritonAMDGPU/IR/TritonAMDGPUEnums.cpp.inc"

#define GET_ATTRDEF_CLASSES
#include "Dialect/TritonAMDGPU/IR/TritonAMDGPUAttrDefs.cpp.inc"

Expand Down
Loading