From d527c3f47b395eb886c0d9980b458d9d1735da01 Mon Sep 17 00:00:00 2001 From: Thomas Raoux Date: Tue, 28 May 2024 20:08:27 -0700 Subject: [PATCH] [BACKEND] Add memory space to memdesc type. (#4027) Currently only shared memory is supported but this will allow supporting different kinds of local memory (like private) or others. --- .../triton/Dialect/Triton/IR/TritonTypes.td | 11 +- .../Dialect/TritonGPU/IR/TritonGPUAttrDefs.td | 6 + .../Dialect/TritonGPU/IR/TritonGPUOps.td | 6 + lib/Analysis/Alias.cpp | 7 +- lib/Analysis/Allocation.cpp | 3 +- .../DecomposeUnsupportedConversions.cpp | 5 +- .../TritonGPUToLLVM/MemoryOpToLLVM.cpp | 2 + lib/Dialect/Triton/IR/Ops.cpp | 6 +- lib/Dialect/Triton/IR/Types.cpp | 14 +- .../TritonGPU/Transforms/AccelerateMatmul.cpp | 7 +- .../Transforms/OptimizeDotOperands.cpp | 7 +- .../Pipeliner/MatmulLoopPipeline.cpp | 32 +- .../Pipeliner/TMAStoresPipeline.cpp | 8 +- lib/Dialect/TritonGPU/Transforms/Prefetch.cpp | 5 +- .../Transforms/ReduceDataDuplication.cpp | 6 +- .../Transforms/TMALowering.cpp | 13 +- .../test/unit/hopper/test_experimental_tma.py | 14 +- python/test/unit/language/test_core.py | 8 +- test/Analysis/test-alias.mlir | 86 ++--- test/Analysis/test-allocation.mlir | 294 +++++++-------- test/Analysis/test-membar.mlir | 337 +++++++++--------- .../decompose-unsupported-conversions.mlir | 2 +- test/Conversion/tritongpu_to_llvm.mlir | 86 ++--- test/TritonGPU/loop-pipeline-hopper.mlir | 58 +-- test/TritonGPU/loop-pipeline.mlir | 50 +-- test/TritonGPU/reduce-data-duplication.mlir | 2 +- test/TritonNvidiaGPU/membar.mlir | 30 +- .../TritonAMDGPUTransforms/StreamPipeline.cpp | 5 +- .../SharedToDotOperandMMAv2.cpp | 3 +- .../DecomposeUnsupportedConversions.cpp | 7 +- 30 files changed, 594 insertions(+), 526 deletions(-) diff --git a/include/triton/Dialect/Triton/IR/TritonTypes.td b/include/triton/Dialect/Triton/IR/TritonTypes.td index fd5af9cc8ea3..e5e43f806ad4 100644 --- a/include/triton/Dialect/Triton/IR/TritonTypes.td +++ b/include/triton/Dialect/Triton/IR/TritonTypes.td @@ -106,12 +106,13 @@ def TT_MemDescType : TritonTypeDef<"MemDesc", "memdesc", [ShapedTypeInterface]> ArrayRefParameter<"int64_t">:$shape, "Type":$elementType, "Attribute":$encoding, + "Attribute":$memorySpace, "bool":$mutable_memory ); let extraClassDeclaration = [{ MemDescType cloneWith(std::optional> shape, Type elementType) const { - return MemDescType::get(shape.value_or(getShape()), elementType, getEncoding()); + return MemDescType::get(shape.value_or(getShape()), elementType, getEncoding(), getMemorySpace(), getMutableMemory()); } bool hasRank() const { return true; } @@ -120,17 +121,19 @@ def TT_MemDescType : TritonTypeDef<"MemDesc", "memdesc", [ShapedTypeInterface]> TypeBuilderWithInferredContext<(ins "llvm::ArrayRef":$shape, "Type":$elementType, - "Attribute":$encoding + "Attribute":$encoding, + "Attribute":$memorySpace ), [{ - return $_get(elementType.getContext(), shape, elementType, encoding, /*mutableMemory=*/false); + return $_get(elementType.getContext(), shape, elementType, encoding, memorySpace, /*mutableMemory=*/false); }]>, TypeBuilderWithInferredContext<(ins "llvm::ArrayRef":$shape, "Type":$elementType, "Attribute":$encoding, + "Attribute":$memorySpace, "bool":$mutableMemory ), [{ - return $_get(elementType.getContext(), shape, elementType, encoding, mutableMemory); + return $_get(elementType.getContext(), shape, elementType, encoding, memorySpace, mutableMemory); }]> ]; let hasCustomAssemblyFormat = 1; diff --git a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td index 1099e9bf69c4..bf8f2dc18ec6 100644 --- a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td +++ b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td @@ -1298,4 +1298,10 @@ elements along the K dim, or they use all elements of the tensor along the K dim }]; } +def TTG_SharedMemorySpace : AttrDef { + let mnemonic = "shared_memory"; + let description = [{ + Attribute to indicate that the memory descriptor points to shared memory. + }]; +} #endif diff --git a/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td b/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td index 2530009cbd87..061085de6a31 100644 --- a/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td +++ b/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td @@ -144,6 +144,12 @@ def TTG_LocalAllocOp : TTG_Op<"local_alloc", [DeclareOpInterfaceMethods:$src); + let extraClassDeclaration = [{ + bool isSharedMemoryAlloc() { + return getType().getMemorySpace() && + isa(getType().getMemorySpace()); + } + }]; let assemblyFormat = [{$src attr-dict `:` functional-type(operands, results)}]; let results = (outs TT_MemDescType:$result); diff --git a/lib/Analysis/Alias.cpp b/lib/Analysis/Alias.cpp index 5b3910013b7e..52082ddf7021 100644 --- a/lib/Analysis/Alias.cpp +++ b/lib/Analysis/Alias.cpp @@ -26,8 +26,13 @@ void SharedMemoryAliasAnalysis::visitOperation( ArrayRef *> results) { AliasInfo aliasInfo; bool pessimistic = true; - // These ops may allocate a new shared memory buffer. auto result = op->getResult(0); + // skip ops that return memdesc in a different memory space. + if (auto memdescTy = dyn_cast(result.getType())) { + if (!isa_and_nonnull( + memdescTy.getMemorySpace())) + return; + } // Only LocalAllocOp creates a new buffer. if (isa(op)) { diff --git a/lib/Analysis/Allocation.cpp b/lib/Analysis/Allocation.cpp index a129cb1947c6..cae55efac1ed 100644 --- a/lib/Analysis/Allocation.cpp +++ b/lib/Analysis/Allocation.cpp @@ -210,7 +210,8 @@ class AllocationAnalysis { // XXX(Keren): Why this hard-coded alignment? size_t kAlignment = 8; for (Value result : op->getResults()) { - if (auto alloc = result.getDefiningOp()) { + auto alloc = result.getDefiningOp(); + if (alloc && alloc.isSharedMemoryAlloc()) { // Bytes could be a different value once we support padding or other // allocation policies. auto allocType = alloc.getType(); diff --git a/lib/Conversion/TritonGPUToLLVM/DecomposeUnsupportedConversions.cpp b/lib/Conversion/TritonGPUToLLVM/DecomposeUnsupportedConversions.cpp index 690155ee54b5..a35b32c9c4ef 100644 --- a/lib/Conversion/TritonGPUToLLVM/DecomposeUnsupportedConversions.cpp +++ b/lib/Conversion/TritonGPUToLLVM/DecomposeUnsupportedConversions.cpp @@ -95,12 +95,15 @@ void decomposeBlockedToDotLayoutConversion(ModuleOp module) { auto dstDotOp = dyn_cast(dstType.getEncoding()); if (srcBlocked && dstDotOp) { + Attribute sharedMemorySpace = + triton::gpu::SharedMemorySpaceAttr::get(srcType.getContext()); auto tmpType = MemDescType::get( dstType.getShape(), dstType.getElementType(), triton::gpu::SharedEncodingAttr::get( module.getContext(), dstDotOp, srcType.getShape(), srcBlocked.getOrder(), srcBlocked.getCTALayout(), - srcType.getElementType())); + srcType.getElementType()), + sharedMemorySpace); auto tmp = builder.create( cvtOp.getLoc(), tmpType, cvtOp.getSrc()); addAttrs(tmp, cvtOp->getAttrs()); diff --git a/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp index 12ab6684c3b3..38ecb288af1a 100644 --- a/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp @@ -51,6 +51,8 @@ struct LocalAllocOpConversion LogicalResult matchAndRewrite(triton::gpu::LocalAllocOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { + if (!op.isSharedMemoryAlloc()) + return failure(); Location loc = op->getLoc(); Value smemBase = LLVM::getSharedMemoryBase(loc, rewriter, op.getOperation()); diff --git a/lib/Dialect/Triton/IR/Ops.cpp b/lib/Dialect/Triton/IR/Ops.cpp index ce4f97336919..4b26babef93e 100644 --- a/lib/Dialect/Triton/IR/Ops.cpp +++ b/lib/Dialect/Triton/IR/Ops.cpp @@ -224,9 +224,9 @@ LogicalResult TransOp::inferReturnTypes( return failure(); } } - if (isa(argTy)) { - inferredReturnTypes.push_back( - MemDescType::get(retShape, retEltTy, retEncoding)); + if (auto memDescTy = dyn_cast(argTy)) { + inferredReturnTypes.push_back(MemDescType::get( + retShape, retEltTy, retEncoding, memDescTy.getMemorySpace())); } else { inferredReturnTypes.push_back( RankedTensorType::get(retShape, retEltTy, retEncoding)); diff --git a/lib/Dialect/Triton/IR/Types.cpp b/lib/Dialect/Triton/IR/Types.cpp index 0e1df5b744cd..f9356cb73791 100644 --- a/lib/Dialect/Triton/IR/Types.cpp +++ b/lib/Dialect/Triton/IR/Types.cpp @@ -70,16 +70,24 @@ Type MemDescType::parse(AsmParser &parser) { return Type(); } bool mutableMemory = false; + Attribute memorySpace; if (succeeded(parser.parseOptionalComma())) { + if (failed(parser.parseOptionalKeyword(kMutableMemory))) { + if (parser.parseAttribute(memorySpace)) + return Type(); + } else { + mutableMemory = true; + } + } + if (mutableMemory == false && succeeded(parser.parseOptionalComma())) { if (parser.parseOptionalKeyword(kMutableMemory)) return Type(); mutableMemory = true; } if (parser.parseGreater()) return Type(); - return MemDescType::get(parser.getContext(), dimensions, elementType, - encoding, mutableMemory); + encoding, memorySpace, mutableMemory); } void MemDescType::print(AsmPrinter &printer) const { @@ -89,6 +97,8 @@ void MemDescType::print(AsmPrinter &printer) const { printer << getElementType(); if (getEncoding()) printer << ", " << getEncoding(); + if (getMemorySpace()) + printer << ", " << getMemorySpace(); if (getMutableMemory()) printer << ", " << kMutableMemory; printer << ">"; diff --git a/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp b/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp index df84c4e628ea..854c96d46dbf 100644 --- a/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp +++ b/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp @@ -208,12 +208,15 @@ class BlockedToMMA : public mlir::RewritePattern { } } + Attribute SharedMemorySpace = + SharedMemorySpaceAttr::get(argType.getContext()); auto CTALayout = getCTALayout(argType.getEncoding()); auto newLayout = SharedEncodingAttr::get(argType.getContext(), argType.getShape(), newOrder, CTALayout, argType.getElementType()); - auto newType = MemDescType::get(argType.getShape(), - argType.getElementType(), newLayout); + auto newType = + MemDescType::get(argType.getShape(), argType.getElementType(), + newLayout, SharedMemorySpace); rewriter.setInsertionPointAfterValue(arg); return rewriter.create(arg.getLoc(), newType, arg); } diff --git a/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp b/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp index 4a30bf9f37c8..08cda421b36a 100644 --- a/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp +++ b/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp @@ -59,12 +59,12 @@ class SwizzleShmemConvert : public OpRewritePattern { srcTy.getElementType(), /*needTrans=*/true); if (newInnerCvtEnc == cvtEncoding) return failure(); - rewriter.setInsertionPoint(trans); + auto sharedMemorySpace = SharedMemorySpaceAttr::get(getContext()); auto alloc = rewriter.create( trans.getLoc(), MemDescType::get(srcTy.getShape(), srcTy.getElementType(), - newInnerCvtEnc), + newInnerCvtEnc, sharedMemorySpace), trans.getSrc()); auto newTrans = rewriter.create(trans.getLoc(), alloc, ArrayRef({1, 0})); @@ -254,7 +254,8 @@ class FuseTransHopper : public OpRewritePattern { allocEncoding.getCTALayout(), srcTy.getElementType()); MemDescType innerTy = - MemDescType::get(srcTy.getShape(), srcTy.getElementType(), newInnerEnc); + MemDescType::get(srcTy.getShape(), srcTy.getElementType(), newInnerEnc, + allocType.getMemorySpace()); auto newAlloc = rewriter.create(allocOp.getLoc(), innerTy, trans.getSrc()); rewriter.replaceOpWithNewOp(allocOp, newAlloc, diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp index 457d42f4eef7..a407f5f29e61 100644 --- a/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp @@ -245,9 +245,11 @@ static void createAsyncCopy(scf::ForOp &forOp, tt::LoadOp loadOp, Value alloc, tt::MemDescType allocTy = cast(alloc.getType()); SmallVector copyOffsets(allocTy.getRank(), zero); copyOffsets[0] = insertIdx; + Attribute sharedMemorySpace = + triton::gpu::SharedMemorySpaceAttr::get(forOp.getContext()); tt::MemDescType subviewTy = tt::MemDescType::get( allocTy.getShape().drop_front(), allocTy.getElementType(), - allocTy.getEncoding(), /*mutableMemory=*/true); + allocTy.getEncoding(), sharedMemorySpace, /*mutableMemory=*/true); auto view = builder.create(loc, subviewTy, alloc, copyOffsets); Operation *copy = builder.create( @@ -316,6 +318,8 @@ static void createTMAAsyncCopy( llvm::MapVector &loadToInfo, int numStages) { assert(phase && "Phase value is required for TMA async copy."); OpBuilder builder(forOp); + Attribute sharedMemorySpace = + triton::gpu::SharedMemorySpaceAttr::get(forOp.getContext()); Value zero = builder.create(forOp.getLoc(), 0, 32); builder.setInsertionPoint(loadOp); Location loc = loadOp.getLoc(); @@ -324,7 +328,7 @@ static void createTMAAsyncCopy( copyOffsets[0] = insertIdx; tt::MemDescType subviewTy = tt::MemDescType::get( allocTy.getShape().drop_front(), allocTy.getElementType(), - allocTy.getEncoding(), /*mutableMemory=*/true); + allocTy.getEncoding(), sharedMemorySpace, /*mutableMemory=*/true); auto view = builder.create(loc, subviewTy, alloc, copyOffsets); @@ -906,11 +910,14 @@ static void scheduleRemainingToLastStage(scf::ForOp forOp, static Value createAlloc(scf::ForOp &forOp, Operation *loadOp, ttg::SharedEncodingAttr sharedEnc, unsigned distance) { OpBuilder builder(forOp); + Attribute sharedMemorySpace = + triton::gpu::SharedMemorySpaceAttr::get(forOp.getContext()); auto ty = cast(loadOp->getResultTypes()[0]); SmallVector bufferShape(ty.getShape().begin(), ty.getShape().end()); bufferShape.insert(bufferShape.begin(), distance); Type memdescType = mlir::triton::MemDescType::get( - bufferShape, ty.getElementType(), sharedEnc, /*mutableMemory*/ true); + bufferShape, ty.getElementType(), sharedEnc, sharedMemorySpace, + /*mutableMemory*/ true); Value alloc = builder.create( loadOp->getLoc(), memdescType, Value()); return alloc; @@ -919,6 +926,8 @@ static Value createAlloc(scf::ForOp &forOp, Operation *loadOp, // Create an allocation to hold the mbarriers. static Value createBarrierAlloc(scf::ForOp &forOp, unsigned distance) { OpBuilder builder(forOp); + Attribute sharedMemorySpace = + triton::gpu::SharedMemorySpaceAttr::get(forOp.getContext()); Location loc = forOp.getLoc(); auto context = forOp.getContext(); auto barrierCTALayout = @@ -926,11 +935,12 @@ static Value createBarrierAlloc(scf::ForOp &forOp, unsigned distance) { /*CTASplitNum=*/{1}, /*CTAOrder=*/{0}); auto barrierEncoding = ttg::SharedEncodingAttr::get(context, 1, 1, 1, {0}, barrierCTALayout); - Type barrierMemDescType = - tt::MemDescType::get({distance}, builder.getI64Type(), barrierEncoding, - /*mutableMemory=*/true); - Type singleBarrierMemDescType = tt::MemDescType::get( - {1}, builder.getI64Type(), barrierEncoding, /*mutableMemory=*/true); + Type barrierMemDescType = tt::MemDescType::get( + {distance}, builder.getI64Type(), barrierEncoding, sharedMemorySpace, + /*mutableMemory=*/true); + Type singleBarrierMemDescType = + tt::MemDescType::get({1}, builder.getI64Type(), barrierEncoding, + sharedMemorySpace, /*mutableMemory=*/true); Value barrierAlloc = builder.create( loc, barrierMemDescType, Value()); for (unsigned i = 0; i < distance; i++) { @@ -1026,9 +1036,12 @@ static void createTMABarrierAndWait( barriers.push_back(barrierAlloc); Location loc = forOp.getLoc(); OpBuilder builder(forOp); + Attribute sharedMemorySpace = + triton::gpu::SharedMemorySpaceAttr::get(builder.getContext()); tt::MemDescType barrierTy = tt::MemDescType::get( {1}, builder.getI64Type(), cast(barrierAlloc.getType()).getEncoding(), + sharedMemorySpace, /*mutableMemory=*/true); builder.setInsertionPoint(group[0]->loadOp); Value barrier = builder.create( @@ -1167,6 +1180,8 @@ createAsyncOps(scf::ForOp &forOp, CoarseSchedule &schedule, static void invalidateBarriers(OpBuilder &builder, SmallVector &barriers) { + Attribute sharedMemorySpace = + triton::gpu::SharedMemorySpaceAttr::get(builder.getContext()); for (Value barrier : barriers) { int numBarriers = cast(barrier.getType()).getShape()[0]; for (int i = 0; i < numBarriers; i++) { @@ -1174,6 +1189,7 @@ static void invalidateBarriers(OpBuilder &builder, tt::MemDescType barrierTy = tt::MemDescType::get( {1}, builder.getI64Type(), cast(barrier.getType()).getEncoding(), + sharedMemorySpace, /*mutableMemory=*/true); Value barrierView = builder.create( barrier.getLoc(), barrierTy, barrier, idx); diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeliner/TMAStoresPipeline.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeliner/TMAStoresPipeline.cpp index 6318b178d39f..798329a6e983 100644 --- a/lib/Dialect/TritonGPU/Transforms/Pipeliner/TMAStoresPipeline.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Pipeliner/TMAStoresPipeline.cpp @@ -39,9 +39,11 @@ static Value createAlloc(scf::ForOp &forOp, encoding = ttg::SharedEncodingAttr::get( ty.getContext(), ty.getShape(), order, ctaLayout, ty.getElementType()); } - - Type memdescType = tt::MemDescType::get(ty.getShape(), ty.getElementType(), - encoding, /*mutableMemory*/ true); + Attribute sharedMemorySpace = + triton::gpu::SharedMemorySpaceAttr::get(ty.getContext()); + Type memdescType = + tt::MemDescType::get(ty.getShape(), ty.getElementType(), encoding, + sharedMemorySpace, /*mutableMemory*/ true); Value alloc = builder.create(storeOp->getLoc(), memdescType, Value()); return alloc; diff --git a/lib/Dialect/TritonGPU/Transforms/Prefetch.cpp b/lib/Dialect/TritonGPU/Transforms/Prefetch.cpp index bb086e03ee22..a078366ab91e 100644 --- a/lib/Dialect/TritonGPU/Transforms/Prefetch.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Prefetch.cpp @@ -136,8 +136,9 @@ Value Prefetcher::generatePrefetch(Value v, unsigned opIdx, bool isPrologue, builder.create(v.getLoc(), off, 32)); Value newSmem = builder.create( v.getLoc(), - triton::MemDescType::get(shape, elementType, type.getEncoding()), v, - offsetsVal); + triton::MemDescType::get(shape, elementType, type.getEncoding(), + type.getMemorySpace()), + v, offsetsVal); auto dotOperandEnc = triton::gpu::DotOperandEncodingAttr::get( builder.getContext(), opIdx, dotEncoding, prefetchWidth / 8); diff --git a/lib/Dialect/TritonGPU/Transforms/ReduceDataDuplication.cpp b/lib/Dialect/TritonGPU/Transforms/ReduceDataDuplication.cpp index c0b586d6016a..8c1f18e459c5 100644 --- a/lib/Dialect/TritonGPU/Transforms/ReduceDataDuplication.cpp +++ b/lib/Dialect/TritonGPU/Transforms/ReduceDataDuplication.cpp @@ -70,12 +70,14 @@ class TritonGPUReduceDataDuplicationPass } else { sharedOrder = srcOrder; } + auto sharedMemorySpace = + triton::gpu::SharedMemorySpaceAttr::get(srcType.getContext()); auto tmpType = triton::MemDescType::get( dstType.getShape(), dstType.getElementType(), triton::gpu::SharedEncodingAttr::get( mod.getContext(), dstDotOp, srcType.getShape(), sharedOrder, - triton::gpu::getCTALayout(srcEncoding), - srcType.getElementType())); + triton::gpu::getCTALayout(srcEncoding), srcType.getElementType()), + sharedMemorySpace); auto tmp = builder.create( cvtOp.getLoc(), tmpType, cvtOp.getSrc()); auto newConvert = builder.create(cvtOp.getLoc(), diff --git a/lib/Dialect/TritonNvidiaGPU/Transforms/TMALowering.cpp b/lib/Dialect/TritonNvidiaGPU/Transforms/TMALowering.cpp index 58e2888b717d..6f1bd728db03 100644 --- a/lib/Dialect/TritonNvidiaGPU/Transforms/TMALowering.cpp +++ b/lib/Dialect/TritonNvidiaGPU/Transforms/TMALowering.cpp @@ -23,6 +23,8 @@ class TMALoadLowering : public OpRewritePattern { LogicalResult matchAndRewrite(ExperimentalDescriptorLoadOp op, PatternRewriter &rewriter) const override { + MLIRContext *ctx = op.getContext(); + Attribute sharedMemorySpace = triton::gpu::SharedMemorySpaceAttr::get(ctx); auto loc = op.getLoc(); auto tensorType = op.getResult().getType(); auto order = getOrder(tensorType.getEncoding()); @@ -36,15 +38,16 @@ class TMALoadLowering : public OpRewritePattern { } MemDescType memDescType = MemDescType::get(tensorType.getShape(), tensorType.getElementType(), - encoding, /*mutableMemory=*/true); + encoding, sharedMemorySpace, /*mutableMemory=*/true); Value alloc = rewriter.create(loc, memDescType, Value()); auto barrierCTALayout = CTALayoutAttr::get( /*context=*/tensorType.getContext(), /*CTAsPerCGA=*/{1}, /*CTASplitNum=*/{1}, /*CTAOrder=*/{0}); auto barrierEncoding = SharedEncodingAttr::get(tensorType.getContext(), 1, 1, 1, {0}, barrierCTALayout); - MemDescType barrierMemDescType = MemDescType::get( - {1}, rewriter.getI64Type(), barrierEncoding, /*mutableMemory=*/true); + MemDescType barrierMemDescType = + MemDescType::get({1}, rewriter.getI64Type(), barrierEncoding, + sharedMemorySpace, /*mutableMemory=*/true); Value barrierAlloc = rewriter.create(loc, barrierMemDescType, Value()); rewriter.create(loc, barrierAlloc, 1); @@ -70,6 +73,8 @@ class TMAStoreLowering LogicalResult matchAndRewrite(ExperimentalDescriptorStoreOp op, PatternRewriter &rewriter) const override { + MLIRContext *ctx = op.getContext(); + Attribute sharedMemorySpace = triton::gpu::SharedMemorySpaceAttr::get(ctx); auto loc = op.getLoc(); auto tensorType = op.getSrc().getType(); auto order = getOrder(tensorType.getEncoding()); @@ -83,7 +88,7 @@ class TMAStoreLowering } MemDescType memDescType = MemDescType::get(tensorType.getShape(), tensorType.getElementType(), - encoding, /*mutableMemory=*/true); + encoding, sharedMemorySpace, /*mutableMemory=*/true); Value alloc = rewriter.create(loc, memDescType, op.getSrc()); rewriter.create(loc, false); rewriter.create( diff --git a/python/test/unit/hopper/test_experimental_tma.py b/python/test/unit/hopper/test_experimental_tma.py index b20f75bc53a6..c3c8cb7f035e 100644 --- a/python/test/unit/hopper/test_experimental_tma.py +++ b/python/test/unit/hopper/test_experimental_tma.py @@ -26,14 +26,14 @@ def test_descriptor_load_ttgir(): tt.func public @kernel(%arg0: !tt.ptr {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr {{tt.divisibility = 16 : i32}}) attributes {{noinline = false}} {{ %c0_i32 = arith.constant 0 : i32 %0 = tt.make_range {{end = {SIZE} : i32, start = 0 : i32}} : tensor<{SIZE}xi32, #blocked> - %1 = triton_gpu.local_alloc : () -> !tt.memdesc<{SIZE}xf32, #shared, mutable> - %2 = triton_gpu.local_alloc : () -> !tt.memdesc<1xi64, #shared, mutable> - triton_nvidia_gpu.init_barrier %2, 1 : <1xi64, #shared, mutable> + %1 = triton_gpu.local_alloc : () -> !tt.memdesc<{SIZE}xf32, #shared, #triton_gpu.shared_memory, mutable> + %2 = triton_gpu.local_alloc : () -> !tt.memdesc<1xi64, #shared, #triton_gpu.shared_memory, mutable> + triton_nvidia_gpu.init_barrier %2, 1 : <1xi64, #shared, #triton_gpu.shared_memory, mutable> %true = arith.constant 1 : i1 - triton_nvidia_gpu.barrier_expect %2, {size_in_bytes}, %true : <1xi64, #shared, mutable> - triton_nvidia_gpu.async_tma_copy_global_to_local %arg1[%c0_i32] %1, %2, %true : , <1xi64, #shared, mutable> -> <{SIZE}xf32, #shared, mutable> - triton_nvidia_gpu.wait_barrier %2, %c0_i32 : <1xi64, #shared, mutable> - %3 = triton_gpu.local_load %1 : !tt.memdesc<{SIZE}xf32, #shared, mutable> -> tensor<{SIZE}xf32, #blocked> + triton_nvidia_gpu.barrier_expect %2, {size_in_bytes}, %true : <1xi64, #shared, #triton_gpu.shared_memory, mutable> + triton_nvidia_gpu.async_tma_copy_global_to_local %arg1[%c0_i32] %1, %2, %true : , <1xi64, #shared, #triton_gpu.shared_memory, mutable> -> <{SIZE}xf32, #shared, #triton_gpu.shared_memory, mutable> + triton_nvidia_gpu.wait_barrier %2, %c0_i32 : <1xi64, #shared, #triton_gpu.shared_memory, mutable> + %3 = triton_gpu.local_load %1 : !tt.memdesc<{SIZE}xf32, #shared, #triton_gpu.shared_memory, mutable> -> tensor<{SIZE}xf32, #blocked> %4 = tt.splat %arg0 : !tt.ptr -> tensor<{SIZE}x!tt.ptr, #blocked> %5 = tt.addptr %4, %0 : tensor<{SIZE}x!tt.ptr, #blocked>, tensor<{SIZE}xi32, #blocked> tt.store %5, %3 : tensor<{SIZE}x!tt.ptr, #blocked> diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 5972c93d7f98..abcb1bf0b50f 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -4852,10 +4852,10 @@ def test_convert2d(M, N, src_layout, interm_layout, dst_layout, dtype, device): %12 = triton_gpu.convert_layout %9 : tensor<{M}x{N}xi32, #src> -> tensor<{M}x{N}xi32, #dst> %13 = triton_gpu.convert_layout %11 : tensor<{M}x{N}xf16, #src> -> tensor<{M}x{N}xf16, #dst> """ if interm_layout is None else f""" - %15 = triton_gpu.local_alloc %9 : (tensor<{M}x{N}xi32, #src>) -> !tt.memdesc<{M}x{N}xi32, #interm> - %16 = triton_gpu.local_load %15 : !tt.memdesc<{M}x{N}xi32, #interm> -> tensor<{M}x{N}xi32, #src> - %17 = triton_gpu.local_alloc %11 : (tensor<{M}x{N}xf16, #src>) -> !tt.memdesc<{M}x{N}xf16, #interm> - %18 = triton_gpu.local_load %17 : !tt.memdesc<{M}x{N}xf16, #interm> -> tensor<{M}x{N}xf16, #src> + %15 = triton_gpu.local_alloc %9 : (tensor<{M}x{N}xi32, #src>) -> !tt.memdesc<{M}x{N}xi32, #interm, #triton_gpu.shared_memory> + %16 = triton_gpu.local_load %15 : !tt.memdesc<{M}x{N}xi32, #interm, #triton_gpu.shared_memory> -> tensor<{M}x{N}xi32, #src> + %17 = triton_gpu.local_alloc %11 : (tensor<{M}x{N}xf16, #src>) -> !tt.memdesc<{M}x{N}xf16, #interm, #triton_gpu.shared_memory> + %18 = triton_gpu.local_load %17 : !tt.memdesc<{M}x{N}xf16, #interm, #triton_gpu.shared_memory> -> tensor<{M}x{N}xf16, #src> %12 = triton_gpu.convert_layout %16 : tensor<{M}x{N}xi32, #src> -> tensor<{M}x{N}xi32, #dst> %13 = triton_gpu.convert_layout %18 : tensor<{M}x{N}xf16, #src> -> tensor<{M}x{N}xf16, #dst> diff --git a/test/Analysis/test-alias.mlir b/test/Analysis/test-alias.mlir index 2f73e0880f05..8c4e65daabfa 100644 --- a/test/Analysis/test-alias.mlir +++ b/test/Analysis/test-alias.mlir @@ -41,7 +41,7 @@ tt.func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr, // CHECK-LABEL: alloc tt.func @alloc(%A : !tt.ptr) { // CHECK: %0 -> %0 - %cst2 = triton_gpu.local_alloc : () -> !tt.memdesc<16x16xf16, #A_SHARED, mutable> + %cst2 = triton_gpu.local_alloc : () -> !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> tt.return } @@ -49,40 +49,40 @@ tt.func @alloc(%A : !tt.ptr) { tt.func @alloc_init(%A : !tt.ptr) { %cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #AL> // CHECK: %0 -> %0 - %cst1 = triton_gpu.local_alloc %cst0 : (tensor<16x16xf16, #AL>) -> !tt.memdesc<16x16xf16, #A_SHARED> + %cst1 = triton_gpu.local_alloc %cst0 : (tensor<16x16xf16, #AL>) -> !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> tt.return } // CHECK-LABEL: trans tt.func @trans(%A : !tt.ptr) { // CHECK: %0 -> %0 - %tensor = triton_gpu.local_alloc : () -> !tt.memdesc<16x32xf16, #A_SHARED> + %tensor = triton_gpu.local_alloc : () -> !tt.memdesc<16x32xf16, #A_SHARED, #triton_gpu.shared_memory> // CHECK: %1 -> %0 - %b = tt.trans %tensor {order=array} : !tt.memdesc<16x32xf16, #A_SHARED> -> !tt.memdesc<32x16xf16, #A_SHARED_T> + %b = tt.trans %tensor {order=array} : !tt.memdesc<16x32xf16, #A_SHARED, #triton_gpu.shared_memory> -> !tt.memdesc<32x16xf16, #A_SHARED_T, #triton_gpu.shared_memory> tt.return } // CHECK-LABEL: subview -tt.func @subview(%A : !tt.memdesc<1x16x16xf16, #A_SHARED>) { +tt.func @subview(%A : !tt.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory>) { %index = arith.constant 0 : i32 // CHECK: %0 -> %0 - %a = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED> + %a = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory> // CHECK-NEXT: %1 -> %0 - %cst1 = triton_gpu.memdesc_subview %a[%index, %index, %index] : !tt.memdesc<1x16x16xf16, #A_SHARED> -> !tt.memdesc<16x16xf16, #A_SHARED> + %cst1 = triton_gpu.memdesc_subview %a[%index, %index, %index] : !tt.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory> -> !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> tt.return } // CHECK-LABEL: if_alias tt.func @if_alias(%i1 : i1) { // CHECK: %0 -> %0 - %a = triton_gpu.local_alloc : () -> !tt.memdesc<16x16xf16, #A_SHARED> + %a = triton_gpu.local_alloc : () -> !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> // CHECK: %1 -> %1 - %b = triton_gpu.local_alloc : () -> !tt.memdesc<16x16xf16, #A_SHARED> + %b = triton_gpu.local_alloc : () -> !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> // CHECK-NEXT: %2 -> %0,%1 - %cst2 = scf.if %i1 -> !tt.memdesc<16x16xf16, #A_SHARED> { - scf.yield %a : !tt.memdesc<16x16xf16, #A_SHARED> + %cst2 = scf.if %i1 -> !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> { + scf.yield %a : !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> } else { - scf.yield %b : !tt.memdesc<16x16xf16, #A_SHARED> + scf.yield %b : !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> } tt.return } @@ -90,11 +90,11 @@ tt.func @if_alias(%i1 : i1) { // CHECK-LABEL: for tt.func @for(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr) { // CHECK: %0 -> %0 - %a = triton_gpu.local_alloc : () -> !tt.memdesc<16x16xf16, #A_SHARED> + %a = triton_gpu.local_alloc : () -> !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> // CHECK: %1 -> %1 - %b = triton_gpu.local_alloc : () -> !tt.memdesc<16x16xf16, #A_SHARED> + %b = triton_gpu.local_alloc : () -> !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> // CHECK: %2 -> %2 - %c = triton_gpu.local_alloc : () -> !tt.memdesc<16x16xf16, #A_SHARED> + %c = triton_gpu.local_alloc : () -> !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> // CHECK-NEXT: %arg6 -> %0 // CHECK-NEXT: %arg7 -> %1 // CHECK-NEXT: %arg8 -> %2 @@ -102,8 +102,8 @@ tt.func @for(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !t // CHECK-NEXT: %3#1 -> %0,%1 // CHECK-NEXT: %3#2 -> %0,%1,%2 %a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a, %b_shared = %b, %c_shared = %c) -> - (!tt.memdesc<16x16xf16, #A_SHARED>, !tt.memdesc<16x16xf16, #A_SHARED>, !tt.memdesc<16x16xf16, #A_SHARED>) { - scf.yield %b_shared, %a_shared, %a_shared : !tt.memdesc<16x16xf16, #A_SHARED>, !tt.memdesc<16x16xf16, #A_SHARED>, !tt.memdesc<16x16xf16, #A_SHARED> + (!tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory>, !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory>, !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory>) { + scf.yield %b_shared, %a_shared, %a_shared : !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory>, !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory>, !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> } tt.return } @@ -111,11 +111,11 @@ tt.func @for(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !t // CHECK-LABEL: for_if tt.func @for_if(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr, %i1 : i1) { // CHECK: %0 -> %0 - %a_shared_init = triton_gpu.local_alloc : () -> !tt.memdesc<128x32xf16, #A_SHARED> + %a_shared_init = triton_gpu.local_alloc : () -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> // CHECK-NEXT: %1 -> %1 - %b_shared_init = triton_gpu.local_alloc : () -> !tt.memdesc<128x32xf16, #A_SHARED> + %b_shared_init = triton_gpu.local_alloc : () -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> // CHECK-NEXT: %2 -> %2 - %c_shared_init = triton_gpu.local_alloc : () -> !tt.memdesc<128x32xf16, #A_SHARED> + %c_shared_init = triton_gpu.local_alloc : () -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> // CHECK-NEXT: %arg7 -> %0 // CHECK-NEXT: %arg8 -> %1 // CHECK-NEXT: %arg9 -> %2 @@ -123,14 +123,14 @@ tt.func @for_if(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : // CHECK-NEXT: %3#1 -> %0,%1 // CHECK-NEXT: %3#2 -> %0,%1,%2 %a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init, %c_shared = %c_shared_init) -> - (!tt.memdesc<128x32xf16, #A_SHARED>, !tt.memdesc<128x32xf16, #A_SHARED>, !tt.memdesc<128x32xf16, #A_SHARED>) { + (!tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>) { scf.if %i1 { %index = arith.constant 8 : i32 // CHECK-NEXT: %4 -> %0,%1 - %cst0 = triton_gpu.memdesc_subview %a_shared[%index, %index] : !tt.memdesc<128x32xf16, #A_SHARED> -> !tt.memdesc<32xf16, #A_SHARED> + %cst0 = triton_gpu.memdesc_subview %a_shared[%index, %index] : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> -> !tt.memdesc<32xf16, #A_SHARED, #triton_gpu.shared_memory> scf.yield } - scf.yield %b_shared, %a_shared, %a_shared : !tt.memdesc<128x32xf16, #A_SHARED>, !tt.memdesc<128x32xf16, #A_SHARED>, !tt.memdesc<128x32xf16, #A_SHARED> + scf.yield %b_shared, %a_shared, %a_shared : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> } tt.return } @@ -138,11 +138,11 @@ tt.func @for_if(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : // CHECK-LABEL: for_for_if tt.func @for_for_if(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr, %i1 : i1) { // CHECK: %0 -> %0 - %a_shared_init = triton_gpu.local_alloc : () -> !tt.memdesc<128x32xf16, #A_SHARED> + %a_shared_init = triton_gpu.local_alloc : () -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> // CHECK-NEXT: %1 -> %1 - %b_shared_init = triton_gpu.local_alloc : () -> !tt.memdesc<128x32xf16, #A_SHARED> + %b_shared_init = triton_gpu.local_alloc : () -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> // CHECK-NEXT: %2 -> %2 - %c_shared_init = triton_gpu.local_alloc : () -> !tt.memdesc<128x32xf16, #A_SHARED> + %c_shared_init = triton_gpu.local_alloc : () -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> // CHECK-NEXT: %arg7 -> %0 // CHECK-NEXT: %arg8 -> %1 // CHECK-NEXT: %arg9 -> %2 @@ -150,23 +150,23 @@ tt.func @for_for_if(%lb : index, %ub : index, %step : index, %A : !tt.ptr, // CHECK-NEXT: %3#1 -> %1 // CHECK-NEXT: %3#2 -> %2,%6,%6 %a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init, %c_shared = %c_shared_init) -> - (!tt.memdesc<128x32xf16, #A_SHARED>, !tt.memdesc<128x32xf16, #A_SHARED>, !tt.memdesc<128x32xf16, #A_SHARED>) { + (!tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>) { // CHECK-NEXT: %arg11 -> %2,%6,%6 // CHECK-NEXT: %4 -> %2,%6,%6 - %c_shared_next = scf.for %jv = %lb to %ub step %step iter_args(%c_shared_next = %c_shared) -> (!tt.memdesc<128x32xf16, #A_SHARED>) { + %c_shared_next = scf.for %jv = %lb to %ub step %step iter_args(%c_shared_next = %c_shared) -> (!tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>) { // CHECK-NEXT: %5 -> %6,%6 - %c_shared_next_next = scf.if %i1 -> !tt.memdesc<128x32xf16, #A_SHARED> { + %c_shared_next_next = scf.if %i1 -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> { // CHECK-NEXT: %6 -> %6 - %cst0 = triton_gpu.local_alloc : () -> !tt.memdesc<128x32xf16, #A_SHARED> - scf.yield %cst0 : !tt.memdesc<128x32xf16, #A_SHARED> + %cst0 = triton_gpu.local_alloc : () -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> + scf.yield %cst0 : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> } else { // CHECK-NEXT: %6 -> %6 - %cst0 = triton_gpu.local_alloc : () -> !tt.memdesc<128x32xf16, #A_SHARED> - scf.yield %cst0 : !tt.memdesc<128x32xf16, #A_SHARED> + %cst0 = triton_gpu.local_alloc : () -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> + scf.yield %cst0 : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> } - scf.yield %c_shared_next_next : !tt.memdesc<128x32xf16, #A_SHARED> + scf.yield %c_shared_next_next : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> } - scf.yield %a_shared, %b_shared, %c_shared_next : !tt.memdesc<128x32xf16, #A_SHARED>, !tt.memdesc<128x32xf16, #A_SHARED>, !tt.memdesc<128x32xf16, #A_SHARED> + scf.yield %a_shared, %b_shared, %c_shared_next : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> } tt.return } @@ -175,29 +175,29 @@ tt.func @for_for_if(%lb : index, %ub : index, %step : index, %A : !tt.ptr, tt.func @cf_for(%arg0: index, %arg1: index, %arg2: index, %arg3: !tt.ptr, %arg4: !tt.ptr) { %idx = arith.constant 0 : i32 // CHECK: %0 -> %0 - %cst = triton_gpu.local_alloc : () -> !tt.memdesc<128x32xf16, #A_SHARED> + %cst = triton_gpu.local_alloc : () -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> // CHECK-NEXT: %1 -> %1 - %cst_0 = triton_gpu.local_alloc : () -> !tt.memdesc<128x32xf16, #A_SHARED> + %cst_0 = triton_gpu.local_alloc : () -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> // CHECK-NEXT: %2 -> %0 - %0 = triton_gpu.memdesc_subview %cst[%idx, %idx] : !tt.memdesc<128x32xf16, #A_SHARED> -> !tt.memdesc<128x32xf16, #A_SHARED> + %0 = triton_gpu.memdesc_subview %cst[%idx, %idx] : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> gpu.barrier // CHECK-NEXT: %3 -> %3 - %cst_1 = triton_gpu.local_alloc : () -> !tt.memdesc<128x32xf16, #A_SHARED> + %cst_1 = triton_gpu.local_alloc : () -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> // CHECK-NEXT: %5 -> %0,%1,%3 // CHECK-NEXT: %6 -> %0,%1,%3 // CHECK-NEXT: %7 -> %0,%1,%3 - cf.br ^bb1(%arg0, %cst, %cst_0, %cst_1 : index, !tt.memdesc<128x32xf16, #A_SHARED>, !tt.memdesc<128x32xf16, #A_SHARED>, !tt.memdesc<128x32xf16, #A_SHARED>) -^bb1(%1: index, %2: !tt.memdesc<128x32xf16, #A_SHARED>, %3: !tt.memdesc<128x32xf16, #A_SHARED>, %4: !tt.memdesc<128x32xf16, #A_SHARED>): // 2 preds: ^bb0, ^bb2 + cf.br ^bb1(%arg0, %cst, %cst_0, %cst_1 : index, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>) +^bb1(%1: index, %2: !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>, %3: !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>, %4: !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>): // 2 preds: ^bb0, ^bb2 %5 = arith.cmpi slt, %1, %arg1 : index cf.cond_br %5, ^bb2, ^bb3 ^bb2: // pred: ^bb1 gpu.barrier %8 = arith.addi %1, %arg2 : index - cf.br ^bb1(%8, %4, %2, %3 : index, !tt.memdesc<128x32xf16, #A_SHARED>, !tt.memdesc<128x32xf16, #A_SHARED>, !tt.memdesc<128x32xf16, #A_SHARED>) + cf.br ^bb1(%8, %4, %2, %3 : index, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>) ^bb3: // pred: ^bb1 gpu.barrier // CHECK-NEXT: %10 -> %0 - %9 = triton_gpu.memdesc_subview %0[%idx, %idx] : !tt.memdesc<128x32xf16, #A_SHARED> -> !tt.memdesc<128x32xf16, #A_SHARED> + %9 = triton_gpu.memdesc_subview %0[%idx, %idx] : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> tt.return } diff --git a/test/Analysis/test-allocation.mlir b/test/Analysis/test-allocation.mlir index f7926de3a24c..76a6340d7aef 100644 --- a/test/Analysis/test-allocation.mlir +++ b/test/Analysis/test-allocation.mlir @@ -80,47 +80,47 @@ tt.func @reusable(%A : !tt.ptr) { // CHECK-LABEL: preallocate tt.func @preallocate(%A : !tt.ptr) { // CHECK: offset = 0, size = 512 - %cst0 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED> + %cst0 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory> // CHECK-NEXT: offset = 1024, size = 512 - %cst1 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED> + %cst1 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory> // CHECK-NEXT: offset = 2048, size = 512 - %cst2 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED> + %cst2 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory> // CHECK-NEXT: offset = 3072, size = 1024 - %a = triton_gpu.local_alloc : () -> !tt.memdesc<32x16xf16, #A_SHARED> + %a = triton_gpu.local_alloc : () -> !tt.memdesc<32x16xf16, #A_SHARED, #triton_gpu.shared_memory> // CHECK-NEXT: offset = 4096, size = 1024 - %b = triton_gpu.local_alloc : () -> !tt.memdesc<32x16xf16, #A_SHARED> + %b = triton_gpu.local_alloc : () -> !tt.memdesc<32x16xf16, #A_SHARED, #triton_gpu.shared_memory> - triton_gpu.local_dealloc %cst0 : !tt.memdesc<1x16x16xf16, #A_SHARED> + triton_gpu.local_dealloc %cst0 : !tt.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory> // CHECK-NEXT: offset = 0, size = 1024 - %c = triton_gpu.local_alloc : () -> !tt.memdesc<32x16xf16, #A_SHARED> + %c = triton_gpu.local_alloc : () -> !tt.memdesc<32x16xf16, #A_SHARED, #triton_gpu.shared_memory> - triton_gpu.local_dealloc %cst1 : !tt.memdesc<1x16x16xf16, #A_SHARED> - triton_gpu.local_dealloc %cst2 : !tt.memdesc<1x16x16xf16, #A_SHARED> + triton_gpu.local_dealloc %cst1 : !tt.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory> + triton_gpu.local_dealloc %cst2 : !tt.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory> // CHECK-NEXT: offset = 1024, size = 1024 - %cst4 = triton_gpu.local_alloc : () -> !tt.memdesc<32x16xf16, #A_SHARED> + %cst4 = triton_gpu.local_alloc : () -> !tt.memdesc<32x16xf16, #A_SHARED, #triton_gpu.shared_memory> // CHECK-NEXT: offset = 6144, size = 2048 - %e = triton_gpu.local_alloc : () -> !tt.memdesc<64x16xf16, #A_SHARED> - triton_gpu.local_dealloc %a : !tt.memdesc<32x16xf16, #A_SHARED> + %e = triton_gpu.local_alloc : () -> !tt.memdesc<64x16xf16, #A_SHARED, #triton_gpu.shared_memory> + triton_gpu.local_dealloc %a : !tt.memdesc<32x16xf16, #A_SHARED, #triton_gpu.shared_memory> // CHECK-NEXT: offset = 8192, size = 2048 - %d = triton_gpu.local_alloc : () -> !tt.memdesc<64x16xf16, #A_SHARED> - triton_gpu.local_dealloc %b : !tt.memdesc<32x16xf16, #A_SHARED> + %d = triton_gpu.local_alloc : () -> !tt.memdesc<64x16xf16, #A_SHARED, #triton_gpu.shared_memory> + triton_gpu.local_dealloc %b : !tt.memdesc<32x16xf16, #A_SHARED, #triton_gpu.shared_memory> // CHECK-NEXT: offset = 10240, size = 2048 - %f = triton_gpu.local_alloc : () -> !tt.memdesc<64x16xf16, #A_SHARED> - triton_gpu.local_dealloc %cst4 : !tt.memdesc<32x16xf16, #A_SHARED> - triton_gpu.local_dealloc %c : !tt.memdesc<32x16xf16, #A_SHARED> + %f = triton_gpu.local_alloc : () -> !tt.memdesc<64x16xf16, #A_SHARED, #triton_gpu.shared_memory> + triton_gpu.local_dealloc %cst4 : !tt.memdesc<32x16xf16, #A_SHARED, #triton_gpu.shared_memory> + triton_gpu.local_dealloc %c : !tt.memdesc<32x16xf16, #A_SHARED, #triton_gpu.shared_memory> // CHECK-NEXT: offset = 0, size = 2048 - %cst5 = triton_gpu.local_alloc : () -> !tt.memdesc<64x16xf16, #A_SHARED> + %cst5 = triton_gpu.local_alloc : () -> !tt.memdesc<64x16xf16, #A_SHARED, #triton_gpu.shared_memory> // CHECK-NEXT: offset = 2048, size = 4096 - %g = triton_gpu.local_alloc : () -> !tt.memdesc<128x16xf16, #A_SHARED> - triton_gpu.local_dealloc %e : !tt.memdesc<64x16xf16, #A_SHARED> + %g = triton_gpu.local_alloc : () -> !tt.memdesc<128x16xf16, #A_SHARED, #triton_gpu.shared_memory> + triton_gpu.local_dealloc %e : !tt.memdesc<64x16xf16, #A_SHARED, #triton_gpu.shared_memory> // CHECK-NEXT: offset = 2048, size = 4096 - %h = triton_gpu.local_alloc : () -> !tt.memdesc<128x16xf16, #A_SHARED> - triton_gpu.local_dealloc %d : !tt.memdesc<64x16xf16, #A_SHARED> + %h = triton_gpu.local_alloc : () -> !tt.memdesc<128x16xf16, #A_SHARED, #triton_gpu.shared_memory> + triton_gpu.local_dealloc %d : !tt.memdesc<64x16xf16, #A_SHARED, #triton_gpu.shared_memory> // CHECK-NEXT: offset = 2048, size = 4096 - %i = triton_gpu.local_alloc : () -> !tt.memdesc<128x16xf16, #A_SHARED> - triton_gpu.local_dealloc %f : !tt.memdesc<64x16xf16, #A_SHARED> - triton_gpu.local_dealloc %cst5 : !tt.memdesc<64x16xf16, #A_SHARED> + %i = triton_gpu.local_alloc : () -> !tt.memdesc<128x16xf16, #A_SHARED, #triton_gpu.shared_memory> + triton_gpu.local_dealloc %f : !tt.memdesc<64x16xf16, #A_SHARED, #triton_gpu.shared_memory> + triton_gpu.local_dealloc %cst5 : !tt.memdesc<64x16xf16, #A_SHARED, #triton_gpu.shared_memory> tt.return // CHECK-NEXT: size = 12288 } @@ -130,11 +130,11 @@ tt.func @preallocate(%A : !tt.ptr) { tt.func @unused(%A : !tt.ptr) { %cst = arith.constant dense<0.000000e+00> : tensor<32x16xf16, #AL> // CHECK: offset = 0, size = 1024 - %cst0 = triton_gpu.local_alloc %cst : (tensor<32x16xf16, #AL>) -> !tt.memdesc<32x16xf16, #A_SHARED> + %cst0 = triton_gpu.local_alloc %cst : (tensor<32x16xf16, #AL>) -> !tt.memdesc<32x16xf16, #A_SHARED, #triton_gpu.shared_memory> // CHECK-NEXT: offset = 0, size = 512 - %cst1 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED> + %cst1 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory> // CHECK-NEXT: offset = 0, size = 512 - %cst2 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED> + %cst2 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory> tt.return // CHECK: size = 1024 } @@ -143,33 +143,33 @@ tt.func @unused(%A : !tt.ptr) { // CHECK-LABEL: longlive tt.func @longlive(%A : !tt.ptr) { // CHECK: offset = 0, size = 512 - %cst0 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED> + %cst0 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory> // CHECK-NEXT: offset = 1024, size = 512 - %cst1 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED> + %cst1 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory> // CHECK-NEXT: offset = 2048, size = 512 - %cst2 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED> + %cst2 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory> // CHECK-NEXT: offset = 3072, size = 1024 - %a = triton_gpu.local_alloc : () -> !tt.memdesc<32x16xf16, #A_SHARED> - triton_gpu.local_dealloc %cst1 : !tt.memdesc<1x16x16xf16, #A_SHARED> - triton_gpu.local_dealloc %cst2 : !tt.memdesc<1x16x16xf16, #A_SHARED> + %a = triton_gpu.local_alloc : () -> !tt.memdesc<32x16xf16, #A_SHARED, #triton_gpu.shared_memory> + triton_gpu.local_dealloc %cst1 : !tt.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory> + triton_gpu.local_dealloc %cst2 : !tt.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory> // CHECK-NEXT: offset = 1024, size = 512 - %cst3 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED> + %cst3 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory> // CHECK-NEXT: offset = 2048, size = 512 - %cst4 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED> + %cst4 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory> // CHECK-NEXT: offset = 3072, size = 1024 - %b = triton_gpu.local_alloc : () -> !tt.memdesc<32x16xf16, #A_SHARED> + %b = triton_gpu.local_alloc : () -> !tt.memdesc<32x16xf16, #A_SHARED, #triton_gpu.shared_memory> // CHECK-NEXT: offset = 3072, size = 512 - %cst5 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED> + %cst5 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory> // CHECK-NEXT: offset = 3072, size = 512 - %cst6 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED> + %cst6 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory> // CHECK-NEXT: offset = 3072, size = 1024 - %c = triton_gpu.local_alloc : () -> !tt.memdesc<32x16xf16, #A_SHARED> - triton_gpu.local_dealloc %cst3 : !tt.memdesc<1x16x16xf16, #A_SHARED> - triton_gpu.local_dealloc %cst4 : !tt.memdesc<1x16x16xf16, #A_SHARED> + %c = triton_gpu.local_alloc : () -> !tt.memdesc<32x16xf16, #A_SHARED, #triton_gpu.shared_memory> + triton_gpu.local_dealloc %cst3 : !tt.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory> + triton_gpu.local_dealloc %cst4 : !tt.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory> // CHECK-NEXT: offset = 1024, size = 1024 - %d = triton_gpu.local_alloc : () -> !tt.memdesc<32x16xf16, #A_SHARED> - triton_gpu.local_dealloc %cst0 : !tt.memdesc<1x16x16xf16, #A_SHARED> + %d = triton_gpu.local_alloc : () -> !tt.memdesc<32x16xf16, #A_SHARED, #triton_gpu.shared_memory> + triton_gpu.local_dealloc %cst0 : !tt.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory> tt.return // CHECK-NEXT: size = 4096 } @@ -178,43 +178,43 @@ tt.func @longlive(%A : !tt.ptr) { // CHECK-LABEL: multi_color tt.func @multi_color(%A : !tt.ptr) { // CHECK: offset = 0, size = 64 - %cst = triton_gpu.local_alloc : () -> !tt.memdesc<4x8xf16, #A_SHARED> + %cst = triton_gpu.local_alloc : () -> !tt.memdesc<4x8xf16, #A_SHARED, #triton_gpu.shared_memory> // CHECK-NEXT: offset = 1536, size = 32 - %cst_0 = triton_gpu.local_alloc : () -> !tt.memdesc<4x4xf16, #A_SHARED> + %cst_0 = triton_gpu.local_alloc : () -> !tt.memdesc<4x4xf16, #A_SHARED, #triton_gpu.shared_memory> // CHECK-NEXT: offset = 1664, size = 128 - %cst_1 = triton_gpu.local_alloc : () -> !tt.memdesc<16x4xf16, #A_SHARED> + %cst_1 = triton_gpu.local_alloc : () -> !tt.memdesc<16x4xf16, #A_SHARED, #triton_gpu.shared_memory> %cst_2 = arith.constant dense<0.000000e+00> : tensor<16x32xf16, #AL> // CHECK-NEXT: scratch offset = 128, size = 1152 %0 = triton_gpu.convert_layout %cst_2 : tensor<16x32xf16, #AL> -> tensor<16x32xf16, #AL> - %1 = triton_gpu.local_load %cst : !tt.memdesc<4x8xf16, #A_SHARED> -> tensor<4x8xf16, #AL> + %1 = triton_gpu.local_load %cst : !tt.memdesc<4x8xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<4x8xf16, #AL> // CHECK-NEXT: offset = 0, size = 128 - %cst_3 = triton_gpu.local_alloc : () -> !tt.memdesc<4x16xf16, #A_SHARED> - %2 = triton_gpu.local_load %cst_0 : !tt.memdesc<4x4xf16, #A_SHARED> -> tensor<4x4xf16, #AL> + %cst_3 = triton_gpu.local_alloc : () -> !tt.memdesc<4x16xf16, #A_SHARED, #triton_gpu.shared_memory> + %2 = triton_gpu.local_load %cst_0 : !tt.memdesc<4x4xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<4x4xf16, #AL> // CHECK-NEXT: scratch offset = 0, size = 1152 %3 = triton_gpu.convert_layout %cst_2 : tensor<16x32xf16, #AL> -> tensor<16x32xf16, #AL> // CHECK-NEXT: offset = 0, size = 256 - %cst_4 = triton_gpu.local_alloc : () -> !tt.memdesc<4x32xf16, #A_SHARED> + %cst_4 = triton_gpu.local_alloc : () -> !tt.memdesc<4x32xf16, #A_SHARED, #triton_gpu.shared_memory> // CHECK-NEXT: offset = 256, size = 64 - %cst_5 = triton_gpu.local_alloc : () -> !tt.memdesc<4x8xf16, #A_SHARED> - %4 = triton_gpu.local_load %cst_5 : !tt.memdesc<4x8xf16, #A_SHARED> -> tensor<4x8xf16, #AL> - %5 = triton_gpu.local_load %cst_5 : !tt.memdesc<4x8xf16, #A_SHARED> -> tensor<4x8xf16, #AL> + %cst_5 = triton_gpu.local_alloc : () -> !tt.memdesc<4x8xf16, #A_SHARED, #triton_gpu.shared_memory> + %4 = triton_gpu.local_load %cst_5 : !tt.memdesc<4x8xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<4x8xf16, #AL> + %5 = triton_gpu.local_load %cst_5 : !tt.memdesc<4x8xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<4x8xf16, #AL> // CHECK-NEXT: offset = 1024, size = 512 - %cst_6 = triton_gpu.local_alloc : () -> !tt.memdesc<8x32xf16, #A_SHARED> + %cst_6 = triton_gpu.local_alloc : () -> !tt.memdesc<8x32xf16, #A_SHARED, #triton_gpu.shared_memory> // CHECK-NEXT: offset = 1792, size = 128 - %cst_7 = triton_gpu.local_alloc : () -> !tt.memdesc<2x32xf16, #A_SHARED> - %6 = triton_gpu.local_load %cst_0 : !tt.memdesc<4x4xf16, #A_SHARED> -> tensor<4x4xf16, #AL> + %cst_7 = triton_gpu.local_alloc : () -> !tt.memdesc<2x32xf16, #A_SHARED, #triton_gpu.shared_memory> + %6 = triton_gpu.local_load %cst_0 : !tt.memdesc<4x4xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<4x4xf16, #AL> // CHECK-NEXT: offset = 1024, size = 512 - %cst_8 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED> + %cst_8 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory> // CHECK-NEXT: offset = 256, size = 32 - %cst_9 = triton_gpu.local_alloc : () -> !tt.memdesc<4x4xf16, #A_SHARED> + %cst_9 = triton_gpu.local_alloc : () -> !tt.memdesc<4x4xf16, #A_SHARED, #triton_gpu.shared_memory> // CHECK-NEXT: offset = 1024, size = 512 - %cst_10 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED> - %7 = triton_gpu.local_load %cst_1 : !tt.memdesc<16x4xf16, #A_SHARED> -> tensor<16x4xf16, #AL> - %8 = triton_gpu.local_load %cst_4 : !tt.memdesc<4x32xf16, #A_SHARED> -> tensor<4x32xf16, #AL> + %cst_10 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory> + %7 = triton_gpu.local_load %cst_1 : !tt.memdesc<16x4xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<16x4xf16, #AL> + %8 = triton_gpu.local_load %cst_4 : !tt.memdesc<4x32xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<4x32xf16, #AL> // CHECK-NEXT: scratch offset = 0, size = 1152 %9 = triton_gpu.convert_layout %cst_2 : tensor<16x32xf16, #AL> -> tensor<16x32xf16, #AL> %cst_11 = arith.constant dense<0.000000e+00> : tensor<4x4xf16, #AL> - %10 = triton_gpu.local_load %cst_7 : !tt.memdesc<2x32xf16, #A_SHARED> -> tensor<2x32xf16, #AL> + %10 = triton_gpu.local_load %cst_7 : !tt.memdesc<2x32xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<2x32xf16, #AL> %cst_12 = arith.constant dense<0.000000e+00> : tensor<4x16xf16, #AL> %cst_13 = arith.constant dense<0.000000e+00> : tensor<8x32xf16, #AL> // CHECK-NEXT: size = 1920 @@ -225,25 +225,25 @@ tt.func @multi_color(%A : !tt.ptr) { // CHECK-LABEL: multi_color_multi_rounds tt.func @multi_color_multi_rounds(%arg0: !tt.ptr) { // CHECK: offset = 0, size = 32 - %cst = triton_gpu.local_alloc : () -> !tt.memdesc<4x4xf16, #A_SHARED> + %cst = triton_gpu.local_alloc : () -> !tt.memdesc<4x4xf16, #A_SHARED, #triton_gpu.shared_memory> // CHECK-NEXT: offset = 1280, size = 128 - %cst_0 = triton_gpu.local_alloc : () -> !tt.memdesc<16x4xf16, #A_SHARED> + %cst_0 = triton_gpu.local_alloc : () -> !tt.memdesc<16x4xf16, #A_SHARED, #triton_gpu.shared_memory> // CHECK-NEXT: offset = 2048, size = 8192 - %cst_1 = triton_gpu.local_alloc : () -> !tt.memdesc<1024x4xf16, #A_SHARED> + %cst_1 = triton_gpu.local_alloc : () -> !tt.memdesc<1024x4xf16, #A_SHARED, #triton_gpu.shared_memory> %cst_2 = arith.constant dense<0.000000e+00> : tensor<16x32xf16, #AL> // CHECK-NEXT: scratch offset = 128, size = 1152 %0 = triton_gpu.convert_layout %cst_2 : tensor<16x32xf16, #AL> -> tensor<16x32xf16, #AL> - %1 = triton_gpu.local_load %cst : !tt.memdesc<4x4xf16, #A_SHARED> -> tensor<4x4xf16, #AL> + %1 = triton_gpu.local_load %cst : !tt.memdesc<4x4xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<4x4xf16, #AL> // CHECK-NEXT: offset = 1152, size = 128 - %cst_3 = triton_gpu.local_alloc : () -> !tt.memdesc<2x32xf16, #A_SHARED> - %2 = triton_gpu.local_load %cst : !tt.memdesc<4x4xf16, #A_SHARED> -> tensor<4x4xf16, #AL> + %cst_3 = triton_gpu.local_alloc : () -> !tt.memdesc<2x32xf16, #A_SHARED, #triton_gpu.shared_memory> + %2 = triton_gpu.local_load %cst : !tt.memdesc<4x4xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<4x4xf16, #AL> // CHECK-NEXT: offset = 0, size = 512 - %cst_4 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED> - %3 = triton_gpu.local_load %cst_0 : !tt.memdesc<16x4xf16, #A_SHARED> -> tensor<16x4xf16, #AL> - %4 = triton_gpu.local_load %cst_1 : !tt.memdesc<1024x4xf16, #A_SHARED> -> tensor<1024x4xf16, #AL> + %cst_4 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory> + %3 = triton_gpu.local_load %cst_0 : !tt.memdesc<16x4xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<16x4xf16, #AL> + %4 = triton_gpu.local_load %cst_1 : !tt.memdesc<1024x4xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<1024x4xf16, #AL> // CHECK-NEXT: scratch offset = 0, size = 1152 %5 = triton_gpu.convert_layout %cst_2 : tensor<16x32xf16, #AL> -> tensor<16x32xf16, #AL> - %6 = triton_gpu.local_load %cst_3 : !tt.memdesc<2x32xf16, #A_SHARED> -> tensor<2x32xf16, #AL> + %6 = triton_gpu.local_load %cst_3 : !tt.memdesc<2x32xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<2x32xf16, #AL> // CHECK-NEXT: size = 10240 tt.return } @@ -252,10 +252,10 @@ tt.func @multi_color_multi_rounds(%arg0: !tt.ptr) { // CHECK-LABEL: alloc tt.func @alloc(%A : !tt.ptr) { // CHECK: offset = 0, size = 512 - %cst0 = triton_gpu.local_alloc : () -> !tt.memdesc<16x16xf16, #A_SHARED> + %cst0 = triton_gpu.local_alloc : () -> !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> %cst1 = arith.constant dense<0.000000e+00> : tensor<16x32xf16, #AL> // CHECK-NEXT: offset = 0, size = 512 - %cst2 = triton_gpu.local_alloc : () -> !tt.memdesc<16x16xf16, #A_SHARED> + %cst2 = triton_gpu.local_alloc : () -> !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> tt.return // CHECK-NEXT: size = 512 } @@ -264,10 +264,10 @@ tt.func @alloc(%A : !tt.ptr) { // CHECK-LABEL: dealloc tt.func @dealloc(%A : !tt.ptr) { // CHECK: offset = 0, size = 1024 - %cst0 = triton_gpu.local_alloc : () -> !tt.memdesc<32x16xf16, #A_SHARED> + %cst0 = triton_gpu.local_alloc : () -> !tt.memdesc<32x16xf16, #A_SHARED, #triton_gpu.shared_memory> // CHECK: offset = 1024, size = 1024 - %cst1 = triton_gpu.local_alloc : () -> !tt.memdesc<32x16xf16, #A_SHARED> - triton_gpu.local_dealloc %cst0 : !tt.memdesc<32x16xf16, #A_SHARED> + %cst1 = triton_gpu.local_alloc : () -> !tt.memdesc<32x16xf16, #A_SHARED, #triton_gpu.shared_memory> + triton_gpu.local_dealloc %cst0 : !tt.memdesc<32x16xf16, #A_SHARED, #triton_gpu.shared_memory> tt.return // CHECK-NEXT: size = 2048 } @@ -288,8 +288,8 @@ tt.func @scratch() { // CHECK-LABEL: trans tt.func @trans(%A : !tt.ptr) { // CHECK: offset = 0, size = 1024 - %tensor = triton_gpu.local_alloc : () -> !tt.memdesc<16x32xf16, #A_SHARED> - %b = tt.trans %tensor {order=array} : !tt.memdesc<16x32xf16, #A_SHARED> -> !tt.memdesc<32x16xf16, #A_SHARED_T> + %tensor = triton_gpu.local_alloc : () -> !tt.memdesc<16x32xf16, #A_SHARED, #triton_gpu.shared_memory> + %b = tt.trans %tensor {order=array} : !tt.memdesc<16x32xf16, #A_SHARED, #triton_gpu.shared_memory> -> !tt.memdesc<32x16xf16, #A_SHARED_T, #triton_gpu.shared_memory> tt.return } @@ -297,9 +297,9 @@ tt.func @trans(%A : !tt.ptr) { // CHECK-LABEL: extract_slice tt.func @extract_slice(%A : !tt.ptr) { // CHECK: offset = 0, size = 512 - %cst0 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED> + %cst0 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory> %index = arith.constant 0 : i32 - %cst1 = triton_gpu.memdesc_subview %cst0[%index, %index, %index] : !tt.memdesc<1x16x16xf16, #A_SHARED> -> !tt.memdesc<16x16xf16, #A_SHARED> + %cst1 = triton_gpu.memdesc_subview %cst0[%index, %index, %index] : !tt.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory> -> !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> tt.return // CHECK-NEXT: size = 512 } @@ -309,25 +309,25 @@ tt.func @extract_slice(%A : !tt.ptr) { // CHECK-LABEL: if tt.func @if(%i1 : i1) { // CHECK: offset = 0, size = 512 - %cst0 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED> + %cst0 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory> // CHECK-NEXT: offset = 1024, size = 512 - %cst1 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED> + %cst1 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory> scf.if %i1 { // CHECK-NEXT: offset = 2048, size = 1024 - %a = triton_gpu.local_alloc : () -> !tt.memdesc<32x16xf16, #A_SHARED> + %a = triton_gpu.local_alloc : () -> !tt.memdesc<32x16xf16, #A_SHARED, #triton_gpu.shared_memory> // CHECK-NEXT: offset = 2048, size = 1024 - %b = triton_gpu.local_alloc : () -> !tt.memdesc<32x16xf16, #A_SHARED> - triton_gpu.local_dealloc %cst0 : !tt.memdesc<1x16x16xf16, #A_SHARED> - triton_gpu.local_dealloc %cst1 : !tt.memdesc<1x16x16xf16, #A_SHARED> + %b = triton_gpu.local_alloc : () -> !tt.memdesc<32x16xf16, #A_SHARED, #triton_gpu.shared_memory> + triton_gpu.local_dealloc %cst0 : !tt.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory> + triton_gpu.local_dealloc %cst1 : !tt.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory> } // CHECK-NEXT: offset = 0, size = 512 - %cst2 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED> + %cst2 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory> // CHECK-NEXT: offset = 1024, size = 512 - %cst3 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED> + %cst3 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory> // CHECK-NEXT: offset = 2048, size = 1024 - %a = triton_gpu.local_alloc : () -> !tt.memdesc<32x16xf16, #A_SHARED> - triton_gpu.local_dealloc %cst2 : !tt.memdesc<1x16x16xf16, #A_SHARED> - triton_gpu.local_dealloc %cst3 : !tt.memdesc<1x16x16xf16, #A_SHARED> + %a = triton_gpu.local_alloc : () -> !tt.memdesc<32x16xf16, #A_SHARED, #triton_gpu.shared_memory> + triton_gpu.local_dealloc %cst2 : !tt.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory> + triton_gpu.local_dealloc %cst3 : !tt.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory> tt.return // CHECK-NEXT: size = 3072 } @@ -337,28 +337,28 @@ tt.func @if(%i1 : i1) { // CHECK-LABEL: if_else tt.func @if_else(%i1 : i1) { // CHECK: offset = 0, size = 512 - %cst0 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED> + %cst0 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory> // CHECK-NEXT: offset = 1024, size = 512 - %cst1 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED> + %cst1 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory> scf.if %i1 { // CHECK-NEXT: offset = 2048, size = 1024 - %a = triton_gpu.local_alloc : () -> !tt.memdesc<32x16xf16, #A_SHARED> + %a = triton_gpu.local_alloc : () -> !tt.memdesc<32x16xf16, #A_SHARED, #triton_gpu.shared_memory> // CHECK-NEXT: offset = 2048, size = 1024 - %b = triton_gpu.local_alloc : () -> !tt.memdesc<32x16xf16, #A_SHARED> + %b = triton_gpu.local_alloc : () -> !tt.memdesc<32x16xf16, #A_SHARED, #triton_gpu.shared_memory> } else { // CHECK-NEXT: offset = 2048, size = 512 - %cst2 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED> + %cst2 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory> // CHECK-NEXT: offset = 3072, size = 512 - %cst3 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED> + %cst3 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory> // CHECK-NEXT: offset = 4096, size = 1024 - %a = triton_gpu.local_alloc : () -> !tt.memdesc<32x16xf16, #A_SHARED> - triton_gpu.local_dealloc %cst2 : !tt.memdesc<1x16x16xf16, #A_SHARED> - triton_gpu.local_dealloc %cst3 : !tt.memdesc<1x16x16xf16, #A_SHARED> + %a = triton_gpu.local_alloc : () -> !tt.memdesc<32x16xf16, #A_SHARED, #triton_gpu.shared_memory> + triton_gpu.local_dealloc %cst2 : !tt.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory> + triton_gpu.local_dealloc %cst3 : !tt.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory> } // CHECK-NEXT: offset = 2048, size = 1024 - %a = triton_gpu.local_alloc : () -> !tt.memdesc<32x16xf16, #A_SHARED> - triton_gpu.local_dealloc %cst0 : !tt.memdesc<1x16x16xf16, #A_SHARED> - triton_gpu.local_dealloc %cst1 : !tt.memdesc<1x16x16xf16, #A_SHARED> + %a = triton_gpu.local_alloc : () -> !tt.memdesc<32x16xf16, #A_SHARED, #triton_gpu.shared_memory> + triton_gpu.local_dealloc %cst0 : !tt.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory> + triton_gpu.local_dealloc %cst1 : !tt.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory> tt.return // CHECK-NEXT: size = 5120 } @@ -368,13 +368,13 @@ tt.func @if_else(%i1 : i1) { // CHECK-LABEL: for tt.func @for(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr) { // CHECK: offset = 0, size = 8192 - %a_shared_init = triton_gpu.local_alloc : () -> !tt.memdesc<128x32xf16, #A_SHARED> + %a_shared_init = triton_gpu.local_alloc : () -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> // CHECK-NEXT: offset = 8192, size = 8192 - %b_shared_init = triton_gpu.local_alloc : () -> !tt.memdesc<128x32xf16, #A_SHARED> + %b_shared_init = triton_gpu.local_alloc : () -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> // CHECK-NEXT: offset = 16384, size = 8192 - %c_shared_init = triton_gpu.local_alloc : () -> !tt.memdesc<128x32xf16, #A_SHARED> - %a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init, %c_shared = %c_shared_init) -> (!tt.memdesc<128x32xf16, #A_SHARED>, !tt.memdesc<128x32xf16, #A_SHARED>, !tt.memdesc<128x32xf16, #A_SHARED>) { - scf.yield %b_shared, %a_shared, %a_shared : !tt.memdesc<128x32xf16, #A_SHARED>, !tt.memdesc<128x32xf16, #A_SHARED>, !tt.memdesc<128x32xf16, #A_SHARED> + %c_shared_init = triton_gpu.local_alloc : () -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> + %a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init, %c_shared = %c_shared_init) -> (!tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>) { + scf.yield %b_shared, %a_shared, %a_shared : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> } tt.return // CHECK-NEXT: size = 24576 @@ -383,18 +383,18 @@ tt.func @for(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !t // CHECK-LABEL: for_if_slice tt.func @for_if_slice(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr, %i1 : i1) { // CHECK: offset = 0, size = 8192 - %a_shared_init = triton_gpu.local_alloc : () -> !tt.memdesc<128x32xf16, #A_SHARED> + %a_shared_init = triton_gpu.local_alloc : () -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> // CHECK-NEXT: offset = 8192, size = 8192 - %b_shared_init = triton_gpu.local_alloc : () -> !tt.memdesc<128x32xf16, #A_SHARED> + %b_shared_init = triton_gpu.local_alloc : () -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> // CHECK-NEXT: offset = 16384, size = 8192 - %c_shared_init = triton_gpu.local_alloc : () -> !tt.memdesc<128x32xf16, #A_SHARED> - %a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init, %c_shared = %c_shared_init) -> (!tt.memdesc<128x32xf16, #A_SHARED>, !tt.memdesc<128x32xf16, #A_SHARED>, !tt.memdesc<128x32xf16, #A_SHARED>) { + %c_shared_init = triton_gpu.local_alloc : () -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> + %a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init, %c_shared = %c_shared_init) -> (!tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>) { scf.if %i1 { %index = arith.constant 8 : i32 - %cst0 = triton_gpu.memdesc_subview %a_shared[%index, %index] : !tt.memdesc<128x32xf16, #A_SHARED> -> !tt.memdesc<32xf16, #A_SHARED> + %cst0 = triton_gpu.memdesc_subview %a_shared[%index, %index] : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> -> !tt.memdesc<32xf16, #A_SHARED, #triton_gpu.shared_memory> scf.yield } - scf.yield %b_shared, %a_shared, %a_shared : !tt.memdesc<128x32xf16, #A_SHARED>, !tt.memdesc<128x32xf16, #A_SHARED>, !tt.memdesc<128x32xf16, #A_SHARED> + scf.yield %b_shared, %a_shared, %a_shared : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> } tt.return // CHECK-NEXT: size = 24576 @@ -404,16 +404,16 @@ tt.func @for_if_slice(%lb : index, %ub : index, %step : index, %A : !tt.ptr // CHECK-LABEL: for_use_ancestor tt.func @for_use_ancestor(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr, %i1 : i1) { // CHECK: offset = 0, size = 8192 - %a_shared_init = triton_gpu.local_alloc : () -> !tt.memdesc<128x32xf16, #A_SHARED> + %a_shared_init = triton_gpu.local_alloc : () -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> // CHECK-NEXT: offset = 8192, size = 8192 - %b_shared_init = triton_gpu.local_alloc : () -> !tt.memdesc<128x32xf16, #A_SHARED> + %b_shared_init = triton_gpu.local_alloc : () -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> // CHECK-NEXT: offset = 16384, size = 8192 - %c_shared_init = triton_gpu.local_alloc : () -> !tt.memdesc<128x32xf16, #A_SHARED> - %a_shared, %b_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init) -> (!tt.memdesc<128x32xf16, #A_SHARED>, !tt.memdesc<128x32xf16, #A_SHARED>) { - %c0 = tt.trans %c_shared_init {order=array} : !tt.memdesc<128x32xf16, #A_SHARED> -> !tt.memdesc<32x128xf16, #A_SHARED_T> + %c_shared_init = triton_gpu.local_alloc : () -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> + %a_shared, %b_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init) -> (!tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>) { + %c0 = tt.trans %c_shared_init {order=array} : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> -> !tt.memdesc<32x128xf16, #A_SHARED_T, #triton_gpu.shared_memory> // CHECK-NEXT: offset = 24576, size = 8192 - %c1 = triton_gpu.local_alloc : () -> !tt.memdesc<128x32xf16, #A_SHARED> - scf.yield %b_shared, %a_shared: !tt.memdesc<128x32xf16, #A_SHARED>, !tt.memdesc<128x32xf16, #A_SHARED> + %c1 = triton_gpu.local_alloc : () -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> + scf.yield %b_shared, %a_shared: !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> } tt.return // CHECK-NEXT: size = 32768 @@ -424,28 +424,28 @@ tt.func @for_use_ancestor(%lb : index, %ub : index, %step : index, %A : !tt.ptr< // CHECK-LABEL: for_for_if tt.func @for_for_if(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr, %i1 : i1) { // CHECK: offset = 0, size = 8192 - %a_shared_init = triton_gpu.local_alloc : () -> !tt.memdesc<128x32xf16, #A_SHARED> + %a_shared_init = triton_gpu.local_alloc : () -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> // CHECK-NEXT: offset = 8192, size = 8192 - %b_shared_init = triton_gpu.local_alloc : () -> !tt.memdesc<128x32xf16, #A_SHARED> + %b_shared_init = triton_gpu.local_alloc : () -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> // CHECK-NEXT: offset = 16384, size = 8192 - %c_shared_init = triton_gpu.local_alloc : () -> !tt.memdesc<128x32xf16, #A_SHARED> - %a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init, %c_shared = %c_shared_init) -> (!tt.memdesc<128x32xf16, #A_SHARED>, !tt.memdesc<128x32xf16, #A_SHARED>, !tt.memdesc<128x32xf16, #A_SHARED>) { - %c_shared_next = scf.for %jv = %lb to %ub step %step iter_args(%c_shared_next = %c_shared) -> (!tt.memdesc<128x32xf16, #A_SHARED>) { - %c_shared_next_next = scf.if %i1 -> !tt.memdesc<128x32xf16, #A_SHARED> { + %c_shared_init = triton_gpu.local_alloc : () -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> + %a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init, %c_shared = %c_shared_init) -> (!tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>) { + %c_shared_next = scf.for %jv = %lb to %ub step %step iter_args(%c_shared_next = %c_shared) -> (!tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>) { + %c_shared_next_next = scf.if %i1 -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> { // CHECK-NEXT: offset = 24576, size = 8192 - %cst0 = triton_gpu.local_alloc : () -> !tt.memdesc<128x32xf16, #A_SHARED> - scf.yield %cst0 : !tt.memdesc<128x32xf16, #A_SHARED> + %cst0 = triton_gpu.local_alloc : () -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> + scf.yield %cst0 : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> } else { // CHECK-NEXT: offset = 32768, size = 8192 - %cst1 = triton_gpu.local_alloc : () -> !tt.memdesc<128x32xf16, #A_SHARED> - scf.yield %cst1 : !tt.memdesc<128x32xf16, #A_SHARED> + %cst1 = triton_gpu.local_alloc : () -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> + scf.yield %cst1 : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> } - scf.yield %c_shared_next_next : !tt.memdesc<128x32xf16, #A_SHARED> + scf.yield %c_shared_next_next : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> } - scf.yield %a_shared, %b_shared, %c_shared_next : !tt.memdesc<128x32xf16, #A_SHARED>, !tt.memdesc<128x32xf16, #A_SHARED>, !tt.memdesc<128x32xf16, #A_SHARED> + scf.yield %a_shared, %b_shared, %c_shared_next : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> } // CHECK-NEXT: offset = 0, size = 8192 - %cst2 = triton_gpu.local_alloc : () -> !tt.memdesc<128x32xf16, #A_SHARED> + %cst2 = triton_gpu.local_alloc : () -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> tt.return // CHECK-NEXT: size = 40960 } @@ -457,7 +457,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { // CHECK-LABEL: alloc1 tt.func @alloc1(%A : !tt.ptr) { // CHECK: offset = 0, size = 512 - %cst0 = triton_gpu.local_alloc : () -> !tt.memdesc<16x16xf16, #A_SHARED> + %cst0 = triton_gpu.local_alloc : () -> !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> tt.return // CHECK-NEXT: size = 512 } @@ -465,7 +465,7 @@ tt.func @alloc1(%A : !tt.ptr) { // CHECK-LABEL: alloc2 tt.func @alloc2(%A : !tt.ptr) { // CHECK: offset = 0, size = 1024 - %cst0 = triton_gpu.local_alloc : () -> !tt.memdesc<32x16xf16, #A_SHARED> + %cst0 = triton_gpu.local_alloc : () -> !tt.memdesc<32x16xf16, #A_SHARED, #triton_gpu.shared_memory> tt.return // CHECK-NEXT: size = 1024 } @@ -474,10 +474,10 @@ tt.func @alloc2(%A : !tt.ptr) { tt.func @alloc3(%cond : i1) { scf.if %cond { // CHECK: offset = 0, size = 512 - %cst0 = triton_gpu.local_alloc : () -> !tt.memdesc<16x16xf16, #A_SHARED> + %cst0 = triton_gpu.local_alloc : () -> !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> } else { // CHECK-NEXT: offset = 0, size = 1024 - %cst0 = triton_gpu.local_alloc : () -> !tt.memdesc<16x32xf16, #A_SHARED> + %cst0 = triton_gpu.local_alloc : () -> !tt.memdesc<16x32xf16, #A_SHARED, #triton_gpu.shared_memory> } tt.return // CHECK-NEXT: size = 1024 @@ -499,7 +499,7 @@ tt.func @alloc4(%A : !tt.ptr, %cond : i1) { // CHECK-LABEL: single_call tt.func @single_call(%A : !tt.ptr) { // CHECK: offset = 0, size = 512 - %cst0 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED> + %cst0 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory> %cst1 = arith.constant dense<0.000000e+00> : tensor<16x32xf16, #AL> // CHECK-NEXT: virtual offset = 0, size = 512 tt.call @alloc1(%A) : (!tt.ptr) -> () @@ -510,7 +510,7 @@ tt.func @single_call(%A : !tt.ptr) { // CHECK-LABEL: multiple_calls tt.func @multiple_calls(%A : !tt.ptr) { // CHECK: offset = 0, size = 512 - %cst0 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED> + %cst0 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory> // CHECK-NEXT: virtual offset = 0, size = 512 tt.call @alloc1(%A) : (!tt.ptr) -> () %cst1 = arith.constant dense<0.000000e+00> : tensor<16x32xf16, #AL> @@ -525,9 +525,9 @@ tt.func @if_else_calls(%A : !tt.ptr, %cond : i1) { %cst = arith.constant dense<0.000000e+00> : tensor<16x32xf16, #AL> scf.if %cond { // CHECK: offset = 0, size = 512 - %cst0 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED> + %cst0 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory> // CHECK-NEXT: offset = 0, size = 1024 - %cst1 = triton_gpu.local_alloc %cst : (tensor<16x32xf16, #AL>) -> !tt.memdesc<16x32xf16, #A_SHARED> + %cst1 = triton_gpu.local_alloc %cst : (tensor<16x32xf16, #AL>) -> !tt.memdesc<16x32xf16, #A_SHARED, #triton_gpu.shared_memory> // CHECK-NEXT: virtual offset = 0, size = 512 tt.call @alloc1(%A) : (!tt.ptr) -> () } else { @@ -542,7 +542,7 @@ tt.func @if_else_calls(%A : !tt.ptr, %cond : i1) { // CHECK-LABEL: for_calls tt.func @for_calls(%A : !tt.ptr, %cond : i1) { // CHECK: offset = 0, size = 512 - %cst0 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED> + %cst0 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory> %cst1 = arith.constant dense<0.000000e+00> : tensor<16x32xf16, #AL> %lb = arith.constant 0 : index %ub = arith.constant 10 : index @@ -558,7 +558,7 @@ tt.func @for_calls(%A : !tt.ptr, %cond : i1) { // CHECK-LABEL: call_graph_1 tt.func @call_graph_1(%A : !tt.ptr, %cond : i1) { // CHECK: offset = 0, size = 512 - %cst0 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED> + %cst0 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory> // CHECK-NEXT: virtual offset = 0, size = 1024 tt.call @alloc3(%cond) : (i1) -> () tt.return @@ -568,7 +568,7 @@ tt.func @call_graph_1(%A : !tt.ptr, %cond : i1) { // CHECK-LABEL: call_graph_2 tt.func @call_graph_2(%A : !tt.ptr, %cond : i1) { // CHECK: offset = 0, size = 512 - %cst0 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED> + %cst0 = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory> // CHECK-NEXT: virtual offset = 0, size = 1024 tt.call @alloc4(%A, %cond) : (!tt.ptr, i1) -> () tt.return diff --git a/test/Analysis/test-membar.mlir b/test/Analysis/test-membar.mlir index 747a63959942..6f657cda5555 100644 --- a/test/Analysis/test-membar.mlir +++ b/test/Analysis/test-membar.mlir @@ -5,7 +5,6 @@ #BL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> #A_SHARED = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}> #A_SHARED_T = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [0, 1]}> -#B_SHARED = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}> #C = #triton_gpu.nvidia_mma<{versionMajor = 2, warpsPerCTA = [4, 1]}> #A_DOT = #triton_gpu.dot_op<{opIdx = 0, parent = #C, kWidth = 2}> #B_DOT = #triton_gpu.dot_op<{opIdx = 1, parent = #C, kWidth = 2}> @@ -47,10 +46,10 @@ tt.func @raw_single_block(%A : !tt.ptr) { %cst2 = arith.constant dense<0.000000e+00> : tensor<128x32xf16, #AL> %0 = tt.splat %A : !tt.ptr -> tensor<128x32x!tt.ptr, #AL> %1 = tt.load %0, %cst1, %cst2 : tensor<128x32x!tt.ptr, #AL> - %2 = triton_gpu.local_alloc %1 : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED> + %2 = triton_gpu.local_alloc %1 : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> // CHECK: gpu.barrier // CHECK-NEXT: triton_gpu.local_load - %3 = triton_gpu.local_load %2 : !tt.memdesc<128x32xf16, #A_SHARED> -> tensor<128x32xf16, #AL> + %3 = triton_gpu.local_load %2 : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<128x32xf16, #AL> tt.return } @@ -60,14 +59,14 @@ tt.func @war_single_block(%A : !tt.ptr) { %cst2 = arith.constant dense<0.000000e+00> : tensor<128x32xf16, #AL> %0 = tt.splat %A : !tt.ptr -> tensor<128x32x!tt.ptr, #AL> %1 = tt.load %0, %cst1, %cst2 : tensor<128x32x!tt.ptr, #AL> - %2 = triton_gpu.local_alloc %1 : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED> + %2 = triton_gpu.local_alloc %1 : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> // CHECK: triton_gpu.local_alloc // CHECK: gpu.barrier // CHECK-NEXT: triton_gpu.local_load - %3 = triton_gpu.local_load %2 : !tt.memdesc<128x32xf16, #A_SHARED> -> tensor<128x32xf16, #AL> + %3 = triton_gpu.local_load %2 : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<128x32xf16, #AL> // CHECK: gpu.barrier // CHECK-NEXT: %4 = triton_gpu.local_alloc - %4 = triton_gpu.local_alloc %1 : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED> + %4 = triton_gpu.local_alloc %1 : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> tt.return } @@ -77,25 +76,25 @@ tt.func @war_single_block_local_store(%A : !tt.ptr) { %cst2 = arith.constant dense<0.000000e+00> : tensor<128x32xf16, #AL> %0 = tt.splat %A : !tt.ptr -> tensor<128x32x!tt.ptr, #AL> %1 = tt.load %0, %cst1, %cst2 : tensor<128x32x!tt.ptr, #AL> - %2 = triton_gpu.local_alloc %1 : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED> + %2 = triton_gpu.local_alloc %1 : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> // CHECK: triton_gpu.local_alloc // CHECK: gpu.barrier // CHECK-NEXT: triton_gpu.local_load - %3 = triton_gpu.local_load %2 : !tt.memdesc<128x32xf16, #A_SHARED> -> tensor<128x32xf16, #AL> + %3 = triton_gpu.local_load %2 : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<128x32xf16, #AL> // CHECK: gpu.barrier // CHECK-NEXT: triton_gpu.local_store - triton_gpu.local_store %1, %2 : tensor<128x32xf16, #AL> -> !tt.memdesc<128x32xf16, #A_SHARED> + triton_gpu.local_store %1, %2 : tensor<128x32xf16, #AL> -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> tt.return } // CHECK-LABEL: scratch tt.func @scratch(%arg: tensor<16x16xf16, #AL>) { - %cst0 = triton_gpu.local_alloc %arg : (tensor<16x16xf16, #AL>) -> !tt.memdesc<16x16xf16, #A_SHARED> + %cst0 = triton_gpu.local_alloc %arg : (tensor<16x16xf16, #AL>) -> !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> // CHECK: gpu.barrier // CHECK-NEXT: triton_gpu.local_load // CHECK: gpu.barrier // CHECK: tt.reduce - %1 = triton_gpu.local_load %cst0 : !tt.memdesc<16x16xf16, #A_SHARED> -> tensor<16x16xf16, #AL> + %1 = triton_gpu.local_load %cst0 : !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<16x16xf16, #AL> %2 = "tt.reduce" (%1) ({ ^bb0(%arg1: f16, %arg2: f16): %add = arith.addf %arg1, %arg2 : f16 @@ -106,34 +105,34 @@ tt.func @scratch(%arg: tensor<16x16xf16, #AL>) { // CHECK-LABEL: async_wait tt.func @async_wait(%arg: tensor<32x16xf16, #AL>) { - %cst0 = triton_gpu.local_alloc %arg : (tensor<32x16xf16, #AL>) -> !tt.memdesc<32x16xf16, #A_SHARED> + %cst0 = triton_gpu.local_alloc %arg : (tensor<32x16xf16, #AL>) -> !tt.memdesc<32x16xf16, #A_SHARED, #triton_gpu.shared_memory> // CHECK: triton_gpu.async_wait triton_gpu.async_wait {num = 4 : i32} // CHECK: gpu.barrier // CHECK-NEXT: triton_gpu.local_load - %1 = triton_gpu.local_load %cst0 : !tt.memdesc<32x16xf16, #A_SHARED> -> tensor<32x16xf16, #AL> + %1 = triton_gpu.local_load %cst0 : !tt.memdesc<32x16xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<32x16xf16, #AL> tt.return } // CHECK-LABEL: subview tt.func @subview() { %cst0 = arith.constant dense<0.000000e+00> : tensor<32x16xf16, #AL> - %a = triton_gpu.local_alloc %cst0 : (tensor<32x16xf16, #AL>) -> !tt.memdesc<32x16xf16, #A_SHARED> + %a = triton_gpu.local_alloc %cst0 : (tensor<32x16xf16, #AL>) -> !tt.memdesc<32x16xf16, #A_SHARED, #triton_gpu.shared_memory> %index = arith.constant 0 : i32 - %0 = triton_gpu.memdesc_subview %a[%index, %index] : !tt.memdesc<32x16xf16, #A_SHARED> -> !tt.memdesc<16x16xf16, #A_SHARED> + %0 = triton_gpu.memdesc_subview %a[%index, %index] : !tt.memdesc<32x16xf16, #A_SHARED, #triton_gpu.shared_memory> -> !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> // CHECK: gpu.barrier // CHECK-NEXT: triton_gpu.local_load - %1 = triton_gpu.local_load %0 : !tt.memdesc<16x16xf16, #A_SHARED> -> tensor<16x16xf16, #AL> + %1 = triton_gpu.local_load %0 : !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<16x16xf16, #AL> // CHECK: gpu.barrier // CHECK-NEXT: triton_gpu.local_alloc - %2 = triton_gpu.local_alloc %1 : (tensor<16x16xf16, #AL>) -> !tt.memdesc<16x16xf16, #A_SHARED> + %2 = triton_gpu.local_alloc %1 : (tensor<16x16xf16, #AL>) -> !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> tt.return } // CHECK-LABEL: trans -tt.func @trans(%a: !tt.memdesc<16x32xf16, #A_SHARED>) { +tt.func @trans(%a: !tt.memdesc<16x32xf16, #A_SHARED, #triton_gpu.shared_memory>) { // CHECK-NOT: gpu.barrier - %b = tt.trans %a {order=array} : !tt.memdesc<16x32xf16, #A_SHARED> -> !tt.memdesc<32x16xf16, #A_SHARED_T> + %b = tt.trans %a {order=array} : !tt.memdesc<16x32xf16, #A_SHARED, #triton_gpu.shared_memory> -> !tt.memdesc<32x16xf16, #A_SHARED_T, #triton_gpu.shared_memory> tt.return } @@ -143,31 +142,31 @@ tt.func @async_copy_global_to_local(%A : !tt.ptr, %i1 : i1) { %a_ptr = tt.splat %A : !tt.ptr -> tensor<16x16x!tt.ptr, #AL> %mask = tt.splat %i1 : i1 -> tensor<16x16xi1, #AL> %other = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #AL> - %alloc = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED, mutable> - %subview = triton_gpu.memdesc_subview %alloc[%index, %index, %index] : !tt.memdesc<1x16x16xf16, #A_SHARED, mutable> -> !tt.memdesc<16x16xf16, #A_SHARED, mutable> - %1 = triton_gpu.async_copy_global_to_local %a_ptr, %subview : tensor<16x16x!tt.ptr, #AL> -> !tt.memdesc<16x16xf16, #A_SHARED, mutable> + %alloc = triton_gpu.local_alloc : () -> !tt.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %subview = triton_gpu.memdesc_subview %alloc[%index, %index, %index] : !tt.memdesc<1x16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> + %1 = triton_gpu.async_copy_global_to_local %a_ptr, %subview : tensor<16x16x!tt.ptr, #AL> -> !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> // CHECK: gpu.barrier // CHECK-NEXT: triton_gpu.local_load - %4 = triton_gpu.local_load %subview : !tt.memdesc<16x16xf16, #A_SHARED, mutable> -> tensor<16x16xf16, #AL> + %4 = triton_gpu.local_load %subview : !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> -> tensor<16x16xf16, #AL> tt.return } // If branch inserted a barrier for %cst0, but else didn't, then the barrier should be inserted in the parent region // CHECK-LABEL: multi_blocks tt.func @multi_blocks(%i1 : i1) { %cst = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #AL> - %cst0 = triton_gpu.local_alloc %cst : (tensor<16x16xf16, #AL>) -> !tt.memdesc<16x16xf16, #A_SHARED> + %cst0 = triton_gpu.local_alloc %cst : (tensor<16x16xf16, #AL>) -> !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> scf.if %i1 { // CHECK: gpu.barrier // CHECK-NEXT: triton_gpu.local_load - %0 = triton_gpu.local_load %cst0 : !tt.memdesc<16x16xf16, #A_SHARED> -> tensor<16x16xf16, #AL> + %0 = triton_gpu.local_load %cst0 : !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<16x16xf16, #AL> scf.yield } else { - %cst1 = triton_gpu.local_alloc %cst : (tensor<16x16xf16, #AL>) -> !tt.memdesc<16x16xf16, #A_SHARED> + %cst1 = triton_gpu.local_alloc %cst : (tensor<16x16xf16, #AL>) -> !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> scf.yield } // CHECK: gpu.barrier // CHECK-NEXT: triton_gpu.local_load - %2 = triton_gpu.local_load %cst0 : !tt.memdesc<16x16xf16, #A_SHARED> -> tensor<16x16xf16, #AL> + %2 = triton_gpu.local_load %cst0 : !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<16x16xf16, #AL> tt.return } @@ -175,21 +174,21 @@ tt.func @multi_blocks(%i1 : i1) { // CHECK-LABEL: multi_blocks_join_barrier tt.func @multi_blocks_join_barrier(%i1 : i1) { %cst = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #AL> - %cst0 = triton_gpu.local_alloc %cst : (tensor<16x16xf16, #AL>) -> !tt.memdesc<16x16xf16, #A_SHARED> + %cst0 = triton_gpu.local_alloc %cst : (tensor<16x16xf16, #AL>) -> !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> scf.if %i1 { // CHECK: gpu.barrier // CHECK-NEXT: triton_gpu.local_load - %0 = triton_gpu.local_load %cst0 : !tt.memdesc<16x16xf16, #A_SHARED> -> tensor<16x16xf16, #AL> + %0 = triton_gpu.local_load %cst0 : !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<16x16xf16, #AL> scf.yield } else { // CHECK: gpu.barrier // CHECK-NEXT: triton_gpu.local_load - %1 = triton_gpu.local_load %cst0 : !tt.memdesc<16x16xf16, #A_SHARED> -> tensor<16x16xf16, #AL> + %1 = triton_gpu.local_load %cst0 : !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<16x16xf16, #AL> scf.yield } // CHECK-NOT: gpu.barrier // CHECK: tt.return - %a_ = triton_gpu.local_load %cst0 : !tt.memdesc<16x16xf16, #A_SHARED> -> tensor<16x16xf16, #AL> + %a_ = triton_gpu.local_load %cst0 : !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<16x16xf16, #AL> tt.return } @@ -197,25 +196,25 @@ tt.func @multi_blocks_join_barrier(%i1 : i1) { // CHECK-LABEL: multi_blocks_yield tt.func @multi_blocks_yield(%i1 : i1) { %cst = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #AL> - %cst0 = triton_gpu.local_alloc %cst : (tensor<16x16xf16, #AL>) -> !tt.memdesc<16x16xf16, #A_SHARED> - %a = scf.if %i1 -> (!tt.memdesc<16x16xf16, #A_SHARED>) { + %cst0 = triton_gpu.local_alloc %cst : (tensor<16x16xf16, #AL>) -> !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> + %a = scf.if %i1 -> (!tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory>) { // CHECK: gpu.barrier // CHECK-NEXT: triton_gpu.local_load - %0 = triton_gpu.local_load %cst0 : !tt.memdesc<16x16xf16, #A_SHARED> -> tensor<16x16xf16, #AL> - %1 = triton_gpu.local_alloc %0 : (tensor<16x16xf16, #AL>) -> !tt.memdesc<16x16xf16, #A_SHARED> - scf.yield %1 : !tt.memdesc<16x16xf16, #A_SHARED> + %0 = triton_gpu.local_load %cst0 : !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<16x16xf16, #AL> + %1 = triton_gpu.local_alloc %0 : (tensor<16x16xf16, #AL>) -> !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> + scf.yield %1 : !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> } else { // CHECK: gpu.barrier // CHECK-NEXT: triton_gpu.local_load - %2 = triton_gpu.local_load %cst0 : !tt.memdesc<16x16xf16, #A_SHARED> -> tensor<16x16xf16, #AL> - %3 = triton_gpu.local_alloc %2 : (tensor<16x16xf16, #AL>) -> !tt.memdesc<16x16xf16, #A_SHARED> - scf.yield %3 : !tt.memdesc<16x16xf16, #A_SHARED> + %2 = triton_gpu.local_load %cst0 : !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<16x16xf16, #AL> + %3 = triton_gpu.local_alloc %2 : (tensor<16x16xf16, #AL>) -> !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> + scf.yield %3 : !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> } - %a_ = triton_gpu.local_load %cst0 : !tt.memdesc<16x16xf16, #A_SHARED> -> tensor<16x16xf16, #AL> + %a_ = triton_gpu.local_load %cst0 : !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<16x16xf16, #AL> // CHECK: triton_gpu.local_load // CHECK-NEXT: gpu.barrier // CHECK-NEXT: triton_gpu.local_load - %4 = triton_gpu.local_load %a : !tt.memdesc<16x16xf16, #A_SHARED> -> tensor<16x16xf16, #AL> + %4 = triton_gpu.local_load %a : !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<16x16xf16, #AL> tt.return } @@ -223,27 +222,27 @@ tt.func @multi_blocks_yield(%i1 : i1) { // CHECK-LABEL: multi_blocks_entry_no_shared tt.func @multi_blocks_entry_no_shared(%i1 : i1) { %cst = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #AL> - %cst0 = triton_gpu.local_alloc %cst : (tensor<16x16xf16, #AL>) -> !tt.memdesc<16x16xf16, #A_SHARED> - %a = scf.if %i1 -> (!tt.memdesc<16x16xf16, #A_SHARED>) { + %cst0 = triton_gpu.local_alloc %cst : (tensor<16x16xf16, #AL>) -> !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> + %a = scf.if %i1 -> (!tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory>) { // CHECK: gpu.barrier // CHECK-NEXT: triton_gpu.local_alloc // CHECK-NEXT: gpu.barrier // CHECK-NEXT: triton_gpu.local_load // CHECK-NEXT: gpu.barrier // CHECK-NEXT: triton_gpu.local_alloc - %cst1 = triton_gpu.local_alloc %cst : (tensor<16x16xf16, #AL>) -> !tt.memdesc<16x16xf16, #A_SHARED> - %0 = triton_gpu.local_load %cst1 : !tt.memdesc<16x16xf16, #A_SHARED> -> tensor<16x16xf16, #AL> - %1 = triton_gpu.local_alloc %0 : (tensor<16x16xf16, #AL>) -> !tt.memdesc<16x16xf16, #A_SHARED> - scf.yield %1 : !tt.memdesc<16x16xf16, #A_SHARED> + %cst1 = triton_gpu.local_alloc %cst : (tensor<16x16xf16, #AL>) -> !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> + %0 = triton_gpu.local_load %cst1 : !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<16x16xf16, #AL> + %1 = triton_gpu.local_alloc %0 : (tensor<16x16xf16, #AL>) -> !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> + scf.yield %1 : !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> } else { // CHECK-NOT: gpu.barrier // CHECK: triton_gpu.local_alloc - %cst1 = triton_gpu.local_alloc %cst : (tensor<16x16xf16, #AL>) -> !tt.memdesc<16x16xf16, #A_SHARED> - scf.yield %cst1 : !tt.memdesc<16x16xf16, #A_SHARED> + %cst1 = triton_gpu.local_alloc %cst : (tensor<16x16xf16, #AL>) -> !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> + scf.yield %cst1 : !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> } // CHECK: gpu.barrier // CHECK-NEXT: triton_gpu.local_load - %2 = triton_gpu.local_load %a : !tt.memdesc<16x16xf16, #A_SHARED> -> tensor<16x16xf16, #AL> + %2 = triton_gpu.local_load %a : !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<16x16xf16, #AL> tt.return } @@ -251,16 +250,16 @@ tt.func @multi_blocks_entry_no_shared(%i1 : i1) { // CHECK-LABEL: multi_blocks_noelse tt.func @multi_blocks_noelse(%i1 : i1) { %cst = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #AL> - %cst0 = triton_gpu.local_alloc %cst : (tensor<16x16xf16, #AL>) -> !tt.memdesc<16x16xf16, #A_SHARED> + %cst0 = triton_gpu.local_alloc %cst : (tensor<16x16xf16, #AL>) -> !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> scf.if %i1 { // CHECK: gpu.barrier // CHECK-NEXT: triton_gpu.local_load - %0 = triton_gpu.local_load %cst0 : !tt.memdesc<16x16xf16, #A_SHARED> -> tensor<16x16xf16, #AL> + %0 = triton_gpu.local_load %cst0 : !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<16x16xf16, #AL> scf.yield } // CHECK: gpu.barrier // CHECK-NEXT: triton_gpu.local_load - %1 = triton_gpu.local_load %cst0 : !tt.memdesc<16x16xf16, #A_SHARED> -> tensor<16x16xf16, #AL> + %1 = triton_gpu.local_load %cst0 : !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<16x16xf16, #AL> tt.return } @@ -268,39 +267,39 @@ tt.func @multi_blocks_noelse(%i1 : i1) { // CHECK-LABEL: multi_blocks_nested_scf tt.func @multi_blocks_nested_scf(%i1 : i1, %i2 : i1) { %cst = arith.constant dense<0.000000e+00> : tensor<128x32xf16, #AL> - %cst0 = triton_gpu.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED> + %cst0 = triton_gpu.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> scf.if %i1 { scf.if %i2 { // CHECK: gpu.barrier // CHECK-NEXT: triton_gpu.local_load - %0 = triton_gpu.local_load %cst0 : !tt.memdesc<128x32xf16, #A_SHARED> -> tensor<128x32xf16, #AL> + %0 = triton_gpu.local_load %cst0 : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<128x32xf16, #AL> scf.yield } scf.yield } else { // CHECK: gpu.barrier // CHECK-NEXT: triton_gpu.local_load - %1 = triton_gpu.local_load %cst0 : !tt.memdesc<128x32xf16, #A_SHARED> -> tensor<128x32xf16, #AL> + %1 = triton_gpu.local_load %cst0 : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<128x32xf16, #AL> scf.yield } // CHECK: gpu.barrier // CHECK-NEXT: triton_gpu.local_load - %2 = triton_gpu.local_load %cst0 : !tt.memdesc<128x32xf16, #A_SHARED> -> tensor<128x32xf16, #AL> + %2 = triton_gpu.local_load %cst0 : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<128x32xf16, #AL> tt.return } // CHECK-LABEL: for tt.func @for(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr) { %cst = arith.constant dense<0.000000e+00> : tensor<128x32xf16, #AL> - %a_shared_init = triton_gpu.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED> - %b_shared_init = triton_gpu.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED> - %c_shared_init = triton_gpu.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED> - %a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init, %c_shared = %c_shared_init) -> (!tt.memdesc<128x32xf16, #A_SHARED>, !tt.memdesc<128x32xf16, #A_SHARED>, !tt.memdesc<128x32xf16, #A_SHARED>) { + %a_shared_init = triton_gpu.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> + %b_shared_init = triton_gpu.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> + %c_shared_init = triton_gpu.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> + %a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init, %c_shared = %c_shared_init) -> (!tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>) { // CHECK: gpu.barrier // CHECK-NEXT: triton_gpu.local_load - %a0 = triton_gpu.local_load %a_shared : !tt.memdesc<128x32xf16, #A_SHARED> -> tensor<128x32xf16, #AL> - %b0 = triton_gpu.local_load %b_shared : !tt.memdesc<128x32xf16, #A_SHARED> -> tensor<128x32xf16, #AL> - scf.yield %b_shared, %a_shared, %a_shared : !tt.memdesc<128x32xf16, #A_SHARED>, !tt.memdesc<128x32xf16, #A_SHARED>, !tt.memdesc<128x32xf16, #A_SHARED> + %a0 = triton_gpu.local_load %a_shared : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<128x32xf16, #AL> + %b0 = triton_gpu.local_load %b_shared : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<128x32xf16, #AL> + scf.yield %b_shared, %a_shared, %a_shared : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> } tt.return } @@ -310,24 +309,24 @@ tt.func @for(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !t // CHECK-LABEL: for_alias tt.func @for_alias(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr) { %cst = arith.constant dense<0.000000e+00> : tensor<128x32xf16, #AL> - %a_shared_init = triton_gpu.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED> - %b_shared_init = triton_gpu.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED> + %a_shared_init = triton_gpu.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> + %b_shared_init = triton_gpu.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> // CHECK: gpu.barrier // CHECK-NEXT: triton_gpu.local_load - %a0 = triton_gpu.local_load %a_shared_init : !tt.memdesc<128x32xf16, #A_SHARED> -> tensor<128x32xf16, #AL> - %b0 = triton_gpu.local_load %b_shared_init : !tt.memdesc<128x32xf16, #A_SHARED> -> tensor<128x32xf16, #AL> - %0 = triton_gpu.local_alloc %a0 : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED> - %c_shared_init = triton_gpu.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED> - %a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init, %c_shared = %c_shared_init) -> (!tt.memdesc<128x32xf16, #A_SHARED>, !tt.memdesc<128x32xf16, #A_SHARED>, !tt.memdesc<128x32xf16, #A_SHARED>) { + %a0 = triton_gpu.local_load %a_shared_init : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<128x32xf16, #AL> + %b0 = triton_gpu.local_load %b_shared_init : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<128x32xf16, #AL> + %0 = triton_gpu.local_alloc %a0 : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> + %c_shared_init = triton_gpu.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> + %a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init, %c_shared = %c_shared_init) -> (!tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>) { // CHECK: gpu.barrier // CHECK-NEXT: triton_gpu.local_load - %a1 = triton_gpu.local_load %a_shared : !tt.memdesc<128x32xf16, #A_SHARED> -> tensor<128x32xf16, #AL> - %b1 = triton_gpu.local_load %b_shared : !tt.memdesc<128x32xf16, #A_SHARED> -> tensor<128x32xf16, #AL> - scf.yield %c_shared, %a_shared, %b_shared : !tt.memdesc<128x32xf16, #A_SHARED>, !tt.memdesc<128x32xf16, #A_SHARED>, !tt.memdesc<128x32xf16, #A_SHARED> + %a1 = triton_gpu.local_load %a_shared : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<128x32xf16, #AL> + %b1 = triton_gpu.local_load %b_shared : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<128x32xf16, #AL> + scf.yield %c_shared, %a_shared, %b_shared : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> } // CHECK: gpu.barrier // CHECK-NEXT: triton_gpu.local_load - %r = triton_gpu.local_load %0 : !tt.memdesc<128x32xf16, #A_SHARED> -> tensor<128x32xf16, #AL> + %r = triton_gpu.local_load %0 : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<128x32xf16, #AL> tt.return } @@ -336,63 +335,63 @@ tt.func @for_alias(%lb : index, %ub : index, %step : index, %A : !tt.ptr, % // CHECK-LABEL: for_reuse tt.func @for_reuse(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr) { %cst = arith.constant dense<0.000000e+00> : tensor<128x32xf16, #AL> - %a_shared_init = triton_gpu.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED> - %b_shared_init = triton_gpu.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED> + %a_shared_init = triton_gpu.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> + %b_shared_init = triton_gpu.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> // CHECK: gpu.barrier // CHECK-NEXT: triton_gpu.local_load - %a0 = triton_gpu.local_load %a_shared_init : !tt.memdesc<128x32xf16, #A_SHARED> -> tensor<128x32xf16, #AL> - %b0 = triton_gpu.local_load %b_shared_init : !tt.memdesc<128x32xf16, #A_SHARED> -> tensor<128x32xf16, #AL> - %0 = triton_gpu.local_alloc %a0 : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED> - %c_shared_init = triton_gpu.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED> - %a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init, %c_shared = %c_shared_init) -> (!tt.memdesc<128x32xf16, #A_SHARED>, !tt.memdesc<128x32xf16, #A_SHARED>, !tt.memdesc<128x32xf16, #A_SHARED>) { + %a0 = triton_gpu.local_load %a_shared_init : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<128x32xf16, #AL> + %b0 = triton_gpu.local_load %b_shared_init : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<128x32xf16, #AL> + %0 = triton_gpu.local_alloc %a0 : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> + %c_shared_init = triton_gpu.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> + %a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init, %c_shared = %c_shared_init) -> (!tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>) { // CHECK: gpu.barrier // CHECK-NEXT: triton_gpu.local_alloc - %a1 = triton_gpu.local_load %a_shared_init : !tt.memdesc<128x32xf16, #A_SHARED> -> tensor<128x32xf16, #AL> - %b1 = triton_gpu.local_load %b_shared_init : !tt.memdesc<128x32xf16, #A_SHARED> -> tensor<128x32xf16, #AL> - %1 = triton_gpu.local_alloc %a1 : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED> + %a1 = triton_gpu.local_load %a_shared_init : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<128x32xf16, #AL> + %b1 = triton_gpu.local_load %b_shared_init : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<128x32xf16, #AL> + %1 = triton_gpu.local_alloc %a1 : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> // CHECK: gpu.barrier // CHECK-NEXT: triton_gpu.local_alloc - %a2 = triton_gpu.local_load %a_shared_init : !tt.memdesc<128x32xf16, #A_SHARED> -> tensor<128x32xf16, #AL> - %b2 = triton_gpu.local_load %b_shared_init : !tt.memdesc<128x32xf16, #A_SHARED> -> tensor<128x32xf16, #AL> - %2 = triton_gpu.local_alloc %a1 : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED> - scf.yield %c_shared, %a_shared, %b_shared : !tt.memdesc<128x32xf16, #A_SHARED>, !tt.memdesc<128x32xf16, #A_SHARED>, !tt.memdesc<128x32xf16, #A_SHARED> + %a2 = triton_gpu.local_load %a_shared_init : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<128x32xf16, #AL> + %b2 = triton_gpu.local_load %b_shared_init : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<128x32xf16, #AL> + %2 = triton_gpu.local_alloc %a1 : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> + scf.yield %c_shared, %a_shared, %b_shared : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> } // CHECK: gpu.barrier // CHECK-NEXT: triton_gpu.local_load - %r = triton_gpu.local_load %0 : !tt.memdesc<128x32xf16, #A_SHARED> -> tensor<128x32xf16, #AL> + %r = triton_gpu.local_load %0 : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<128x32xf16, #AL> tt.return } // CHECK-LABEL: for_reuse_nested tt.func @for_reuse_nested(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr) { %cst = arith.constant dense<0.000000e+00> : tensor<128x32xf16, #AL> - %a_shared_init = triton_gpu.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED> - %b_shared_init = triton_gpu.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED> + %a_shared_init = triton_gpu.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> + %b_shared_init = triton_gpu.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> // CHECK: gpu.barrier // CHECK-NEXT: triton_gpu.local_load - %a0 = triton_gpu.local_load %a_shared_init : !tt.memdesc<128x32xf16, #A_SHARED> -> tensor<128x32xf16, #AL> - %b0 = triton_gpu.local_load %b_shared_init : !tt.memdesc<128x32xf16, #A_SHARED> -> tensor<128x32xf16, #AL> - %0 = triton_gpu.local_alloc %a0 : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED> - %c_shared_init = triton_gpu.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED> - %a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init, %c_shared = %c_shared_init) -> (!tt.memdesc<128x32xf16, #A_SHARED>, !tt.memdesc<128x32xf16, #A_SHARED>, !tt.memdesc<128x32xf16, #A_SHARED>) { + %a0 = triton_gpu.local_load %a_shared_init : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<128x32xf16, #AL> + %b0 = triton_gpu.local_load %b_shared_init : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<128x32xf16, #AL> + %0 = triton_gpu.local_alloc %a0 : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> + %c_shared_init = triton_gpu.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> + %a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init, %c_shared = %c_shared_init) -> (!tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>) { // CHECK: gpu.barrier // CHECK-NEXT: triton_gpu.local_alloc - %a1 = triton_gpu.local_load %a_shared_init : !tt.memdesc<128x32xf16, #A_SHARED> -> tensor<128x32xf16, #AL> - %b1 = triton_gpu.local_load %b_shared_init : !tt.memdesc<128x32xf16, #A_SHARED> -> tensor<128x32xf16, #AL> - %1 = triton_gpu.local_alloc %a1 : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED> - %a_shared_next, %b_shared_next, %c_shared_next = scf.for %ivv = %lb to %ub step %step iter_args(%a_shared_nested = %a_shared_init, %b_shared_nested = %b_shared_init, %c_shared_nested = %c_shared_init) -> (!tt.memdesc<128x32xf16, #A_SHARED>, !tt.memdesc<128x32xf16, #A_SHARED>, !tt.memdesc<128x32xf16, #A_SHARED>) { + %a1 = triton_gpu.local_load %a_shared_init : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<128x32xf16, #AL> + %b1 = triton_gpu.local_load %b_shared_init : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<128x32xf16, #AL> + %1 = triton_gpu.local_alloc %a1 : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> + %a_shared_next, %b_shared_next, %c_shared_next = scf.for %ivv = %lb to %ub step %step iter_args(%a_shared_nested = %a_shared_init, %b_shared_nested = %b_shared_init, %c_shared_nested = %c_shared_init) -> (!tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>) { // CHECK: gpu.barrier // CHECK-NEXT: triton_gpu.local_alloc - %a2 = triton_gpu.local_load %a_shared_init : !tt.memdesc<128x32xf16, #A_SHARED> -> tensor<128x32xf16, #AL> - %b2 = triton_gpu.local_load %b_shared_init : !tt.memdesc<128x32xf16, #A_SHARED> -> tensor<128x32xf16, #AL> - %2 = triton_gpu.local_alloc %a2 : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED> - scf.yield %c_shared_nested, %a_shared_nested, %b_shared_nested : !tt.memdesc<128x32xf16, #A_SHARED>, !tt.memdesc<128x32xf16, #A_SHARED>, !tt.memdesc<128x32xf16, #A_SHARED> + %a2 = triton_gpu.local_load %a_shared_init : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<128x32xf16, #AL> + %b2 = triton_gpu.local_load %b_shared_init : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<128x32xf16, #AL> + %2 = triton_gpu.local_alloc %a2 : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> + scf.yield %c_shared_nested, %a_shared_nested, %b_shared_nested : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> } - scf.yield %c_shared, %a_shared, %b_shared : !tt.memdesc<128x32xf16, #A_SHARED>, !tt.memdesc<128x32xf16, #A_SHARED>, !tt.memdesc<128x32xf16, #A_SHARED> + scf.yield %c_shared, %a_shared, %b_shared : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> } // CHECK: gpu.barrier // CHECK-NEXT: triton_gpu.local_load - %r = triton_gpu.local_load %0 : !tt.memdesc<128x32xf16, #A_SHARED> -> tensor<128x32xf16, #AL> + %r = triton_gpu.local_load %0 : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<128x32xf16, #AL> tt.return } @@ -400,25 +399,25 @@ tt.func @for_reuse_nested(%lb : index, %ub : index, %step : index, %A : !tt.ptr< // CHECK-LABEL: for_for_if tt.func @for_for_if(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr, %i1 : i1) { %cst = arith.constant dense<0.000000e+00> : tensor<128x32xf16, #AL> - %a_shared_init = triton_gpu.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED> - %b_shared_init = triton_gpu.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED> - %c_shared_init = triton_gpu.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED> - %a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init, %c_shared = %c_shared_init) -> (!tt.memdesc<128x32xf16, #A_SHARED>, !tt.memdesc<128x32xf16, #A_SHARED>, !tt.memdesc<128x32xf16, #A_SHARED>) { - %c_shared_next = scf.for %jv = %lb to %ub step %step iter_args(%c_shared_next = %c_shared) -> (!tt.memdesc<128x32xf16, #A_SHARED>) { - %c_shared_next_next = scf.if %i1 -> !tt.memdesc<128x32xf16, #A_SHARED> { + %a_shared_init = triton_gpu.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> + %b_shared_init = triton_gpu.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> + %c_shared_init = triton_gpu.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> + %a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init, %c_shared = %c_shared_init) -> (!tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>) { + %c_shared_next = scf.for %jv = %lb to %ub step %step iter_args(%c_shared_next = %c_shared) -> (!tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>) { + %c_shared_next_next = scf.if %i1 -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> { // CHECK: gpu.barrier // CHECK-NEXT: triton_gpu.local_alloc - %cst0 = triton_gpu.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED> - scf.yield %cst0 : !tt.memdesc<128x32xf16, #A_SHARED> + %cst0 = triton_gpu.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> + scf.yield %cst0 : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> } else { // CHECK: gpu.barrier // CHECK-NEXT: triton_gpu.local_alloc - %cst0 = triton_gpu.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED> - scf.yield %cst0 : !tt.memdesc<128x32xf16, #A_SHARED> + %cst0 = triton_gpu.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> + scf.yield %cst0 : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> } - scf.yield %c_shared_next_next : !tt.memdesc<128x32xf16, #A_SHARED> + scf.yield %c_shared_next_next : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> } - scf.yield %a_shared, %b_shared, %c_shared_next : !tt.memdesc<128x32xf16, #A_SHARED>, !tt.memdesc<128x32xf16, #A_SHARED>, !tt.memdesc<128x32xf16, #A_SHARED> + scf.yield %a_shared, %b_shared, %c_shared_next : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> } tt.return } @@ -427,30 +426,30 @@ tt.func @for_for_if(%lb : index, %ub : index, %step : index, %A : !tt.ptr, // CHECK-LABEL: for_if_for tt.func @for_if_for(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr, %i1 : i1) { %cst = arith.constant dense<0.000000e+00> : tensor<128x32xf16, #AL> - %a_shared_init = triton_gpu.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED> - %b_shared_init = triton_gpu.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED> - %c_shared_init = triton_gpu.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED> + %a_shared_init = triton_gpu.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> + %b_shared_init = triton_gpu.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> + %c_shared_init = triton_gpu.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> // CHECK: gpu.barrier - %c_blocked = triton_gpu.local_load %c_shared_init : !tt.memdesc<128x32xf16, #A_SHARED> -> tensor<128x32xf16, #AL> + %c_blocked = triton_gpu.local_load %c_shared_init : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<128x32xf16, #AL> - %a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init, %c_shared = %c_shared_init) -> (!tt.memdesc<128x32xf16, #A_SHARED>, !tt.memdesc<128x32xf16, #A_SHARED>, !tt.memdesc<128x32xf16, #A_SHARED>) { - %c_shared_next_next = scf.if %i1 -> !tt.memdesc<128x32xf16, #A_SHARED> { + %a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init, %c_shared = %c_shared_init) -> (!tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>) { + %c_shared_next_next = scf.if %i1 -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> { // CHECK: gpu.barrier // CHECK-NEXT: triton_gpu.local_alloc - %cst0 = triton_gpu.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED> - scf.yield %cst0 : !tt.memdesc<128x32xf16, #A_SHARED> + %cst0 = triton_gpu.local_alloc %cst : (tensor<128x32xf16, #AL>) -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> + scf.yield %cst0 : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> } else { - %c_shared_ = scf.for %jv = %lb to %ub step %step iter_args(%c_shared_next = %c_shared) -> (!tt.memdesc<128x32xf16, #A_SHARED>) { + %c_shared_ = scf.for %jv = %lb to %ub step %step iter_args(%c_shared_next = %c_shared) -> (!tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>) { // CHECK: gpu.barrier // CHECK-NEXT: triton_gpu.local_load - %c_blocked_next = triton_gpu.local_load %c_shared_next : !tt.memdesc<128x32xf16, #A_SHARED> -> tensor<128x32xf16, #AL> - scf.yield %c_shared : !tt.memdesc<128x32xf16, #A_SHARED> + %c_blocked_next = triton_gpu.local_load %c_shared_next : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<128x32xf16, #AL> + scf.yield %c_shared : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> } - scf.yield %c_shared_ : !tt.memdesc<128x32xf16, #A_SHARED> + scf.yield %c_shared_ : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> } // CHECK-NOT: gpu.barrier - %b_blocked_next = triton_gpu.local_load %b_shared: !tt.memdesc<128x32xf16, #A_SHARED> -> tensor<128x32xf16, #AL> - scf.yield %a_shared, %b_shared, %c_shared_next_next : !tt.memdesc<128x32xf16, #A_SHARED>, !tt.memdesc<128x32xf16, #A_SHARED>, !tt.memdesc<128x32xf16, #A_SHARED> + %b_blocked_next = triton_gpu.local_load %b_shared: !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<128x32xf16, #AL> + scf.yield %a_shared, %b_shared, %c_shared_next_next : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory> } tt.return } @@ -458,63 +457,63 @@ tt.func @for_if_for(%lb : index, %ub : index, %step : index, %A : !tt.ptr, // CHECK-LABEL: cf_if tt.func @cf_if(%i1 : i1) { %cst = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #AL> - %a = triton_gpu.local_alloc %cst : (tensor<16x16xf16, #AL>) -> !tt.memdesc<16x16xf16, #A_SHARED> + %a = triton_gpu.local_alloc %cst : (tensor<16x16xf16, #AL>) -> !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> cf.cond_br %i1, ^bb1, ^bb2 ^bb1: // pred: ^bb0 // CHECK: gpu.barrier // CHECK-NEXT: triton_gpu.local_load - %0 = triton_gpu.local_load %a : !tt.memdesc<16x16xf16, #A_SHARED> -> tensor<16x16xf16, #AL> + %0 = triton_gpu.local_load %a : !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<16x16xf16, #AL> cf.br ^bb2 ^bb2: // 2 preds: ^bb0, ^bb1 // CHECK: gpu.barrier // CHECK-NEXT: triton_gpu.local_load - %1 = triton_gpu.local_load %a : !tt.memdesc<16x16xf16, #A_SHARED> -> tensor<16x16xf16, #AL> + %1 = triton_gpu.local_load %a : !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<16x16xf16, #AL> tt.return } tt.func @cf_if_else(%i1 : i1) { %cst = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #AL> - %a = triton_gpu.local_alloc %cst : (tensor<16x16xf16, #AL>) -> !tt.memdesc<16x16xf16, #A_SHARED> + %a = triton_gpu.local_alloc %cst : (tensor<16x16xf16, #AL>) -> !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> cf.cond_br %i1, ^bb1, ^bb2 ^bb1: // pred: ^bb0 // CHECK: gpu.barrier // CHECK-NEXT: triton_gpu.local_load - %0 = triton_gpu.local_load %a : !tt.memdesc<16x16xf16, #A_SHARED> -> tensor<16x16xf16, #AL> - %1 = triton_gpu.local_alloc %0 : (tensor<16x16xf16, #AL>) -> !tt.memdesc<16x16xf16, #A_SHARED> - cf.br ^bb3(%1 : !tt.memdesc<16x16xf16, #A_SHARED>) + %0 = triton_gpu.local_load %a : !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<16x16xf16, #AL> + %1 = triton_gpu.local_alloc %0 : (tensor<16x16xf16, #AL>) -> !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> + cf.br ^bb3(%1 : !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory>) ^bb2: // pred: ^bb0 // CHECK: gpu.barrier // CHECK-NEXT: triton_gpu.local_load - %2 = triton_gpu.local_load %a : !tt.memdesc<16x16xf16, #A_SHARED> -> tensor<16x16xf16, #AL> - %3 = triton_gpu.local_alloc %2 : (tensor<16x16xf16, #AL>) -> !tt.memdesc<16x16xf16, #A_SHARED> - cf.br ^bb3(%3 : !tt.memdesc<16x16xf16, #A_SHARED>) -^bb3(%arg: !tt.memdesc<16x16xf16, #A_SHARED>): // 2 preds: ^bb1, ^bb2 + %2 = triton_gpu.local_load %a : !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<16x16xf16, #AL> + %3 = triton_gpu.local_alloc %2 : (tensor<16x16xf16, #AL>) -> !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> + cf.br ^bb3(%3 : !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory>) +^bb3(%arg: !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory>): // 2 preds: ^bb1, ^bb2 cf.br ^bb4 ^bb4: // pred: ^bb3 // CHECK: triton_gpu.local_load - %4 = triton_gpu.local_load %a : !tt.memdesc<16x16xf16, #A_SHARED> -> tensor<16x16xf16, #AL> + %4 = triton_gpu.local_load %a : !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<16x16xf16, #AL> // CHECK: gpu.barrier // CHECK-NEXT: triton_gpu.local_load - %5 = triton_gpu.local_load %arg : !tt.memdesc<16x16xf16, #A_SHARED> -> tensor<16x16xf16, #AL> + %5 = triton_gpu.local_load %arg : !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<16x16xf16, #AL> tt.return } tt.func @cf_if_else_return(%i1 : i1) { %cst = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #AL> - %a = triton_gpu.local_alloc %cst : (tensor<16x16xf16, #AL>) -> !tt.memdesc<16x16xf16, #A_SHARED> - %b = triton_gpu.local_alloc %cst : (tensor<16x16xf16, #AL>) -> !tt.memdesc<16x16xf16, #A_SHARED> + %a = triton_gpu.local_alloc %cst : (tensor<16x16xf16, #AL>) -> !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> + %b = triton_gpu.local_alloc %cst : (tensor<16x16xf16, #AL>) -> !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> cf.cond_br %i1, ^bb1, ^bb2 ^bb1: // pred: ^bb0 // CHECK: gpu.barrier // CHECK-NEXT: triton_gpu.local_load - %0 = triton_gpu.local_load %a : !tt.memdesc<16x16xf16, #A_SHARED> -> tensor<16x16xf16, #AL> - %1 = triton_gpu.local_load %b : !tt.memdesc<16x16xf16, #A_SHARED> -> tensor<16x16xf16, #AL> + %0 = triton_gpu.local_load %a : !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<16x16xf16, #AL> + %1 = triton_gpu.local_load %b : !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<16x16xf16, #AL> tt.return ^bb2: // pred: ^bb0 // CHECK: gpu.barrier // CHECK-NEXT: triton_gpu.local_load - %2 = triton_gpu.local_load %a : !tt.memdesc<16x16xf16, #A_SHARED> -> tensor<16x16xf16, #AL> - %3 = triton_gpu.local_load %b : !tt.memdesc<16x16xf16, #A_SHARED> -> tensor<16x16xf16, #AL> + %2 = triton_gpu.local_load %a : !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<16x16xf16, #AL> + %3 = triton_gpu.local_load %b : !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<16x16xf16, #AL> tt.return } @@ -525,38 +524,38 @@ module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 1 : // CHECK-LABEL: convert_layout1 tt.func @convert_layout1(%A : !tt.ptr) { // CHECK-NOT: gpu.barrier - %0 = triton_gpu.local_alloc : () -> !tt.memdesc<16x16xf16, #A_SHARED> - %1 = triton_gpu.local_load %0 : !tt.memdesc<16x16xf16, #A_SHARED> -> tensor<16x16xf16, #AL> + %0 = triton_gpu.local_alloc : () -> !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> + %1 = triton_gpu.local_load %0 : !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<16x16xf16, #AL> tt.return } // CHECK-LABEL: convert_layout2 tt.func @convert_layout2(%A : !tt.ptr) { %cst = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #AL> - %0 = triton_gpu.local_alloc : () -> !tt.memdesc<16x16xf16, #A_SHARED> - %1 = triton_gpu.local_alloc %cst : (tensor<16x16xf16, #AL>) -> !tt.memdesc<16x16xf16, #A_SHARED> + %0 = triton_gpu.local_alloc : () -> !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> + %1 = triton_gpu.local_alloc %cst : (tensor<16x16xf16, #AL>) -> !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> // CHECK: triton_gpu.local_load // CHECK-NEXT: gpu.barrier // CHECK: triton_gpu.local_load - %3 = triton_gpu.local_load %0 : !tt.memdesc<16x16xf16, #A_SHARED> -> tensor<16x16xf16, #AL> - %4 = triton_gpu.local_load %1 : !tt.memdesc<16x16xf16, #A_SHARED> -> tensor<16x16xf16, #AL> + %3 = triton_gpu.local_load %0 : !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<16x16xf16, #AL> + %4 = triton_gpu.local_load %1 : !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<16x16xf16, #AL> tt.return } // CHECK-LABEL: convert_layout3 tt.func @convert_layout3(%cond : i1) { scf.if %cond { - %0 = triton_gpu.local_alloc : () -> !tt.memdesc<16x64xf16, #A_SHARED> + %0 = triton_gpu.local_alloc : () -> !tt.memdesc<16x64xf16, #A_SHARED, #triton_gpu.shared_memory> // CHECK: triton_gpu.local_load // CHECK-NOT: gpu.barrier - %1 = triton_gpu.local_load %0 : !tt.memdesc<16x64xf16, #A_SHARED> -> tensor<16x64xf16, #AL> + %1 = triton_gpu.local_load %0 : !tt.memdesc<16x64xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<16x64xf16, #AL> } else { - %0 = triton_gpu.local_alloc : () -> !tt.memdesc<16x16xf16, #A_SHARED> + %0 = triton_gpu.local_alloc : () -> !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> // CHECK: triton_gpu.local_load // CHECK-NEXT: gpu.barrier // CHECK-NEXT: triton_gpu.local_alloc - %1 = triton_gpu.local_load %0 : !tt.memdesc<16x16xf16, #A_SHARED> -> tensor<16x16xf16, #AL> - %2 = triton_gpu.local_alloc %1 : (tensor<16x16xf16, #AL>) -> !tt.memdesc<16x16xf16, #A_SHARED> + %1 = triton_gpu.local_load %0 : !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> -> tensor<16x16xf16, #AL> + %2 = triton_gpu.local_alloc %1 : (tensor<16x16xf16, #AL>) -> !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> } tt.return } @@ -595,7 +594,7 @@ tt.func @single_call_no_sync(%A : !tt.ptr) { // CHECK-LABEL: multiple_calls tt.func @multiple_calls(%A : !tt.ptr) { %cst = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #AL> - %cst0 = triton_gpu.local_alloc %cst : (tensor<16x16xf16, #AL>) -> !tt.memdesc<16x16xf16, #A_SHARED> + %cst0 = triton_gpu.local_alloc %cst : (tensor<16x16xf16, #AL>) -> !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> tt.call @convert_layout1(%A) : (!tt.ptr) -> () %cst1 = arith.constant dense<0.000000e+00> : tensor<16x32xf16, #AL> tt.call @convert_layout2(%A) : (!tt.ptr) -> () @@ -607,12 +606,12 @@ tt.func @if_else_calls(%A : !tt.ptr, %cond : i1) { scf.if %cond { %cst = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #AL> %cst_ = arith.constant dense<0.000000e+00> : tensor<16x32xf16, #AL> - %cst0 = triton_gpu.local_alloc %cst : (tensor<16x16xf16, #AL>) -> !tt.memdesc<16x16xf16, #A_SHARED> + %cst0 = triton_gpu.local_alloc %cst : (tensor<16x16xf16, #AL>) -> !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> // CHECK: gpu.barrier // CHECK-NEXT: tt.call // CHECK-NEXT: gpu.barrier tt.call @convert_layout1(%A) : (!tt.ptr) -> () - %cst1 = triton_gpu.local_alloc %cst_ : (tensor<16x32xf16, #AL>) -> !tt.memdesc<16x32xf16, #A_SHARED> + %cst1 = triton_gpu.local_alloc %cst_ : (tensor<16x32xf16, #AL>) -> !tt.memdesc<16x32xf16, #A_SHARED, #triton_gpu.shared_memory> } else { %cst0 = arith.constant dense<0.000000e+00> : tensor<16x32xf16, #AL> // CHECK: tt.call @@ -625,7 +624,7 @@ tt.func @if_else_calls(%A : !tt.ptr, %cond : i1) { // CHECK-LABEL: for_calls tt.func @for_calls(%A : !tt.ptr, %cond : i1) { %cst = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #AL> - %cst0 = triton_gpu.local_alloc %cst : (tensor<16x16xf16, #AL>) -> !tt.memdesc<16x16xf16, #A_SHARED> + %cst0 = triton_gpu.local_alloc %cst : (tensor<16x16xf16, #AL>) -> !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> %cst1 = arith.constant dense<0.000000e+00> : tensor<16x32xf16, #AL> %lb = arith.constant 0 : index %ub = arith.constant 10 : index @@ -641,7 +640,7 @@ tt.func @for_calls(%A : !tt.ptr, %cond : i1) { // CHECK-LABEL: call_graph_1 tt.func @call_graph_1(%A : !tt.ptr, %cond : i1) { %cst = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #AL> - %cst0 = triton_gpu.local_alloc %cst : (tensor<16x16xf16, #AL>) -> !tt.memdesc<16x16xf16, #A_SHARED> // CHECK: gpu.barrier + %cst0 = triton_gpu.local_alloc %cst : (tensor<16x16xf16, #AL>) -> !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> // CHECK: gpu.barrier // CHECK-NEXT: tt.call tt.call @convert_layout3(%cond) : (i1) -> () tt.return @@ -653,7 +652,7 @@ tt.func @call_graph_2(%A : !tt.ptr, %cond : i1) { tt.call @convert_layout4(%A, %cond) : (!tt.ptr, i1) -> () // CHECK: tt.call // CHECK-NEXT: gpu.barrier - %cst0 = triton_gpu.local_alloc %cst : (tensor<16x16xf16, #AL>) -> !tt.memdesc<16x16xf16, #A_SHARED> + %cst0 = triton_gpu.local_alloc %cst : (tensor<16x16xf16, #AL>) -> !tt.memdesc<16x16xf16, #A_SHARED, #triton_gpu.shared_memory> tt.return } diff --git a/test/Conversion/amd/decompose-unsupported-conversions.mlir b/test/Conversion/amd/decompose-unsupported-conversions.mlir index b5fc5b72ce47..a6d39bdfb4b1 100644 --- a/test/Conversion/amd/decompose-unsupported-conversions.mlir +++ b/test/Conversion/amd/decompose-unsupported-conversions.mlir @@ -7,7 +7,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { tt.func @wmma_to_wmma_dot_op(%arg0: tensor<16x16xf16, #mma>) { // CHECK: %[[SRC_BLOCKED:.+]] = triton_gpu.convert_layout %{{.*}} : tensor<16x16xf16, #[[WMMA]]> -> tensor<16x16xf16, #[[BLOCKED]]> - // CHECK-NEXT: %[[INT_SHARED:.+]] = triton_gpu.local_alloc %[[SRC_BLOCKED]] : {{.*}} -> !tt.memdesc<16x16xf16, #[[SHARED]]> + // CHECK-NEXT: %[[INT_SHARED:.+]] = triton_gpu.local_alloc %[[SRC_BLOCKED]] : {{.*}} -> !tt.memdesc<16x16xf16, #[[SHARED]], #triton_gpu.shared_memory> // CHECK-NEXT: %[[DST_DOT_OP:.+]] = triton_gpu.local_load %[[INT_SHARED]] : {{.*}} -> tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #[[WMMA]], kWidth = 16}>> %0 = triton_gpu.convert_layout %arg0 : tensor<16x16xf16, #mma> -> tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 16}>> tt.return diff --git a/test/Conversion/tritongpu_to_llvm.mlir b/test/Conversion/tritongpu_to_llvm.mlir index 838c884b2a26..9c207f7a95e3 100644 --- a/test/Conversion/tritongpu_to_llvm.mlir +++ b/test/Conversion/tritongpu_to_llvm.mlir @@ -445,7 +445,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // CHECK: llvm.mlir.addressof @global_smem // CHECK-NEXT: llvm.getelementptr // CHECK-NEXT: llvm.mlir.constant - %0 = triton_gpu.local_alloc : () -> !tt.memdesc<16x16xf16, #shared0> + %0 = triton_gpu.local_alloc : () -> !tt.memdesc<16x16xf16, #shared0, #triton_gpu.shared_memory> tt.return } } @@ -475,8 +475,8 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // CHECK-NEXT: llvm.getelementptr %index = arith.constant 1 : i32 %zero = arith.constant 0 : i32 - %0 = triton_gpu.local_alloc : () -> !tt.memdesc<128x16x32xf32, #shared0> - %1 = triton_gpu.memdesc_subview %0[%index, %zero, %zero] : !tt.memdesc<128x16x32xf32, #shared0> -> !tt.memdesc<16x32xf32, #shared0> + %0 = triton_gpu.local_alloc : () -> !tt.memdesc<128x16x32xf32, #shared0, #triton_gpu.shared_memory> + %1 = triton_gpu.memdesc_subview %0[%index, %zero, %zero] : !tt.memdesc<128x16x32xf32, #shared0, #triton_gpu.shared_memory> -> !tt.memdesc<16x32xf32, #shared0, #triton_gpu.shared_memory> tt.return } } @@ -506,7 +506,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : %24 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #slice1d0> %59 = tt.addptr %58, %24 : tensor<64x!tt.ptr, #slice1d0>, tensor<64xi32, #slice1d0> %66 = tt.addptr %59, %cst_2 : tensor<64x!tt.ptr, #slice1d0>, tensor<64xi32, #slice1d0> - %71 = triton_gpu.local_alloc : () -> !tt.memdesc<2x64xi64, #shared> + %71 = triton_gpu.local_alloc : () -> !tt.memdesc<2x64xi64, #shared, #triton_gpu.shared_memory> // CHECK: llvm.inline_asm has_side_effects asm_dialect = att // CHECK-SAME: cp.async.ca.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x8, 0x8 // CHECK: cp.async.ca.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x8, 0x8 @@ -517,7 +517,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : // CHECK: cp.async.ca.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x8, 0x8 // CHECK: cp.async.ca.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x8, 0x8 // CHECK: cp.async.commit_group - %73 = triton_gpu.async_copy_global_to_local %66, %71 : tensor<64x!tt.ptr, #slice1d0> -> !tt.memdesc<2x64xi64, #shared> + %73 = triton_gpu.async_copy_global_to_local %66, %71 : tensor<64x!tt.ptr, #slice1d0> -> !tt.memdesc<2x64xi64, #shared, #triton_gpu.shared_memory> triton_gpu.async_commit_group %73 tt.return } @@ -550,7 +550,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : %off = arith.addi %broadcast_off0, %broadcast_off1 : tensor<16x64xi32, #AL> %a_init = tt.splat %arg0 : !tt.ptr -> tensor<16x64x!tt.ptr, #AL> %a_ptr = tt.addptr %a_init, %off : tensor<16x64x!tt.ptr, #AL>, tensor<16x64xi32, #AL> - %tensor = triton_gpu.local_alloc : () -> !tt.memdesc<16x64xf32, #A> + %tensor = triton_gpu.local_alloc : () -> !tt.memdesc<16x64xf32, #A, #triton_gpu.shared_memory> %index = arith.constant 1 : i32 // CHECK: llvm.inline_asm has_side_effects asm_dialect = att @@ -559,7 +559,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // CHECK-SAME: cp.async.cg.shared.global [ ${{.*}} + 16 ], [ ${{.*}} + 0 ], 0x10, 0x10 // CHECK: llvm.inline_asm has_side_effects asm_dialect = att // CHECK-SAME: cp.async.commit_group - %a = triton_gpu.async_copy_global_to_local %a_ptr, %tensor : tensor<16x64x!tt.ptr, #AL> -> !tt.memdesc<16x64xf32, #A> + %a = triton_gpu.async_copy_global_to_local %a_ptr, %tensor : tensor<16x64x!tt.ptr, #AL> -> !tt.memdesc<16x64xf32, #A, #triton_gpu.shared_memory> triton_gpu.async_commit_group tt.return } @@ -592,7 +592,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : %off = arith.addi %broadcast_off0, %broadcast_off1 : tensor<16x32xi32, #AL> %a_init = tt.splat %arg0 : !tt.ptr -> tensor<16x32x!tt.ptr, #AL> %a_ptr = tt.addptr %a_init, %off : tensor<16x32x!tt.ptr, #AL>, tensor<16x32xi32, #AL> - %tensor = triton_gpu.local_alloc : () -> !tt.memdesc<16x32xf32, #A> + %tensor = triton_gpu.local_alloc : () -> !tt.memdesc<16x32xf32, #A, #triton_gpu.shared_memory> %index = arith.constant 1 : i32 // CHECK: llvm.inline_asm @@ -605,7 +605,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // CHECK-SAME: cp.async.ca.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x4, 0x4 // CHECK: llvm.inline_asm // CHECK-SAME: cp.async.commit_group - %a = triton_gpu.async_copy_global_to_local %a_ptr, %tensor : tensor<16x32x!tt.ptr, #AL> -> !tt.memdesc<16x32xf32, #A> + %a = triton_gpu.async_copy_global_to_local %a_ptr, %tensor : tensor<16x32x!tt.ptr, #AL> -> !tt.memdesc<16x32xf32, #A, #triton_gpu.shared_memory> triton_gpu.async_commit_group tt.return } @@ -637,7 +637,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : %off = arith.addi %broadcast_off0, %broadcast_off1 : tensor<32x32xi32, #AL> %a_init = tt.splat %arg0 : !tt.ptr -> tensor<32x32x!tt.ptr, #AL> %a_ptr = tt.addptr %a_init, %off : tensor<32x32x!tt.ptr, #AL>, tensor<32x32xi32, #AL> - %tensor = triton_gpu.local_alloc : () -> !tt.memdesc<32x32xf32, #A> + %tensor = triton_gpu.local_alloc : () -> !tt.memdesc<32x32xf32, #A, #triton_gpu.shared_memory> %index = arith.constant 1 : i32 // CHECK: llvm.mlir.constant(0 : i32) : i32 @@ -662,7 +662,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // CHECK-SAME: cp.async.ca.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x4, 0x4 // CHECK: llvm.inline_asm // CHECK-SAME: cp.async.commit_group - %a = triton_gpu.async_copy_global_to_local %a_ptr, %tensor : tensor<32x32x!tt.ptr, #AL> -> !tt.memdesc<32x32xf32, #A> + %a = triton_gpu.async_copy_global_to_local %a_ptr, %tensor : tensor<32x32x!tt.ptr, #AL> -> !tt.memdesc<32x32xf32, #A, #triton_gpu.shared_memory> triton_gpu.async_commit_group tt.return } @@ -806,14 +806,14 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} { // CHECK-LABEL: convert_dot tt.func @convert_dot(%A: tensor<16x16xf16, #blocked0>, %B: tensor<16x16xf16, #blocked0>) { - %AA = triton_gpu.local_alloc %A : (tensor<16x16xf16, #blocked0>) -> !tt.memdesc<16x16xf16, #shared0> - %BB = triton_gpu.local_alloc %B : (tensor<16x16xf16, #blocked0>) -> !tt.memdesc<16x16xf16, #shared0> + %AA = triton_gpu.local_alloc %A : (tensor<16x16xf16, #blocked0>) -> !tt.memdesc<16x16xf16, #shared0, #triton_gpu.shared_memory> + %BB = triton_gpu.local_alloc %B : (tensor<16x16xf16, #blocked0>) -> !tt.memdesc<16x16xf16, #shared0, #triton_gpu.shared_memory> // CHECK: llvm.inline_asm // CHECK: ldmatrix.sync.aligned.m8n8.x4 // CHECK: llvm.inline_asm // CHECK-SAME: ldmatrix.sync.aligned.m8n8.x4 - %AA_DOT = triton_gpu.local_load %AA : !tt.memdesc<16x16xf16, #shared0> -> tensor<16x16xf16, #dot_operand_a> - %BB_DOT = triton_gpu.local_load %BB : !tt.memdesc<16x16xf16, #shared0> -> tensor<16x16xf16, #dot_operand_b> + %AA_DOT = triton_gpu.local_load %AA : !tt.memdesc<16x16xf16, #shared0, #triton_gpu.shared_memory> -> tensor<16x16xf16, #dot_operand_a> + %BB_DOT = triton_gpu.local_load %BB : !tt.memdesc<16x16xf16, #shared0, #triton_gpu.shared_memory> -> tensor<16x16xf16, #dot_operand_b> %cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf32, #mma0> // CHECK: llvm.inline_asm @@ -906,7 +906,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : // CHECK-SAME: !llvm.ptr<3> // CHECK: llvm.store // CHECK-SAME: !llvm.ptr<3> - %0 = triton_gpu.local_alloc %arg0 : (tensor<128x32xf32, #blocked0>) -> !tt.memdesc<128x32xf32, #shared0> + %0 = triton_gpu.local_alloc %arg0 : (tensor<128x32xf32, #blocked0>) -> !tt.memdesc<128x32xf32, #shared0, #triton_gpu.shared_memory> tt.return } } @@ -963,11 +963,11 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : #dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mma, kWidth=2}> module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { tt.func @matmul_kernel_dot_operand_layout(%ptr:!tt.ptr {tt.divisibility = 16 : i32}, - %a:!tt.memdesc<128x32xf16, #shared>, %b:!tt.memdesc<32x256xf16, #shared>) { + %a:!tt.memdesc<128x32xf16, #shared, #triton_gpu.shared_memory>, %b:!tt.memdesc<32x256xf16, #shared, #triton_gpu.shared_memory>) { %cst = arith.constant dense<0.000000e+00> : tensor<128x256xf32, #mma> // CHECK: ldmatrix.sync.aligned.m8n8.x4.shared.b16 - %a_mat = triton_gpu.local_load %a : !tt.memdesc<128x32xf16, #shared> -> tensor<128x32xf16, #dot_operand_a> - %b_mat = triton_gpu.local_load %b : !tt.memdesc<32x256xf16, #shared> -> tensor<32x256xf16, #dot_operand_b> + %a_mat = triton_gpu.local_load %a : !tt.memdesc<128x32xf16, #shared, #triton_gpu.shared_memory> -> tensor<128x32xf16, #dot_operand_a> + %b_mat = triton_gpu.local_load %b : !tt.memdesc<32x256xf16, #shared, #triton_gpu.shared_memory> -> tensor<32x256xf16, #dot_operand_b> %28 = tt.dot %a_mat, %b_mat, %cst : tensor<128x32xf16, #dot_operand_a> * tensor<32x256xf16, #dot_operand_b> -> tensor<128x256xf32, #mma> %38 = triton_gpu.convert_layout %28 : tensor<128x256xf32, #mma> -> tensor<128x256xf32, #blocked> @@ -989,11 +989,11 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : #dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mma}> module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { tt.func @matmul884_kernel_dot_operand_layout(%ptr:!tt.ptr {tt.divisibility = 16 : i32}, - %a:!tt.memdesc<32x64xf16, #shared0>, %b:!tt.memdesc<64x64xf16, #shared1>) { + %a:!tt.memdesc<32x64xf16, #shared0, #triton_gpu.shared_memory>, %b:!tt.memdesc<64x64xf16, #shared1, #triton_gpu.shared_memory>) { %cst = arith.constant dense<0.000000e+00> : tensor<32x64xf32, #mma> // CHECK: ldmatrix.sync.aligned.m8n8.x4.shared.b16 - %a_mat = triton_gpu.local_load %a : !tt.memdesc<32x64xf16, #shared0> -> tensor<32x64xf16, #dot_operand_a> - %b_mat = triton_gpu.local_load %b : !tt.memdesc<64x64xf16, #shared1> -> tensor<64x64xf16, #dot_operand_b> + %a_mat = triton_gpu.local_load %a : !tt.memdesc<32x64xf16, #shared0, #triton_gpu.shared_memory> -> tensor<32x64xf16, #dot_operand_a> + %b_mat = triton_gpu.local_load %b : !tt.memdesc<64x64xf16, #shared1, #triton_gpu.shared_memory> -> tensor<64x64xf16, #dot_operand_b> %28 = tt.dot %a_mat, %b_mat, %cst : tensor<32x64xf16, #dot_operand_a> * tensor<64x64xf16, #dot_operand_b> -> tensor<32x64xf32, #mma> %38 = triton_gpu.convert_layout %28 : tensor<32x64xf32, #mma> -> tensor<32x64xf32, #blocked> @@ -1012,11 +1012,11 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : #dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#blocked}> module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { tt.func @matmul_fmadot(%ptr:!tt.ptr {tt.divisibility = 16 : i32}, - %a:!tt.memdesc<32x16xf32, #shared>, %b:!tt.memdesc<16x32xf32, #shared>) { + %a:!tt.memdesc<32x16xf32, #shared, #triton_gpu.shared_memory>, %b:!tt.memdesc<16x32xf32, #shared, #triton_gpu.shared_memory>) { %cst = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #blocked> // CHECK: llvm.intr.fmuladd - %a_mat = triton_gpu.local_load %a : !tt.memdesc<32x16xf32, #shared> -> tensor<32x16xf32, #dot_operand_a> - %b_mat = triton_gpu.local_load %b : !tt.memdesc<16x32xf32, #shared> -> tensor<16x32xf32, #dot_operand_b> + %a_mat = triton_gpu.local_load %a : !tt.memdesc<32x16xf32, #shared, #triton_gpu.shared_memory> -> tensor<32x16xf32, #dot_operand_a> + %b_mat = triton_gpu.local_load %b : !tt.memdesc<16x32xf32, #shared, #triton_gpu.shared_memory> -> tensor<16x32xf32, #dot_operand_b> %28 = tt.dot %a_mat, %b_mat, %cst, inputPrecision = ieee : tensor<32x16xf32, #dot_operand_a> * tensor<16x32xf32, #dot_operand_b> -> tensor<32x32xf32, #blocked> %30 = tt.splat %ptr : !tt.ptr -> tensor<32x1x!tt.ptr, #blocked> @@ -1036,7 +1036,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { // CHECK-LABEL: matmul_tf32dot tt.func @matmul_tf32dot(%ptr:!tt.ptr {tt.divisibility = 16 : i32}, - %a:!tt.memdesc<32x16xf32, #shared>, %b:!tt.memdesc<16x32xf32, #shared>) { + %a:!tt.memdesc<32x16xf32, #shared, #triton_gpu.shared_memory>, %b:!tt.memdesc<16x32xf32, #shared, #triton_gpu.shared_memory>) { %cst = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #mma> // CHECK: llvm.inline_asm // CHECK-SAME: ldmatrix.sync.aligned.m8n8.x4.shared.b16 @@ -1044,8 +1044,8 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // CHECK: llvm.inline_asm // CHECK-SAME: ldmatrix.sync.aligned.m8n8.x4.shared.b16 // CHECK-SAME: (i32, i32, i32, i32) - %a_mat = triton_gpu.local_load %a : !tt.memdesc<32x16xf32, #shared> -> tensor<32x16xf32, #dot_operand_a> - %b_mat = triton_gpu.local_load %b : !tt.memdesc<16x32xf32, #shared> -> tensor<16x32xf32, #dot_operand_b> + %a_mat = triton_gpu.local_load %a : !tt.memdesc<32x16xf32, #shared, #triton_gpu.shared_memory> -> tensor<32x16xf32, #dot_operand_a> + %b_mat = triton_gpu.local_load %b : !tt.memdesc<16x32xf32, #shared, #triton_gpu.shared_memory> -> tensor<16x32xf32, #dot_operand_b> // CHECK: llvm.inline_asm // CHECK-SAME: mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32 @@ -1240,8 +1240,8 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : // CHECK-LABEL: test_base_index_cache tt.func @test_base_index_cache(%arg0: tensor<128x32xf32, #blocked0>) { // CHECK: nvvm.read.ptx.sreg.tid.x - %0 = triton_gpu.local_alloc %arg0 : (tensor<128x32xf32, #blocked0>) -> !tt.memdesc<128x32xf32, #shared0> - %1 = triton_gpu.local_alloc %arg0 : (tensor<128x32xf32, #blocked0>) -> !tt.memdesc<128x32xf32, #shared0> + %0 = triton_gpu.local_alloc %arg0 : (tensor<128x32xf32, #blocked0>) -> !tt.memdesc<128x32xf32, #shared0, #triton_gpu.shared_memory> + %1 = triton_gpu.local_alloc %arg0 : (tensor<128x32xf32, #blocked0>) -> !tt.memdesc<128x32xf32, #shared0, #triton_gpu.shared_memory> tt.return } } @@ -1253,10 +1253,10 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : // CHECK-LABEL: test_index_cache_different_block tt.func @test_index_cache_different_block(%arg0: tensor<128x32xf32, #blocked0>, %arg1: i1) { // CHECK: nvvm.read.ptx.sreg.tid.x - %0 = triton_gpu.local_alloc %arg0 : (tensor<128x32xf32, #blocked0>) -> !tt.memdesc<128x32xf32, #shared0> + %0 = triton_gpu.local_alloc %arg0 : (tensor<128x32xf32, #blocked0>) -> !tt.memdesc<128x32xf32, #shared0, #triton_gpu.shared_memory> cf.cond_br %arg1, ^bb1, ^bb2 ^bb1: // pred: ^bb0 - %1 = triton_gpu.local_alloc %arg0 : (tensor<128x32xf32, #blocked0>) -> !tt.memdesc<128x32xf32, #shared0> + %1 = triton_gpu.local_alloc %arg0 : (tensor<128x32xf32, #blocked0>) -> !tt.memdesc<128x32xf32, #shared0, #triton_gpu.shared_memory> cf.br ^bb2 ^bb2: // 2 preds: ^bb0, ^bb1 tt.return @@ -1529,8 +1529,8 @@ module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : // CHECK: llvm.load // CHECK-SAME: {alignment = 8 : i64} : !llvm.ptr<3> -> vector<8xi8> // CHECK-NOT: llvm.load - tt.func public @vectorize_shmem_load(%shmem : !tt.memdesc<16x16xi8, #shared>) { - %0 = triton_gpu.local_load %shmem : !tt.memdesc<16x16xi8, #shared> -> tensor<16x16xi8, #blocked> + tt.func public @vectorize_shmem_load(%shmem : !tt.memdesc<16x16xi8, #shared, #triton_gpu.shared_memory>) { + %0 = triton_gpu.local_load %shmem : !tt.memdesc<16x16xi8, #shared, #triton_gpu.shared_memory> -> tensor<16x16xi8, #blocked> tt.return } } @@ -1545,7 +1545,7 @@ module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : // CHECK-SAME: {alignment = 64 : i64} : vector<16xi32>, !llvm.ptr<3> // CHECK-NOT: llvm.store tt.func public @vectorize_shmem_store(%block : tensor<64x64xi32, #blocked>) { - %0 = triton_gpu.local_alloc %block : (tensor<64x64xi32, #blocked>) -> !tt.memdesc<64x64xi32, #shared> + %0 = triton_gpu.local_alloc %block : (tensor<64x64xi32, #blocked>) -> !tt.memdesc<64x64xi32, #shared, #triton_gpu.shared_memory> tt.return } } @@ -1570,9 +1570,9 @@ module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : // CHECK: llvm.extractelement {{.*}} : vector<8xbf16> tt.func public @test_local_load_bf16() { %c0_i32 = arith.constant 0 : i32 - %19 = triton_gpu.local_alloc : () -> !tt.memdesc<1x1x2048xbf16, #shared, mutable> - %22 = triton_gpu.memdesc_subview %19[%c0_i32, %c0_i32, %c0_i32] : !tt.memdesc<1x1x2048xbf16, #shared, mutable> -> !tt.memdesc<1x2048xbf16, #shared, mutable> - %39 = triton_gpu.local_load %22 : !tt.memdesc<1x2048xbf16, #shared, mutable> -> tensor<1x2048xbf16, #blocked> + %19 = triton_gpu.local_alloc : () -> !tt.memdesc<1x1x2048xbf16, #shared, #triton_gpu.shared_memory, mutable> + %22 = triton_gpu.memdesc_subview %19[%c0_i32, %c0_i32, %c0_i32] : !tt.memdesc<1x1x2048xbf16, #shared, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<1x2048xbf16, #shared, #triton_gpu.shared_memory, mutable> + %39 = triton_gpu.local_load %22 : !tt.memdesc<1x2048xbf16, #shared, #triton_gpu.shared_memory, mutable> -> tensor<1x2048xbf16, #blocked> %40 = arith.extf %39 : tensor<1x2048xbf16, #blocked> to tensor<1x2048xf32, #blocked> tt.return } @@ -1586,8 +1586,8 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // CHECK: llvm.store tt.func public @test_local_store(%arg0: tensor<1xf32, #blocked>) { %c0_i32 = arith.constant 0 : i32 - %0 = triton_gpu.local_alloc {allocation.offset = 0 : i32} : () -> !tt.memdesc<1xf32, #shared, mutable> - triton_gpu.local_store %arg0, %0 : tensor<1xf32, #blocked> -> !tt.memdesc<1xf32, #shared, mutable> + %0 = triton_gpu.local_alloc {allocation.offset = 0 : i32} : () -> !tt.memdesc<1xf32, #shared, #triton_gpu.shared_memory, mutable> + triton_gpu.local_store %arg0, %0 : tensor<1xf32, #blocked> -> !tt.memdesc<1xf32, #shared, #triton_gpu.shared_memory, mutable> tt.return } } @@ -1600,9 +1600,9 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // CHECK: llvm.store tt.func public @test_local_store_subview(%arg0: tensor<1xf32, #blocked>) { %c0_i32 = arith.constant 0 : i32 - %0 = triton_gpu.local_alloc {allocation.offset = 0 : i32} : () -> !tt.memdesc<1xf32, #shared, mutable> - %sv = triton_gpu.memdesc_subview %0[%c0_i32] : !tt.memdesc<1xf32, #shared, mutable> -> !tt.memdesc<1xf32, #shared, mutable> - triton_gpu.local_store %arg0, %sv : tensor<1xf32, #blocked> -> !tt.memdesc<1xf32, #shared, mutable> + %0 = triton_gpu.local_alloc {allocation.offset = 0 : i32} : () -> !tt.memdesc<1xf32, #shared, #triton_gpu.shared_memory, mutable> + %sv = triton_gpu.memdesc_subview %0[%c0_i32] : !tt.memdesc<1xf32, #shared, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<1xf32, #shared, #triton_gpu.shared_memory, mutable> + triton_gpu.local_store %arg0, %sv : tensor<1xf32, #blocked> -> !tt.memdesc<1xf32, #shared, #triton_gpu.shared_memory, mutable> tt.return } } diff --git a/test/TritonGPU/loop-pipeline-hopper.mlir b/test/TritonGPU/loop-pipeline-hopper.mlir index 5c9cac004c20..4a6af2d905e4 100644 --- a/test/TritonGPU/loop-pipeline-hopper.mlir +++ b/test/TritonGPU/loop-pipeline-hopper.mlir @@ -18,11 +18,11 @@ // CHECK: %[[BBUFFER:.*]] = triton_gpu.local_alloc // CHECK-DAG: %[[LOOP_COND_0:.*]] = arith.cmpi slt, %[[LB:.*]], %[[UB:.*]] // CHECK-DAG: %[[LOOP_COND_0_SPLAT_A:.*]] = tt.splat %[[LOOP_COND_0]] -// CHECK-DAG: %[[ASUB:.*]] = triton_gpu.memdesc_subview %[[ABUFFER]][%[[CONSTANT_0]], %[[CONSTANT_0]], %[[CONSTANT_0]]] : !tt.memdesc<2x128x32xf16, #shared, mutable> -> !tt.memdesc<128x32xf16, #shared, mutable> -// CHECK: %[[T_A0:.*]] = triton_gpu.async_copy_global_to_local %{{.*}}, %[[ASUB]] mask %[[LOOP_COND_0_SPLAT_A]] : tensor<128x32x!tt.ptr, #blocked1> -> <128x32xf16, #shared, mutable> +// CHECK-DAG: %[[ASUB:.*]] = triton_gpu.memdesc_subview %[[ABUFFER]][%[[CONSTANT_0]], %[[CONSTANT_0]], %[[CONSTANT_0]]] : !tt.memdesc<2x128x32xf16, #shared, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<128x32xf16, #shared, #triton_gpu.shared_memory, mutable> +// CHECK: %[[T_A0:.*]] = triton_gpu.async_copy_global_to_local %{{.*}}, %[[ASUB]] mask %[[LOOP_COND_0_SPLAT_A]] : tensor<128x32x!tt.ptr, #blocked1> -> <128x32xf16, #shared, #triton_gpu.shared_memory, mutable> // CHECK-DAG: %[[LOOP_COND_0_SPLAT_B:.*]] = tt.splat %[[LOOP_COND_0]] // CHECK-DAG: %[[BSUB:.*]] = triton_gpu.memdesc_subview %[[BBUFFER]][%[[CONSTANT_0]], %[[CONSTANT_0]], %[[CONSTANT_0]]] -// CHECK: %[[T_B0:.*]] = triton_gpu.async_copy_global_to_local %{{.*}}, %[[BSUB]] mask %[[LOOP_COND_0_SPLAT_B]] other %{{.*}} : tensor<32x128x!tt.ptr, #blocked> -> <32x128xf16, #shared1, mutable> +// CHECK: %[[T_B0:.*]] = triton_gpu.async_copy_global_to_local %{{.*}}, %[[BSUB]] mask %[[LOOP_COND_0_SPLAT_B]] other %{{.*}} : tensor<32x128x!tt.ptr, #blocked> -> <32x128xf16, #shared1, #triton_gpu.shared_memory, mutable> // CHECK-DAG: %[[IV_1:.*]] = arith.addi %[[LB]], %[[STEP:.*]] // CHECK-DAG: %[[LOOP_COND_1:.*]] = arith.cmpi slt, %[[IV_1]], %[[UB]] // CHECK-DAG: %[[LOOP_COND_1_SPLAT_A:.*]] = tt.splat %[[LOOP_COND_1]] @@ -332,8 +332,8 @@ tt.func @matmul_loop_single_pipeline(%lb : index, %ub : index, %step : index, // %a = tt.load %a_tileptr : !tt.ptr, 1> // %b = tt.load %b_tileptr : !tt.ptr, 1> // -// %sa = triton_gpu.local_alloc %a : (tensor<128x32xf16, #BA>) -> !tt.memdesc<128x32xf16, #SA> -// %sb = triton_gpu.local_alloc %b : (tensor<32x128xf16, #BB>) -> !tt.memdesc<32x128xf16, #SB> +// %sa = triton_gpu.local_alloc %a : (tensor<128x32xf16, #BA>) -> !tt.memdesc<128x32xf16, #SA, #triton_gpu.shared_memory> +// %sb = triton_gpu.local_alloc %b : (tensor<32x128xf16, #BB>) -> !tt.memdesc<32x128xf16, #SB, #triton_gpu.shared_memory> // %c = triton_gpu_nvidia.dot_async %sa, %sb, %prev_c : tensor<128x32xf16, #SA> * tensor<32x128xf16, #SB> -> tensor<128x128xf32, #C> // // %a_tileptr_next = tt.advance %a_tileptr, [%c0, %c32_i32] : !tt.ptr, 1> @@ -392,13 +392,13 @@ module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : // CHECK: scf.yield %17:2 = scf.for %arg3 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg4 = %cst_3, %arg5 = %16) -> (tensor<128x64xf32, #mma>, tensor<64x16x!tt.ptr, #blocked>) : i32 { %18 = tt.load %arg5 : tensor<64x16x!tt.ptr, #blocked> - %19 = triton_gpu.local_alloc %9 : (tensor<128x64xf16, #blocked1>) -> !tt.memdesc<128x64xf16, #shared> - %20 = triton_gpu.local_alloc %18 : (tensor<64x16xf16, #blocked>) -> !tt.memdesc<64x16xf16, #shared1> - %21 = tt.dot %19, %20, %cst_2 : !tt.memdesc<128x64xf16, #shared> * !tt.memdesc<64x16xf16, #shared1> -> tensor<128x16xf32, #mma1> + %19 = triton_gpu.local_alloc %9 : (tensor<128x64xf16, #blocked1>) -> !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> + %20 = triton_gpu.local_alloc %18 : (tensor<64x16xf16, #blocked>) -> !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> + %21 = tt.dot %19, %20, %cst_2 : !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> * !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> -> tensor<128x16xf32, #mma1> %22 = arith.truncf %21 : tensor<128x16xf32, #mma1> to tensor<128x16xf16, #mma1> - %23 = tt.trans %20 {order=array} : !tt.memdesc<64x16xf16, #shared1> -> !tt.memdesc<16x64xf16, #shared> + %23 = tt.trans %20 {order=array} : !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> -> !tt.memdesc<16x64xf16, #shared, #triton_gpu.shared_memory> %24 = triton_gpu.convert_layout %22 : tensor<128x16xf16, #mma1> -> tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> - %25 = tt.dot %24, %23, %arg4 : tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> * !tt.memdesc<16x64xf16, #shared> -> tensor<128x64xf32, #mma> + %25 = tt.dot %24, %23, %arg4 : tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> * !tt.memdesc<16x64xf16, #shared, #triton_gpu.shared_memory> -> tensor<128x64xf32, #mma> %26 = tt.addptr %arg5, %cst : tensor<64x16x!tt.ptr, #blocked>, tensor<64x16xi32, #blocked> scf.yield %25, %26 : tensor<128x64xf32, #mma>, tensor<64x16x!tt.ptr, #blocked> } @@ -444,9 +444,9 @@ module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : %17:3 = scf.for %arg3 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg4 = %cst_2, %arg5 = %16, %arg6 = %8) -> (tensor<128x16xf32, #mma1>, tensor<64x16x!tt.ptr, #blocked>, tensor<128x64x!tt.ptr, #blocked1>) : i32 { %9 = tt.load %arg6 : tensor<128x64x!tt.ptr, #blocked1> %18 = tt.load %arg5 : tensor<64x16x!tt.ptr, #blocked> - %19 = triton_gpu.local_alloc %9 : (tensor<128x64xf16, #blocked1>) -> !tt.memdesc<128x64xf16, #shared> - %20 = triton_gpu.local_alloc %18 : (tensor<64x16xf16, #blocked>) -> !tt.memdesc<64x16xf16, #shared1> - %acc = tt.dot %19, %20, %arg4 : !tt.memdesc<128x64xf16, #shared> * !tt.memdesc<64x16xf16, #shared1> -> tensor<128x16xf32, #mma1> + %19 = triton_gpu.local_alloc %9 : (tensor<128x64xf16, #blocked1>) -> !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> + %20 = triton_gpu.local_alloc %18 : (tensor<64x16xf16, #blocked>) -> !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> + %acc = tt.dot %19, %20, %arg4 : !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> * !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> -> tensor<128x16xf32, #mma1> %cnd = arith.cmpi slt, %arg3, %ext : i32 %acc_ = scf.if %cnd -> (tensor<128x16xf32, #mma1>) { %acc_zero = arith.mulf %acc, %cst_2 : tensor<128x16xf32, #mma1> @@ -501,8 +501,8 @@ module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : %15 = tt.broadcast %13 : tensor<64x1xi32, #blocked> -> tensor<64x16xi32, #blocked> %16 = tt.addptr %14, %15 : tensor<64x16x!tt.ptr, #blocked>, tensor<64x16xi32, #blocked> %18 = tt.load %16 : tensor<64x16x!tt.ptr, #blocked> - %19 = triton_gpu.local_alloc %9 : (tensor<128x64xf16, #blocked1>) -> !tt.memdesc<128x64xf16, #shared> - %20 = triton_gpu.local_alloc %18 : (tensor<64x16xf16, #blocked>) -> !tt.memdesc<64x16xf16, #shared1> + %19 = triton_gpu.local_alloc %9 : (tensor<128x64xf16, #blocked1>) -> !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> + %20 = triton_gpu.local_alloc %18 : (tensor<64x16xf16, #blocked>) -> !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> // CHECK: %[[ALLOC1:.+]] = triton_gpu.local_alloc // CHECK: %[[ALLOC2:.+]] = triton_gpu.local_alloc // CHECK: %[[R:.+]]:{{.+}} = scf.for @@ -514,11 +514,11 @@ module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : // CHECK: scf.yield // CHECK: %{{.*}}:2 = triton_nvidia_gpu.dot_wait %[[R]]#{{.+}}, %[[R]]#{{.+}} {pendings = 0 : i32} : tensor<128x16xf32, #{{.*}}>, tensor<128x64xf32, #{{.*}}> %17:3 = scf.for %arg3 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg4 = %cst_3, %arg5 = %16, %arg6 = %cst_2) -> (tensor<128x64xf32, #mma>, tensor<64x16x!tt.ptr, #blocked>, tensor<128x16xf32, #mma1>) : i32 { - %21 = tt.dot %19, %20, %arg6 : !tt.memdesc<128x64xf16, #shared> * !tt.memdesc<64x16xf16, #shared1> -> tensor<128x16xf32, #mma1> + %21 = tt.dot %19, %20, %arg6 : !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> * !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> -> tensor<128x16xf32, #mma1> %l = tt.load %arg5 : tensor<64x16x!tt.ptr, #blocked> - %c = triton_gpu.local_alloc %l : (tensor<64x16xf16, #blocked>) -> !tt.memdesc<64x16xf16, #shared1> - %23 = tt.trans %c {order=array} : !tt.memdesc<64x16xf16, #shared1> -> !tt.memdesc<16x64xf16, #shared> - %25 = tt.dot %cst_4, %23, %arg4 : tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> * !tt.memdesc<16x64xf16, #shared> -> tensor<128x64xf32, #mma> + %c = triton_gpu.local_alloc %l : (tensor<64x16xf16, #blocked>) -> !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> + %23 = tt.trans %c {order=array} : !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> -> !tt.memdesc<16x64xf16, #shared, #triton_gpu.shared_memory> + %25 = tt.dot %cst_4, %23, %arg4 : tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> * !tt.memdesc<16x64xf16, #shared, #triton_gpu.shared_memory> -> tensor<128x64xf32, #mma> %26 = tt.addptr %arg5, %cst : tensor<64x16x!tt.ptr, #blocked>, tensor<64x16xi32, #blocked> scf.yield %25, %26, %21 : tensor<128x64xf32, #mma>, tensor<64x16x!tt.ptr, #blocked>, tensor<128x16xf32, #mma1> } @@ -576,13 +576,13 @@ module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : %22:3 = scf.for %arg3 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg4 = %cst, %arg5 = %12, %arg6 = %21) -> (tensor<128x256xf32, #mma>, tensor<128x64x!tt.ptr, #blocked>, tensor<64x256x!tt.ptr, #blocked1>) : i32 { %35 = tt.load %arg5 : tensor<128x64x!tt.ptr, #blocked> %36 = tt.load %arg6 : tensor<64x256x!tt.ptr, #blocked1> - %37 = triton_gpu.local_alloc %35 : (tensor<128x64xf8E5M2, #blocked>) -> !tt.memdesc<128x64xf8E5M2, #shared> - %38 = triton_gpu.local_alloc %36 : (tensor<64x256xf8E5M2, #blocked1>) -> !tt.memdesc<64x256xf8E5M2, #shared1> + %37 = triton_gpu.local_alloc %35 : (tensor<128x64xf8E5M2, #blocked>) -> !tt.memdesc<128x64xf8E5M2, #shared, #triton_gpu.shared_memory> + %38 = triton_gpu.local_alloc %36 : (tensor<64x256xf8E5M2, #blocked1>) -> !tt.memdesc<64x256xf8E5M2, #shared1, #triton_gpu.shared_memory> // CHECK: triton_gpu.local_alloc // CHECK: scf.for // CHECK: triton_nvidia_gpu.dot_async // CHECK-NEXT: triton_nvidia_gpu.dot_wait - %39 = tt.dot %37, %38, %arg4 {maxNumImpreciseAcc = 1073741824 : i32} : !tt.memdesc<128x64xf8E5M2, #shared> * !tt.memdesc<64x256xf8E5M2, #shared1> -> tensor<128x256xf32, #mma> + %39 = tt.dot %37, %38, %arg4 {maxNumImpreciseAcc = 1073741824 : i32} : !tt.memdesc<128x64xf8E5M2, #shared, #triton_gpu.shared_memory> * !tt.memdesc<64x256xf8E5M2, #shared1, #triton_gpu.shared_memory> -> tensor<128x256xf32, #mma> %40 = tt.addptr %arg5, %cst_6 : tensor<128x64x!tt.ptr, #blocked>, tensor<128x64xi32, #blocked> %41 = tt.addptr %arg6, %cst_5 : tensor<64x256x!tt.ptr, #blocked1>, tensor<64x256xi32, #blocked1> scf.yield %39, %40, %41 : tensor<128x256xf32, #mma>, tensor<128x64x!tt.ptr, #blocked>, tensor<64x256x!tt.ptr, #blocked1> @@ -656,8 +656,8 @@ module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : %15 = tt.broadcast %13 : tensor<64x1xi32, #blocked> -> tensor<64x16xi32, #blocked> %16 = tt.addptr %14, %15 : tensor<64x16x!tt.ptr, #blocked>, tensor<64x16xi32, #blocked> %18 = tt.load %16 : tensor<64x16x!tt.ptr, #blocked> - %19 = triton_gpu.local_alloc %9 : (tensor<128x64xf16, #blocked1>) -> !tt.memdesc<128x64xf16, #shared> - %20 = triton_gpu.local_alloc %18 : (tensor<64x16xf16, #blocked>) -> !tt.memdesc<64x16xf16, #shared1> + %19 = triton_gpu.local_alloc %9 : (tensor<128x64xf16, #blocked1>) -> !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> + %20 = triton_gpu.local_alloc %18 : (tensor<64x16xf16, #blocked>) -> !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> // CHECK: %[[LOOP:[^ :]+]]{{.*}} scf.for {{.*}} iter_args(%[[PREV_DOT2:[^ ]+]] // CHECK-NOT: triton_nvidia_gpu.dot_wait // CHECK: %[[DOT0:.+]] = triton_nvidia_gpu.dot_async @@ -674,17 +674,17 @@ module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : // CHECK: triton_nvidia_gpu.dot_wait %[[LOOP]]#3, %[[LOOP]]#0 {pendings = 0 : i32} %17:4 = scf.for %arg3 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%prev_dot2 = %cst_3, %arg5 = %16, %prev_dot1 = %cst_2, %prev_dot0 = %cst_2) -> (tensor<128x64xf32, #mma>, tensor<64x16x!tt.ptr, #blocked>, tensor<128x16xf32, #mma1>, tensor<128x16xf32, #mma1>) : i32 { // This one can be async. - %dot0 = tt.dot %19, %20, %prev_dot1 : !tt.memdesc<128x64xf16, #shared> * !tt.memdesc<64x16xf16, #shared1> -> tensor<128x16xf32, #mma1> + %dot0 = tt.dot %19, %20, %prev_dot1 : !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> * !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> -> tensor<128x16xf32, #mma1> // This can't be async because its result is modified before it's yielded. - %dot1 = tt.dot %19, %20, %prev_dot1 : !tt.memdesc<128x64xf16, #shared> * !tt.memdesc<64x16xf16, #shared1> -> tensor<128x16xf32, #mma1> + %dot1 = tt.dot %19, %20, %prev_dot1 : !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> * !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> -> tensor<128x16xf32, #mma1> %dot1.1 = arith.addf %dot1, %dot1 : tensor<128x16xf32, #mma1> %l = tt.load %arg5 : tensor<64x16x!tt.ptr, #blocked> - %c = triton_gpu.local_alloc %l : (tensor<64x16xf16, #blocked>) -> !tt.memdesc<64x16xf16, #shared1> - %23 = tt.trans %c {order=array} : !tt.memdesc<64x16xf16, #shared1> -> !tt.memdesc<16x64xf16, #shared> + %c = triton_gpu.local_alloc %l : (tensor<64x16xf16, #blocked>) -> !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> + %23 = tt.trans %c {order=array} : !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> -> !tt.memdesc<16x64xf16, #shared, #triton_gpu.shared_memory> // This dot can be async even though %prev_dot2 is not used directly by an // async dot, because that use follows the synchronous dot above. %prev_dot2.1 = arith.addf %prev_dot2, %prev_dot2 : tensor<128x64xf32, #mma> - %dot2 = tt.dot %cst_4, %23, %prev_dot2.1 : tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> * !tt.memdesc<16x64xf16, #shared> -> tensor<128x64xf32, #mma> + %dot2 = tt.dot %cst_4, %23, %prev_dot2.1 : tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> * !tt.memdesc<16x64xf16, #shared, #triton_gpu.shared_memory> -> tensor<128x64xf32, #mma> %26 = tt.addptr %arg5, %cst : tensor<64x16x!tt.ptr, #blocked>, tensor<64x16xi32, #blocked> scf.yield %dot2, %26, %dot1.1, %dot0 : tensor<128x64xf32, #mma>, tensor<64x16x!tt.ptr, #blocked>, tensor<128x16xf32, #mma1>, tensor<128x16xf32, #mma1> } diff --git a/test/TritonGPU/loop-pipeline.mlir b/test/TritonGPU/loop-pipeline.mlir index 20d8093b03ae..060566767859 100644 --- a/test/TritonGPU/loop-pipeline.mlir +++ b/test/TritonGPU/loop-pipeline.mlir @@ -626,9 +626,9 @@ module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : %21 = tt.dot %19, %20, %cst_1 : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<64x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x16xf32, #mma> %22 = arith.truncf %21 : tensor<128x16xf32, #mma> to tensor<128x16xf16, #mma> %23 = triton_gpu.convert_layout %22 : tensor<128x16xf16, #mma> -> tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> - %24 = triton_gpu.local_alloc %18 : (tensor<64x16xf16, #blocked>) -> !tt.memdesc<64x16xf16, #shared> - %25 = tt.trans %24 {order=array} : !tt.memdesc<64x16xf16, #shared> -> !tt.memdesc<16x64xf16, #shared1> - %26 = triton_gpu.local_load %25 : !tt.memdesc<16x64xf16, #shared1> -> tensor<16x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %24 = triton_gpu.local_alloc %18 : (tensor<64x16xf16, #blocked>) -> !tt.memdesc<64x16xf16, #shared, #triton_gpu.shared_memory> + %25 = tt.trans %24 {order=array} : !tt.memdesc<64x16xf16, #shared, #triton_gpu.shared_memory> -> !tt.memdesc<16x64xf16, #shared1, #triton_gpu.shared_memory> + %26 = triton_gpu.local_load %25 : !tt.memdesc<16x64xf16, #shared1, #triton_gpu.shared_memory> -> tensor<16x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> %27 = tt.dot %23, %26, %arg4 : tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<16x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x64xf32, #mma> scf.yield %21, %27 : tensor<128x16xf32, #mma>, tensor<128x64xf32, #mma> } @@ -680,9 +680,9 @@ module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : %21 = tt.dot %19, %20, %cst_1 : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<64x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x16xf32, #mma> %22 = arith.truncf %21 : tensor<128x16xf32, #mma> to tensor<128x16xf16, #mma> %23 = triton_gpu.convert_layout %22 : tensor<128x16xf16, #mma> -> tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> - %24 = triton_gpu.local_alloc %18 : (tensor<64x16xf16, #blocked>) -> !tt.memdesc<64x16xf16, #shared> - %25 = tt.trans %24 {order=array} : !tt.memdesc<64x16xf16, #shared> -> !tt.memdesc<16x64xf16, #shared1> - %26 = triton_gpu.local_load %25 : !tt.memdesc<16x64xf16, #shared1> -> tensor<16x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %24 = triton_gpu.local_alloc %18 : (tensor<64x16xf16, #blocked>) -> !tt.memdesc<64x16xf16, #shared, #triton_gpu.shared_memory> + %25 = tt.trans %24 {order=array} : !tt.memdesc<64x16xf16, #shared, #triton_gpu.shared_memory> -> !tt.memdesc<16x64xf16, #shared1, #triton_gpu.shared_memory> + %26 = triton_gpu.local_load %25 : !tt.memdesc<16x64xf16, #shared1, #triton_gpu.shared_memory> -> tensor<16x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> %27 = tt.dot %23, %26, %arg4 : tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<16x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x64xf32, #mma> scf.yield %21, %27 : tensor<128x16xf32, #mma>, tensor<128x64xf32, #mma> } @@ -861,9 +861,9 @@ module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : %63 = scf.for %arg6 = %c0_i32 to %c64_i32 step %c32_i32 iter_args(%arg7 = %cst) -> (tensor<64x32xf32, #mma>) : i32 { %70 = tt.load %59 : tensor<32x64x!tt.ptr, #blocked1> %71 = triton_gpu.convert_layout %62 : tensor<64x64xf32, #blocked> -> tensor<64x64xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> - %72 = triton_gpu.local_alloc %70 : (tensor<32x64xf32, #blocked1>) -> !tt.memdesc<32x64xf32, #shared> - %73 = tt.trans %72 {order=array} : !tt.memdesc<32x64xf32, #shared> -> !tt.memdesc<64x32xf32, #shared1> - %74 = triton_gpu.local_load %73 : !tt.memdesc<64x32xf32, #shared1> -> tensor<64x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> + %72 = triton_gpu.local_alloc %70 : (tensor<32x64xf32, #blocked1>) -> !tt.memdesc<32x64xf32, #shared, #triton_gpu.shared_memory> + %73 = tt.trans %72 {order=array} : !tt.memdesc<32x64xf32, #shared, #triton_gpu.shared_memory> -> !tt.memdesc<64x32xf32, #shared1, #triton_gpu.shared_memory> + %74 = triton_gpu.local_load %73 : !tt.memdesc<64x32xf32, #shared1, #triton_gpu.shared_memory> -> tensor<64x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> %75 = tt.dot %71, %74, %cst : tensor<64x64xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<64x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> -> tensor<64x32xf32, #mma> %76 = tt.load %61 : tensor<32x32x!tt.ptr, #blocked1> %77 = triton_gpu.convert_layout %75 : tensor<64x32xf32, #mma> -> tensor<64x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> @@ -888,7 +888,7 @@ module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : // CHECK: scf.for // CHECK: %[[NEXT_BUFFER_1:.*]] = tt.addptr %{{.*}}, {{.*}} // CHECK: triton_gpu.async_copy_global_to_local %[[NEXT_BUFFER_1]] -// CHECK: %[[IND_BUFFER_0:.*]] = triton_gpu.memdesc_subview {{.*}} : !tt.memdesc<1x16xi64, #[[$SHARED_LAYOUT]], mutable> -> !tt.memdesc<16xi64, #[[$SHARED_LAYOUT]], mutable> +// CHECK: %[[IND_BUFFER_0:.*]] = triton_gpu.memdesc_subview {{.*}} : !tt.memdesc<1x16xi64, #[[$SHARED_LAYOUT]], #triton_gpu.shared_memory, mutable> -> !tt.memdesc<16xi64, #[[$SHARED_LAYOUT]], #triton_gpu.shared_memory, mutable> // CHECK: %[[IND_BUFFER_1:.*]] = triton_gpu.local_load %[[IND_BUFFER_0]] // CHECK: %[[IND_BUFFER_2:.*]] = tt.expand_dims %[[IND_BUFFER_1]] {axis = 1 : i32} // CHECK: %[[IND_BUFFER_3:.*]] = tt.broadcast %[[IND_BUFFER_2]] @@ -1090,9 +1090,9 @@ module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : %9 = tt.addptr %7, %8 : tensor<16x16x!tt.ptr, #blocked>, tensor<16x16xi32, #blocked> scf.for %arg1 = %c0_i32 to %c2_i32 step %c1_i32 : i32 { %10 = tt.load %9 : tensor<16x16x!tt.ptr, #blocked> - %11 = triton_gpu.local_alloc %10 : (tensor<16x16xf32, #blocked>) -> !tt.memdesc<16x16xf32, #shared> - %12 = tt.trans %11 {order = array} : !tt.memdesc<16x16xf32, #shared> -> !tt.memdesc<16x16xf32, #shared1> - %13 = triton_gpu.local_load %12 : !tt.memdesc<16x16xf32, #shared1> -> tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> + %11 = triton_gpu.local_alloc %10 : (tensor<16x16xf32, #blocked>) -> !tt.memdesc<16x16xf32, #shared, #triton_gpu.shared_memory> + %12 = tt.trans %11 {order = array} : !tt.memdesc<16x16xf32, #shared, #triton_gpu.shared_memory> -> !tt.memdesc<16x16xf32, #shared1, #triton_gpu.shared_memory> + %13 = triton_gpu.local_load %12 : !tt.memdesc<16x16xf32, #shared1, #triton_gpu.shared_memory> -> tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> scf.for %arg2 = %c0_i32 to %c2_i32 step %c1_i32 : i32 { %14 = tt.load %9 : tensor<16x16x!tt.ptr, #blocked> %15 = triton_gpu.convert_layout %14 : tensor<16x16xf32, #blocked> -> tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> @@ -1435,9 +1435,9 @@ module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : scf.yield %arg5 : tensor<64x16x!tt.ptr, #blocked> } %18 = tt.load %inc_ptr : tensor<64x16x!tt.ptr, #blocked> - %19 = triton_gpu.local_alloc %9 : (tensor<128x64xf16, #blocked1>) -> !tt.memdesc<128x64xf16, #shared> - %20 = triton_gpu.local_alloc %18 : (tensor<64x16xf16, #blocked>) -> !tt.memdesc<64x16xf16, #shared1> - %acc = tt.dot %19, %20, %arg4 : !tt.memdesc<128x64xf16, #shared> * !tt.memdesc<64x16xf16, #shared1> -> tensor<128x16xf32, #mma1> + %19 = triton_gpu.local_alloc %9 : (tensor<128x64xf16, #blocked1>) -> !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> + %20 = triton_gpu.local_alloc %18 : (tensor<64x16xf16, #blocked>) -> !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> + %acc = tt.dot %19, %20, %arg4 : !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> * !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> -> tensor<128x16xf32, #mma1> %acc_ = scf.if %cnd -> (tensor<128x16xf32, #mma1>) { %acc_zero = arith.mulf %acc, %cst_2 : tensor<128x16xf32, #mma1> scf.yield %acc_zero : tensor<128x16xf32, #mma1> @@ -1500,9 +1500,9 @@ module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : %17:3 = scf.for %arg3 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg4 = %cst_2, %arg5 = %16, %arg6 = %8) -> (tensor<128x16xf32, #mma1>, tensor<64x16x!tt.ptr, #blocked>, tensor<128x64x!tt.ptr, #blocked1>) : i32 { %9 = tt.load %arg6 : tensor<128x64x!tt.ptr, #blocked1> %18 = tt.load %arg5 : tensor<64x16x!tt.ptr, #blocked> - %19 = triton_gpu.local_alloc %9 : (tensor<128x64xf16, #blocked1>) -> !tt.memdesc<128x64xf16, #shared> - %20 = triton_gpu.local_alloc %18 : (tensor<64x16xf16, #blocked>) -> !tt.memdesc<64x16xf16, #shared1> - %acc = tt.dot %19, %20, %arg4 : !tt.memdesc<128x64xf16, #shared> * !tt.memdesc<64x16xf16, #shared1> -> tensor<128x16xf32, #mma1> + %19 = triton_gpu.local_alloc %9 : (tensor<128x64xf16, #blocked1>) -> !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> + %20 = triton_gpu.local_alloc %18 : (tensor<64x16xf16, #blocked>) -> !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> + %acc = tt.dot %19, %20, %arg4 : !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> * !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> -> tensor<128x16xf32, #mma1> %cnd = arith.cmpi slt, %arg3, %ext : i32 %if_ret:2 = scf.if %cnd -> (tensor<128x16xf32, #mma1>, tensor<64x16xi32, #blocked>) { %acc_zero = arith.mulf %acc, %cst_2 : tensor<128x16xf32, #mma1> @@ -1568,9 +1568,9 @@ module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : #shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "cuda:90", "triton_gpu.threads-per-warp" = 32 : i32} { // CHECK-LABEL: @matmul_tma -// CHECK-DAG: triton_gpu.local_alloc : () -> !tt.memdesc<3x128x64xf16, #{{.+}}, mutable> -// CHECK-DAG: triton_gpu.local_alloc : () -> !tt.memdesc<3x64x256xf16, #{{.+}}, mutable> -// CHECK-DAG: triton_gpu.local_alloc : () -> !tt.memdesc<3xi64, #{{.+}}, mutable> +// CHECK-DAG: triton_gpu.local_alloc : () -> !tt.memdesc<3x128x64xf16, #{{.+}}, #triton_gpu.shared_memory, mutable> +// CHECK-DAG: triton_gpu.local_alloc : () -> !tt.memdesc<3x64x256xf16, #{{.+}}, #triton_gpu.shared_memory, mutable> +// CHECK-DAG: triton_gpu.local_alloc : () -> !tt.memdesc<3xi64, #{{.+}}, #triton_gpu.shared_memory, mutable> // CHECK-COUNT-3: triton_nvidia_gpu.init_barrier // CHECK-COUNT-4: triton_nvidia_gpu.async_tma_copy_global_to_local // CHECK: scf.for @@ -1586,10 +1586,10 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : %cst = arith.constant dense<0.000000e+00> : tensor<128x256xf32, #mma> %0:2 = scf.for %arg3 = %c0_i32 to %c256_i32 step %c1_i32 iter_args(%arg4 = %cst, %arg5 = %c0_i32) -> (tensor<128x256xf32, #mma>, i32) : i32 { %1 = tt.experimental_descriptor_load %arg0[%c0_i32, %arg5] : !tt.ptr -> tensor<128x64xf16, #blocked> - %2 = triton_gpu.local_alloc %1 : (tensor<128x64xf16, #blocked>) -> !tt.memdesc<128x64xf16, #shared> + %2 = triton_gpu.local_alloc %1 : (tensor<128x64xf16, #blocked>) -> !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> %3 = tt.experimental_descriptor_load %arg1[%arg5, %c0_i32] : !tt.ptr -> tensor<64x256xf16, #blocked1> - %4 = triton_gpu.local_alloc %3 : (tensor<64x256xf16, #blocked1>) -> !tt.memdesc<64x256xf16, #shared> - %5 = tt.dot %2, %4, %arg4, inputPrecision = tf32 : !tt.memdesc<128x64xf16, #shared> * !tt.memdesc<64x256xf16, #shared> -> tensor<128x256xf32, #mma> + %4 = triton_gpu.local_alloc %3 : (tensor<64x256xf16, #blocked1>) -> !tt.memdesc<64x256xf16, #shared, #triton_gpu.shared_memory> + %5 = tt.dot %2, %4, %arg4, inputPrecision = tf32 : !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> * !tt.memdesc<64x256xf16, #shared, #triton_gpu.shared_memory> -> tensor<128x256xf32, #mma> %6 = arith.addi %arg5, %c64_i32 : i32 scf.yield %5, %6 : tensor<128x256xf32, #mma>, i32 } diff --git a/test/TritonGPU/reduce-data-duplication.mlir b/test/TritonGPU/reduce-data-duplication.mlir index 7dd91df04187..e98a6108d758 100644 --- a/test/TritonGPU/reduce-data-duplication.mlir +++ b/test/TritonGPU/reduce-data-duplication.mlir @@ -2,7 +2,7 @@ // CHECK: #[[SHARED:.*]] = #triton_gpu.shared<{vec = 8, perPhase = 8, maxPhase = 2, order = [0, 1], hasLeadingOffset = false} // CHECK: apply_swizzle -// CHECK: %{{.*}} = triton_gpu.local_alloc %{{.*}} : (tensor<16x256xf16, #{{.*}}>) -> !tt.memdesc<16x256xf16, #[[SHARED]]> +// CHECK: %{{.*}} = triton_gpu.local_alloc %{{.*}} : (tensor<16x256xf16, #{{.*}}>) -> !tt.memdesc<16x256xf16, #[[SHARED]], #triton_gpu.shared_memory> #blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [16, 2], warpsPerCTA = [1, 4], order = [0, 1]}> #mma = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 4], instrShape = [16, 8]}> diff --git a/test/TritonNvidiaGPU/membar.mlir b/test/TritonNvidiaGPU/membar.mlir index 95202f12ee7e..358f53fd7cd6 100644 --- a/test/TritonNvidiaGPU/membar.mlir +++ b/test/TritonNvidiaGPU/membar.mlir @@ -9,8 +9,8 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // CHECK-NEXT: init_barrier tt.func @init_barrier() { %cst = arith.constant dense<0> : tensor<1xi64, #blocked0> - %alloc = triton_gpu.local_alloc %cst : (tensor<1xi64, #blocked0>) -> !tt.memdesc<1xi64, #shared0, mutable> - triton_nvidia_gpu.init_barrier %alloc, 1 : !tt.memdesc<1xi64, #shared0, mutable> + %alloc = triton_gpu.local_alloc %cst : (tensor<1xi64, #blocked0>) -> !tt.memdesc<1xi64, #shared0, #triton_gpu.shared_memory, mutable> + triton_nvidia_gpu.init_barrier %alloc, 1 : !tt.memdesc<1xi64, #shared0, #triton_gpu.shared_memory, mutable> tt.return } } @@ -28,9 +28,9 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // CHECK-NEXT: inval_barrier tt.func @inval_barrier() { %cst = arith.constant dense<0> : tensor<1xi64, #blocked0> - %alloc = triton_gpu.local_alloc %cst : (tensor<1xi64, #blocked0>) -> !tt.memdesc<1xi64, #shared0, mutable> - triton_nvidia_gpu.init_barrier %alloc, 1 : !tt.memdesc<1xi64, #shared0, mutable> - triton_nvidia_gpu.inval_barrier %alloc : !tt.memdesc<1xi64, #shared0, mutable> + %alloc = triton_gpu.local_alloc %cst : (tensor<1xi64, #blocked0>) -> !tt.memdesc<1xi64, #shared0, #triton_gpu.shared_memory, mutable> + triton_nvidia_gpu.init_barrier %alloc, 1 : !tt.memdesc<1xi64, #shared0, #triton_gpu.shared_memory, mutable> + triton_nvidia_gpu.inval_barrier %alloc : !tt.memdesc<1xi64, #shared0, #triton_gpu.shared_memory, mutable> tt.return } } @@ -48,9 +48,9 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // CHECK-NEXT: barrier_expect tt.func @barrier_expect(%pred : i1) { %cst = arith.constant dense<0> : tensor<1xi64, #blocked0> - %alloc = triton_gpu.local_alloc %cst : (tensor<1xi64, #blocked0>) -> !tt.memdesc<1xi64, #shared0, mutable> - triton_nvidia_gpu.init_barrier %alloc, 1 : !tt.memdesc<1xi64, #shared0, mutable> - triton_nvidia_gpu.barrier_expect %alloc, 16384, %pred : <1xi64, #shared0, mutable> + %alloc = triton_gpu.local_alloc %cst : (tensor<1xi64, #blocked0>) -> !tt.memdesc<1xi64, #shared0, #triton_gpu.shared_memory, mutable> + triton_nvidia_gpu.init_barrier %alloc, 1 : !tt.memdesc<1xi64, #shared0, #triton_gpu.shared_memory, mutable> + triton_nvidia_gpu.barrier_expect %alloc, 16384, %pred : <1xi64, #shared0, #triton_gpu.shared_memory, mutable> tt.return } } @@ -68,9 +68,9 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // CHECK-NEXT: wait_barrier tt.func @wait_barrier(%phase : i32) { %cst = arith.constant dense<0> : tensor<1xi64, #blocked0> - %alloc = triton_gpu.local_alloc %cst : (tensor<1xi64, #blocked0>) -> !tt.memdesc<1xi64, #shared0, mutable> - triton_nvidia_gpu.init_barrier %alloc, 1 : !tt.memdesc<1xi64, #shared0, mutable> - triton_nvidia_gpu.wait_barrier %alloc, %phase : <1xi64, #shared0, mutable> + %alloc = triton_gpu.local_alloc %cst : (tensor<1xi64, #blocked0>) -> !tt.memdesc<1xi64, #shared0, #triton_gpu.shared_memory, mutable> + triton_nvidia_gpu.init_barrier %alloc, 1 : !tt.memdesc<1xi64, #shared0, #triton_gpu.shared_memory, mutable> + triton_nvidia_gpu.wait_barrier %alloc, %phase : <1xi64, #shared0, #triton_gpu.shared_memory, mutable> tt.return } } @@ -89,8 +89,8 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // CHECK-NEXT: gpu.barrier // CHECK-NEXT: init_barrier %cst = arith.constant dense<0> : tensor<128x64xi64, #blocked0> - %alloc = triton_gpu.local_alloc %cst : (tensor<128x64xi64, #blocked0>) -> !tt.memdesc<128x64xi64, #shared0, mutable> - triton_gpu.local_dealloc %alloc : !tt.memdesc<128x64xi64, #shared0, mutable> + %alloc = triton_gpu.local_alloc %cst : (tensor<128x64xi64, #blocked0>) -> !tt.memdesc<128x64xi64, #shared0, #triton_gpu.shared_memory, mutable> + triton_gpu.local_dealloc %alloc : !tt.memdesc<128x64xi64, #shared0, #triton_gpu.shared_memory, mutable> %l = tt.experimental_descriptor_load %arg0[%arg1, %arg1] : !tt.ptr -> tensor<128x64xf16, #blocked0> tt.return %l : tensor<128x64xf16, #blocked0> } @@ -108,8 +108,8 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // CHECK-NEXT: triton_gpu.local_alloc tt.func public @tma_store(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: i32 {tt.divisibility = 16 : i32}, %arg2: tensor<128x256xf32, #blocked0>) { %cst = arith.constant dense<0> : tensor<128x64xi64, #blocked0> - %alloc = triton_gpu.local_alloc %cst : (tensor<128x64xi64, #blocked0>) -> !tt.memdesc<128x64xi64, #shared0, mutable> - triton_gpu.local_dealloc %alloc : !tt.memdesc<128x64xi64, #shared0, mutable> + %alloc = triton_gpu.local_alloc %cst : (tensor<128x64xi64, #blocked0>) -> !tt.memdesc<128x64xi64, #shared0, #triton_gpu.shared_memory, mutable> + triton_gpu.local_dealloc %alloc : !tt.memdesc<128x64xi64, #shared0, #triton_gpu.shared_memory, mutable> tt.experimental_descriptor_store %arg0[%arg1, %arg1], %arg2 : !tt.ptr, tensor<128x256xf32, #blocked0> tt.return } diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipeline.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipeline.cpp index 6f9ed6a23b4e..24d7aad85b9a 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipeline.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipeline.cpp @@ -403,8 +403,9 @@ void LoopPipeliner::createBufferTypes() { auto sharedEnc = ttg::SharedEncodingAttr::get( ty.getContext(), dotOpEnc, ty.getShape(), ttg::getOrder(ty.getEncoding()), CTALayout, eType); - loadsBufferType[loadOp] = - triton::MemDescType::get(bufferShape, eType, sharedEnc); + loadsBufferType[loadOp] = triton::MemDescType::get( + bufferShape, eType, sharedEnc, + triton::gpu::SharedMemorySpaceAttr::get(ty.getContext())); } } diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2.cpp index 68cab4bdb3f0..d1086c189d33 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2.cpp @@ -738,7 +738,8 @@ MemDescType getExpandedDesc(MemDescType descTy) { expandedShape[2] = shape[1]; auto encoding = descTy.getEncoding(); auto expandedEncoding = getExpandedEncoding(encoding); - auto expandedDesc = MemDescType::get(expandedShape, elTy, expandedEncoding); + auto expandedDesc = MemDescType::get(expandedShape, elTy, expandedEncoding, + descTy.getMemorySpace()); return expandedDesc; } diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DecomposeUnsupportedConversions.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DecomposeUnsupportedConversions.cpp index 4407a50bd0d9..4b8c7c1b37d1 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DecomposeUnsupportedConversions.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DecomposeUnsupportedConversions.cpp @@ -34,8 +34,8 @@ class DecomposeLocalLoadToDotOperand auto dstDotOp = dyn_cast( op.getType().getEncoding()); - auto sharedEncoding = - cast(op.getSrc().getType().getEncoding()); + MemDescType srcType = op.getSrc().getType(); + auto sharedEncoding = cast(srcType.getEncoding()); if (!dstDotOp || !sharedEncoding.getHasLeadingOffset()) return failure(); RankedTensorType type = op.getType(); @@ -55,7 +55,8 @@ class DecomposeLocalLoadToDotOperand triton::gpu::SharedEncodingAttr::get( op.getContext(), dstDotOp, type.getShape(), triton::gpu::getOrder(parentEnc), - triton::gpu::getCTALayout(parentEnc), type.getElementType())); + triton::gpu::getCTALayout(parentEnc), type.getElementType()), + srcType.getMemorySpace()); auto tmp = rewriter.create( op.getLoc(), newSharedDescTy, load); auto newConvert =