Skip to content

Commit

Permalink
[AMD][BACKEND] Move BuiltinFuncToLLVM pass to the end of mlir passes… (
Browse files Browse the repository at this point in the history
…#3837)

… and add needed barriers in atomic codegen
  • Loading branch information
zahimoud authored May 5, 2024
1 parent 2c3483f commit 4b94743
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 12 deletions.
8 changes: 4 additions & 4 deletions third_party/amd/backend/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,16 +158,16 @@ def make_llir(src, metadata, options):
passes.convert.add_cf_to_llvmir(pm)
passes.convert.add_arith_to_llvmir(pm)
passes.common.add_canonicalizer(pm)
passes.common.add_cse(pm)
passes.common.add_symbol_dce(pm)
if os.environ.get("TRITON_DISABLE_LINE_INFO", "0") == "0":
passes.llvmir.add_di_scope(pm)
# This pass (`add_builtin_func_to_llvmir`) serves as a temporary workaround to address the issue of excessive basic block
# count caused by predicated loads/stores. In certain kernels, the addition of these blocks can cause the MLIR
# canonicalizer to never finish when attempting to merge blocks. The permanent solution under consideration
# involves using MUBUF instructions that have built-in out-of-bounds checks, which would eliminate the need
# for conditional branching around memory accesses.
amd.passes.ttgpuir.add_builtin_func_to_llvmir(pm)
passes.common.add_cse(pm)
passes.common.add_symbol_dce(pm)
if os.environ.get("TRITON_DISABLE_LINE_INFO", "0") == "0":
passes.llvmir.add_di_scope(pm)
pm.run(mod)

# LLVM-IR (MLIR) -> LLVM-IR (LLVM)
Expand Down
11 changes: 3 additions & 8 deletions third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -336,13 +336,6 @@ struct StoreOpConversion : public ConvertOpToLLVMPattern<triton::StoreOp>,
}
};

namespace {
void createBarrier(ConversionPatternRewriter &rewriter, Location loc,
int numCTAs) {
barrier();
}
} // namespace

static LLVM::AtomicOrdering getMemoryOrdering(MemSemantic memOrdering) {
switch (memOrdering) {
case MemSemantic::RELAXED:
Expand Down Expand Up @@ -463,7 +456,6 @@ struct AtomicCASOpConversion
auto cmpxchg = rewriter.create<LLVM::AtomicCmpXchgOp>(
loc, casPtr, casCmp, casVal, successOrdering, failureOrdering,
StringRef("agent"));

// Extract the new_loaded value from the pair.
Value newLoaded = extract_val(valueElemTy, cmpxchg, 0);

Expand All @@ -479,6 +471,7 @@ struct AtomicCASOpConversion
BuilderMemfenceLDS.launch(rewriter, loc, void_ty(ctx));
barrier();
Value ret = load(valueElemTy, atomPtr);
barrier();
rewriter.replaceOp(op, {ret});
}
}
Expand Down Expand Up @@ -638,7 +631,9 @@ struct AtomicRMWOpConversion
} else {
Value atomPtr = getSharedMemoryBase(loc, rewriter, op.getOperation());
store(retVal, atomPtr);
barrier();
Value ret = load(valueElemTy, atomPtr);
barrier();
rewriter.replaceOp(op, {ret});
}
}
Expand Down

0 comments on commit 4b94743

Please sign in to comment.