Skip to content

Commit

Permalink
[AMD] Added instr.sched guards for the FA-like kernels
Browse files Browse the repository at this point in the history
  • Loading branch information
ravil-mobile committed Nov 15, 2024
1 parent 3c189dd commit 2f4d48f
Show file tree
Hide file tree
Showing 13 changed files with 159 additions and 82 deletions.
14 changes: 7 additions & 7 deletions test/TritonGPU/amd/amd-instruction-sched.mlir
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
// RUN: triton-opt %s -split-input-file -triton-amdgpu-insert-instruction-sched-hints -triton-amdgpu-lower-insert-instruction-sched-hints='variant=iglp0' -verify-diagnostics | FileCheck %s -check-prefix=INSERT_IGLP0
// RUN: triton-opt %s -split-input-file -triton-amdgpu-insert-instruction-sched-hints -triton-amdgpu-lower-insert-instruction-sched-hints='variant=iglp1' -verify-diagnostics | FileCheck %s -check-prefix=INSERT_IGLP1
// RUN: triton-opt %s -split-input-file -convert-triton-to-tritongpu='target=hip:gfx942 num-ctas=1 num-warps=4 threads-per-warp=64' -tritongpu-coalesce -tritonamdgpu-accelerate-matmul='arch-generation-name=gfx942 matrix-instruction-size=32 kPack=1' -tritongpu-remove-layout-conversions -tritonamdgpu-stream-pipeline-v2='num_stages=1' -triton-amdgpu-insert-instruction-sched-hints -decompose-unsupported-amd-conversions -optimize-amd-lds-usage='target-arch=gfx942' -convert-scf-to-cf -convert-index-to-llvm -allocate-shared-memory -convert-triton-amdgpu-to-llvm='arch=gfx942' -verify-diagnostics | FileCheck %s -check-prefix=INSTR_COUNT_NS1
// RUN: triton-opt %s -split-input-file -convert-triton-to-tritongpu='target=hip:gfx942 num-ctas=1 num-warps=4 threads-per-warp=64' -tritongpu-coalesce -tritonamdgpu-accelerate-matmul='arch-generation-name=gfx942 matrix-instruction-size=32 kPack=1' -tritongpu-remove-layout-conversions -tritonamdgpu-stream-pipeline-v2='num_stages=2' -triton-amdgpu-insert-instruction-sched-hints -decompose-unsupported-amd-conversions -optimize-amd-lds-usage='target-arch=gfx942' -convert-scf-to-cf -convert-index-to-llvm -allocate-shared-memory -convert-triton-amdgpu-to-llvm='arch=gfx942' -verify-diagnostics | FileCheck %s -check-prefix=INSTR_COUNT_NS2
// RUN: triton-opt %s -split-input-file -convert-triton-to-tritongpu='target=hip:gfx942 num-ctas=1 num-warps=4 threads-per-warp=64' -tritongpu-coalesce -tritonamdgpu-accelerate-matmul='arch-generation-name=gfx942 matrix-instruction-size=32 kPack=1' -tritongpu-remove-layout-conversions -tritonamdgpu-stream-pipeline-v2='num_stages=2' -triton-amdgpu-insert-instruction-sched-hints -decompose-unsupported-amd-conversions -optimize-amd-lds-usage='target-arch=gfx942' -convert-scf-to-cf -convert-index-to-llvm -allocate-shared-memory -convert-triton-amdgpu-to-llvm='arch=gfx942' -triton-amdgpu-lower-insert-instruction-sched-hints='variant=ck_v3' -debug-only='lower-insert-instruction-sched-hints' -verify-diagnostics 2>&1 | FileCheck %s -check-prefix=USE_CKV3_GLOBAL_LOAD
// RUN: triton-opt %s -split-input-file -triton-amdgpu-insert-instruction-sched-hints='variant=llvm_iglp_0' -triton-amdgpu-lower-insert-instruction-sched-hints -verify-diagnostics | FileCheck %s -check-prefix=INSERT_IGLP0
// RUN: triton-opt %s -split-input-file -triton-amdgpu-insert-instruction-sched-hints='variant=llvm_iglp_1' -triton-amdgpu-lower-insert-instruction-sched-hints -verify-diagnostics | FileCheck %s -check-prefix=INSERT_IGLP1
// RUN: triton-opt %s -split-input-file -convert-triton-to-tritongpu='target=hip:gfx942 num-ctas=1 num-warps=4 threads-per-warp=64' -tritongpu-coalesce -tritonamdgpu-accelerate-matmul='arch-generation-name=gfx942 matrix-instruction-size=32 kPack=1' -tritongpu-remove-layout-conversions -tritonamdgpu-stream-pipeline-v2='num_stages=1' -triton-amdgpu-insert-instruction-sched-hints='variant=guard' -decompose-unsupported-amd-conversions -optimize-amd-lds-usage='target-arch=gfx942' -convert-scf-to-cf -convert-index-to-llvm -allocate-shared-memory -convert-triton-amdgpu-to-llvm='arch=gfx942' -verify-diagnostics | FileCheck %s -check-prefix=INSTR_COUNT_NS1
// RUN: triton-opt %s -split-input-file -convert-triton-to-tritongpu='target=hip:gfx942 num-ctas=1 num-warps=4 threads-per-warp=64' -tritongpu-coalesce -tritonamdgpu-accelerate-matmul='arch-generation-name=gfx942 matrix-instruction-size=32 kPack=1' -tritongpu-remove-layout-conversions -tritonamdgpu-stream-pipeline-v2='num_stages=2' -triton-amdgpu-insert-instruction-sched-hints='variant=guard' -decompose-unsupported-amd-conversions -optimize-amd-lds-usage='target-arch=gfx942' -convert-scf-to-cf -convert-index-to-llvm -allocate-shared-memory -convert-triton-amdgpu-to-llvm='arch=gfx942' -verify-diagnostics | FileCheck %s -check-prefix=INSTR_COUNT_NS2
// RUN: triton-opt %s -split-input-file -convert-triton-to-tritongpu='target=hip:gfx942 num-ctas=1 num-warps=4 threads-per-warp=64' -tritongpu-coalesce -tritonamdgpu-accelerate-matmul='arch-generation-name=gfx942 matrix-instruction-size=32 kPack=1' -tritongpu-remove-layout-conversions -tritonamdgpu-stream-pipeline-v2='num_stages=2' -triton-amdgpu-insert-instruction-sched-hints='variant=local_prefetch' -decompose-unsupported-amd-conversions -optimize-amd-lds-usage='target-arch=gfx942' -convert-scf-to-cf -convert-index-to-llvm -allocate-shared-memory -convert-triton-amdgpu-to-llvm='arch=gfx942' -triton-amdgpu-lower-insert-instruction-sched-hints -debug-only='lower-insert-instruction-sched-hints' -verify-diagnostics 2>&1 | FileCheck %s -check-prefix=USE_LOCAL_PREFETCH_GLOBAL_LOAD
// RUN: triton-opt %s -split-input-file -convert-triton-to-tritongpu='target=hip:gfx942 num-ctas=1 num-warps=4 threads-per-warp=64' -tritongpu-coalesce -tritongpu-remove-layout-conversions -tritonamdgpu-stream-pipeline-v2='num_stages=1' | FileCheck %s -check-prefix=LABELING_PS_1
// RUN: triton-opt %s -split-input-file -convert-triton-to-tritongpu='target=hip:gfx942 num-ctas=1 num-warps=4 threads-per-warp=64' -tritongpu-coalesce -tritongpu-remove-layout-conversions -tritonamdgpu-stream-pipeline-v2='num_stages=2' | FileCheck %s -check-prefix=LABELING_PS_2

Expand Down Expand Up @@ -68,8 +68,8 @@ module {
// INSTR_COUNT_NS2-SAME: numGlobalLoadsB = #amdgpu.InstCounter<4, vector<4xf16>>
// INSTR_COUNT_NS2-SAME: numMMAs = #amdgpu.InstCounter<16, tensor<32x32x8xf16>>

// USE_CKV3_GLOBAL_LOAD: [lower-insert-instruction-sched-hints]
// USE_CKV3_GLOBAL_LOAD-SAME: Skipping instruction scheduling because `ck_v3` scheduling can be used only with `buffer_load` instructions.
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: [lower-insert-instruction-sched-hints]
// USE_LOCAL_PREFETCH_GLOBAL_LOAD-SAME: skipping `local-prefetch` scheduling given it needs `buffer_load` instructions.

// LABELING_PS_1: scf.for
// LABELING_PS_1: %[[REG0_OP0:.+]] = tt.load {{.*}} {OpIdx = #amdgpu.OpIdx<0>}
Expand Down
5 changes: 2 additions & 3 deletions third_party/amd/backend/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ def make_ttgir(mod, metadata, options):
prefetch = os.getenv("TRITON_HIP_STREAM_PREFETCH", "0") == "1"
amd.passes.ttgpuir.add_stream_pipelinev2(pm, options.num_stages, prefetch)
passes.common.add_canonicalizer(pm)
amd.passes.ttgpuir.insert_instruction_sched_hints(pm)
amd.passes.ttgpuir.insert_instruction_sched_hints(pm, options.instruction_sched_variant)
passes.ttgpuir.add_optimize_dot_operands(pm, True)
passes.ttgpuir.add_remove_layout_conversions(pm)
passes.ttgpuir.add_reduce_data_duplication(pm)
Expand Down Expand Up @@ -275,8 +275,7 @@ def make_llir(src, metadata, options):
passes.common.add_canonicalizer(pm)
passes.common.add_cse(pm)
passes.common.add_symbol_dce(pm)
amd.passes.ttgpuir.lower_instruction_sched_hints(pm, options.arch, options.num_stages,
options.instruction_sched_variant)
amd.passes.ttgpuir.lower_instruction_sched_hints(pm, options.arch, options.num_stages)
if os.environ.get("TRITON_DISABLE_LINE_INFO", "0") == "0":
passes.llvmir.add_di_scope(pm)
amd.passes.ttgpuir.add_builtin_func_to_llvmir(pm, __HIP_FTZ)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ add_mlir_doc(TritonAMDGPUOps TritonAMDGPUOps dialects/ -gen-op-doc)
add_public_tablegen_target(TritonAMDGPUTableGen)

set(LLVM_TARGET_DEFINITIONS TritonAMDGPUAttrDefs.td)
mlir_tablegen(TritonAMDGPUEnums.h.inc -gen-enum-decls)
mlir_tablegen(TritonAMDGPUEnums.cpp.inc -gen-enum-defs)
mlir_tablegen(TritonAMDGPUAttrDefs.h.inc -gen-attrdef-decls)
mlir_tablegen(TritonAMDGPUAttrDefs.cpp.inc -gen-attrdef-defs)
add_public_tablegen_target(TritonAMDGPUAttrDefsIncGen)
1 change: 1 addition & 0 deletions third_party/amd/include/Dialect/TritonAMDGPU/IR/Dialect.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@

// clang-format off
#include "amd/include/Dialect/TritonAMDGPU/IR/Dialect.h.inc"
#include "amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUEnums.h.inc"
// clang-format on

#define GET_ATTRDEF_CLASSES
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

include "mlir/IR/AttrTypeBase.td"
include "TritonAMDGPUDialect.td"
include "mlir/IR/EnumAttr.td"

class TritonAMDGPU_Attr<string name, list<Trait> traits = [],
string baseCppClass = "::mlir::Attribute">
Expand Down Expand Up @@ -59,4 +60,35 @@ def TritonAMDGPU_InstCounter : TritonAMDGPU_Attr<"InstCounter"> {
}


class TritonAMDGPU_I32Enum<string name, string description, list<I32EnumAttrCase> cases>
: I32EnumAttr<name, description, cases> {
let genSpecializedAttr = 0;
let cppNamespace = "::mlir::triton::amdgpu";
}

class TritonAMDGPU_I32EnumAttr<string mnemonic, TritonAMDGPU_I32Enum enumInfo> :
EnumAttr<TritonAMDGPU_Dialect, enumInfo, mnemonic> {
let assemblyFormat = "`<` $value `>`";
let cppNamespace = "::mlir::triton::amdgpu";
}

def SchedHintCaseNone : I32EnumAttrCase<"none", 0>;
def SchedHintCaseLLVMIglp0 : I32EnumAttrCase<"llvm_iglp_0", 1>;
def SchedHintCaseLLVMIglp1 : I32EnumAttrCase<"llvm_iglp_1", 2>;
def SchedHintCaseLocalPrefetch : I32EnumAttrCase<"local_prefetch", 3>;
def SchedHintCaseGuard : I32EnumAttrCase<"guard", 4>;

def TritonAMDGPU_SchedHintsEnum : TritonAMDGPU_I32Enum<
"SchedHint", "Instruction Scheduling Hints for AMD GPUs", [
SchedHintCaseNone,
SchedHintCaseLLVMIglp0,
SchedHintCaseLLVMIglp1,
SchedHintCaseLocalPrefetch,
SchedHintCaseGuard
]>;

def TritonAMDGPU_SchedHintVariantAttr :
TritonAMDGPU_I32EnumAttr<"SchedHintVariant", TritonAMDGPU_SchedHintsEnum>;


#endif
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def InstructionSchedHint : TT_AMDGPU_Op<"instruction_sched_hint", []> {
}];

let arguments = (ins
TritonAMDGPU_SchedHintVariantAttr:$schedVariant,
TritonAMDGPU_InstCounter:$numDsReadsA,
TritonAMDGPU_InstCounter:$numDsReadsB,
TritonAMDGPU_InstCounter:$numDsWritesA,
Expand All @@ -70,12 +71,12 @@ def InstructionSchedHint : TT_AMDGPU_Op<"instruction_sched_hint", []> {
);

let builders = [
OpBuilder<(ins), [{
OpBuilder<(ins "SchedHint":$variant), [{
auto ctx = $_state.getContext();
auto noneType = NoneType::get(ctx);
auto emptyAttr = amdgpu::InstCounterAttr::get(ctx, 0, noneType);
build($_builder, $_state, emptyAttr, emptyAttr, emptyAttr, emptyAttr,
emptyAttr, emptyAttr, false, false, emptyAttr);
build($_builder, $_state, variant, emptyAttr, emptyAttr, emptyAttr,
emptyAttr, emptyAttr, emptyAttr, false, false, emptyAttr);
}]>
];

Expand Down
5 changes: 2 additions & 3 deletions third_party/amd/include/TritonAMDGPUToLLVM/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,10 @@ createConvertTritonAMDGPUToLLVMPass(StringRef targetArch, bool ftz);
std::unique_ptr<OperationPass<ModuleOp>>
createConvertBuiltinFuncToLLVMPass(bool ftz);
std::unique_ptr<OperationPass<ModuleOp>>
createTritonAMDGPUInsertInstructionSchedHintsPass();
createTritonAMDGPUInsertInstructionSchedHintsPass(StringRef variant);
std::unique_ptr<OperationPass<ModuleOp>>
createTritonAMDGPULowerInstructionSchedHintsPass(StringRef arch,
int32_t numStages,
StringRef variant);
int32_t numStages);

#define GEN_PASS_REGISTRATION
#include "TritonAMDGPUToLLVM/Passes.h.inc"
Expand Down
11 changes: 7 additions & 4 deletions third_party/amd/include/TritonAMDGPUToLLVM/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -61,15 +61,20 @@ def ConvertBuiltinFuncToLLVM : Pass<"convert-builtin-func-to-llvm", "mlir::Modul

def TritonAMDGPUInsertInstructionSchedHints : Pass<"triton-amdgpu-insert-instruction-sched-hints", "mlir::ModuleOp"> {
let summary = "Insert instruction scheduling hints after the dot ops in the main loop";
let constructor = "mlir::triton::createTritonAMDGPUInsertInstructionSchedHintsPass()";
let constructor = "mlir::triton::createTritonAMDGPUInsertInstructionSchedHintsPass(/*variant=*/\"\")";

let dependentDialects = ["mlir::LLVM::LLVMDialect",
"mlir::triton::amdgpu::TritonAMDGPUDialect"];

let options = [
Option<"variant", "variant", "std::string", /*default*/"\"\"",
"instruction scheduling variant">,
];
}

def TritonAMDGPULowerInstructionSchedHints : Pass<"triton-amdgpu-lower-insert-instruction-sched-hints", "mlir::ModuleOp"> {
let summary = "Lower instruction scheduling hints to LLVM intrinsics";
let constructor = "mlir::triton::createTritonAMDGPULowerInstructionSchedHintsPass(/*arch=*/\"\",/*numStages=*/2, /*variant=*/\"\")";
let constructor = "mlir::triton::createTritonAMDGPULowerInstructionSchedHintsPass(/*arch=*/\"\",/*numStages=*/2)";

let dependentDialects = ["mlir::LLVM::LLVMDialect",
"mlir::ROCDL::ROCDLDialect",
Expand All @@ -80,8 +85,6 @@ def TritonAMDGPULowerInstructionSchedHints : Pass<"triton-amdgpu-lower-insert-in
"gfx target device architecture, e.g., gfx942">,
Option<"numStages", "num_stages", "int32_t", /*default*/"2",
"number of pipeline stages">,
Option<"variant", "variant", "std::string", /*default*/"\"none\"",
"instruction scheduling variant">,
];
}

Expand Down
2 changes: 2 additions & 0 deletions third_party/amd/lib/Dialect/TritonAMDGPU/IR/Dialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ void mlir::triton::amdgpu::TritonAMDGPUDialect::initialize() {
>();
}

#include "Dialect/TritonAMDGPU/IR/TritonAMDGPUEnums.cpp.inc"

#define GET_ATTRDEF_CLASSES
#include "Dialect/TritonAMDGPU/IR/TritonAMDGPUAttrDefs.cpp.inc"

Expand Down
Loading

0 comments on commit 2f4d48f

Please sign in to comment.