Skip to content

Commit

Permalink
[BACKEND] Add memory space to memdesc type. (#4027)
Browse files Browse the repository at this point in the history
Currently only shared memory is supported but this will allow supporting
different kinds of local memory (like private) or others.
  • Loading branch information
ThomasRaoux authored May 29, 2024
1 parent 445d5ed commit d527c3f
Show file tree
Hide file tree
Showing 30 changed files with 594 additions and 526 deletions.
11 changes: 7 additions & 4 deletions include/triton/Dialect/Triton/IR/TritonTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -106,12 +106,13 @@ def TT_MemDescType : TritonTypeDef<"MemDesc", "memdesc", [ShapedTypeInterface]>
ArrayRefParameter<"int64_t">:$shape,
"Type":$elementType,
"Attribute":$encoding,
"Attribute":$memorySpace,
"bool":$mutable_memory
);
let extraClassDeclaration = [{
MemDescType cloneWith(std::optional<ArrayRef<int64_t>> shape,
Type elementType) const {
return MemDescType::get(shape.value_or(getShape()), elementType, getEncoding());
return MemDescType::get(shape.value_or(getShape()), elementType, getEncoding(), getMemorySpace(), getMutableMemory());
}

bool hasRank() const { return true; }
Expand All @@ -120,17 +121,19 @@ def TT_MemDescType : TritonTypeDef<"MemDesc", "memdesc", [ShapedTypeInterface]>
TypeBuilderWithInferredContext<(ins
"llvm::ArrayRef<int64_t>":$shape,
"Type":$elementType,
"Attribute":$encoding
"Attribute":$encoding,
"Attribute":$memorySpace
), [{
return $_get(elementType.getContext(), shape, elementType, encoding, /*mutableMemory=*/false);
return $_get(elementType.getContext(), shape, elementType, encoding, memorySpace, /*mutableMemory=*/false);
}]>,
TypeBuilderWithInferredContext<(ins
"llvm::ArrayRef<int64_t>":$shape,
"Type":$elementType,
"Attribute":$encoding,
"Attribute":$memorySpace,
"bool":$mutableMemory
), [{
return $_get(elementType.getContext(), shape, elementType, encoding, mutableMemory);
return $_get(elementType.getContext(), shape, elementType, encoding, memorySpace, mutableMemory);
}]>
];
let hasCustomAssemblyFormat = 1;
Expand Down
6 changes: 6 additions & 0 deletions include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td
Original file line number Diff line number Diff line change
Expand Up @@ -1298,4 +1298,10 @@ elements along the K dim, or they use all elements of the tensor along the K dim
}];
}

def TTG_SharedMemorySpace : AttrDef<TritonGPU_Dialect, "SharedMemorySpace"> {
let mnemonic = "shared_memory";
let description = [{
Attribute to indicate that the memory descriptor points to shared memory.
}];
}
#endif
6 changes: 6 additions & 0 deletions include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,12 @@ def TTG_LocalAllocOp : TTG_Op<"local_alloc", [DeclareOpInterfaceMethods<MemoryEf
}];
let arguments = (ins Optional<TT_Tensor>:$src);

let extraClassDeclaration = [{
bool isSharedMemoryAlloc() {
return getType().getMemorySpace() &&
isa<SharedMemorySpaceAttr>(getType().getMemorySpace());
}
}];
let assemblyFormat = [{$src attr-dict `:` functional-type(operands, results)}];

let results = (outs TT_MemDescType:$result);
Expand Down
7 changes: 6 additions & 1 deletion lib/Analysis/Alias.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,13 @@ void SharedMemoryAliasAnalysis::visitOperation(
ArrayRef<dataflow::Lattice<AliasInfo> *> results) {
AliasInfo aliasInfo;
bool pessimistic = true;
// These ops may allocate a new shared memory buffer.
auto result = op->getResult(0);
// skip ops that return memdesc in a different memory space.
if (auto memdescTy = dyn_cast<triton::MemDescType>(result.getType())) {
if (!isa_and_nonnull<triton::gpu::SharedMemorySpaceAttr>(
memdescTy.getMemorySpace()))
return;
}

// Only LocalAllocOp creates a new buffer.
if (isa<triton::gpu::LocalAllocOp>(op)) {
Expand Down
3 changes: 2 additions & 1 deletion lib/Analysis/Allocation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,8 @@ class AllocationAnalysis {
// XXX(Keren): Why this hard-coded alignment?
size_t kAlignment = 8;
for (Value result : op->getResults()) {
if (auto alloc = result.getDefiningOp<triton::gpu::LocalAllocOp>()) {
auto alloc = result.getDefiningOp<triton::gpu::LocalAllocOp>();
if (alloc && alloc.isSharedMemoryAlloc()) {
// Bytes could be a different value once we support padding or other
// allocation policies.
auto allocType = alloc.getType();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,12 +95,15 @@ void decomposeBlockedToDotLayoutConversion(ModuleOp module) {
auto dstDotOp =
dyn_cast<triton::gpu::DotOperandEncodingAttr>(dstType.getEncoding());
if (srcBlocked && dstDotOp) {
Attribute sharedMemorySpace =
triton::gpu::SharedMemorySpaceAttr::get(srcType.getContext());
auto tmpType = MemDescType::get(
dstType.getShape(), dstType.getElementType(),
triton::gpu::SharedEncodingAttr::get(
module.getContext(), dstDotOp, srcType.getShape(),
srcBlocked.getOrder(), srcBlocked.getCTALayout(),
srcType.getElementType()));
srcType.getElementType()),
sharedMemorySpace);
auto tmp = builder.create<triton::gpu::LocalAllocOp>(
cvtOp.getLoc(), tmpType, cvtOp.getSrc());
addAttrs(tmp, cvtOp->getAttrs());
Expand Down
2 changes: 2 additions & 0 deletions lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ struct LocalAllocOpConversion
LogicalResult
matchAndRewrite(triton::gpu::LocalAllocOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (!op.isSharedMemoryAlloc())
return failure();
Location loc = op->getLoc();
Value smemBase =
LLVM::getSharedMemoryBase(loc, rewriter, op.getOperation());
Expand Down
6 changes: 3 additions & 3 deletions lib/Dialect/Triton/IR/Ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -224,9 +224,9 @@ LogicalResult TransOp::inferReturnTypes(
return failure();
}
}
if (isa<MemDescType>(argTy)) {
inferredReturnTypes.push_back(
MemDescType::get(retShape, retEltTy, retEncoding));
if (auto memDescTy = dyn_cast<MemDescType>(argTy)) {
inferredReturnTypes.push_back(MemDescType::get(
retShape, retEltTy, retEncoding, memDescTy.getMemorySpace()));
} else {
inferredReturnTypes.push_back(
RankedTensorType::get(retShape, retEltTy, retEncoding));
Expand Down
14 changes: 12 additions & 2 deletions lib/Dialect/Triton/IR/Types.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,16 +70,24 @@ Type MemDescType::parse(AsmParser &parser) {
return Type();
}
bool mutableMemory = false;
Attribute memorySpace;
if (succeeded(parser.parseOptionalComma())) {
if (failed(parser.parseOptionalKeyword(kMutableMemory))) {
if (parser.parseAttribute(memorySpace))
return Type();
} else {
mutableMemory = true;
}
}
if (mutableMemory == false && succeeded(parser.parseOptionalComma())) {
if (parser.parseOptionalKeyword(kMutableMemory))
return Type();
mutableMemory = true;
}
if (parser.parseGreater())
return Type();

return MemDescType::get(parser.getContext(), dimensions, elementType,
encoding, mutableMemory);
encoding, memorySpace, mutableMemory);
}

void MemDescType::print(AsmPrinter &printer) const {
Expand All @@ -89,6 +97,8 @@ void MemDescType::print(AsmPrinter &printer) const {
printer << getElementType();
if (getEncoding())
printer << ", " << getEncoding();
if (getMemorySpace())
printer << ", " << getMemorySpace();
if (getMutableMemory())
printer << ", " << kMutableMemory;
printer << ">";
Expand Down
7 changes: 5 additions & 2 deletions lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -208,12 +208,15 @@ class BlockedToMMA : public mlir::RewritePattern {
}
}

Attribute SharedMemorySpace =
SharedMemorySpaceAttr::get(argType.getContext());
auto CTALayout = getCTALayout(argType.getEncoding());
auto newLayout =
SharedEncodingAttr::get(argType.getContext(), argType.getShape(),
newOrder, CTALayout, argType.getElementType());
auto newType = MemDescType::get(argType.getShape(),
argType.getElementType(), newLayout);
auto newType =
MemDescType::get(argType.getShape(), argType.getElementType(),
newLayout, SharedMemorySpace);
rewriter.setInsertionPointAfterValue(arg);
return rewriter.create<LocalAllocOp>(arg.getLoc(), newType, arg);
}
Expand Down
7 changes: 4 additions & 3 deletions lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,12 +59,12 @@ class SwizzleShmemConvert : public OpRewritePattern<ConvertLayoutOp> {
srcTy.getElementType(), /*needTrans=*/true);
if (newInnerCvtEnc == cvtEncoding)
return failure();

rewriter.setInsertionPoint(trans);
auto sharedMemorySpace = SharedMemorySpaceAttr::get(getContext());
auto alloc = rewriter.create<LocalAllocOp>(
trans.getLoc(),
MemDescType::get(srcTy.getShape(), srcTy.getElementType(),
newInnerCvtEnc),
newInnerCvtEnc, sharedMemorySpace),
trans.getSrc());
auto newTrans = rewriter.create<TransOp>(trans.getLoc(), alloc,
ArrayRef<int32_t>({1, 0}));
Expand Down Expand Up @@ -254,7 +254,8 @@ class FuseTransHopper : public OpRewritePattern<LocalAllocOp> {
allocEncoding.getCTALayout(), srcTy.getElementType());

MemDescType innerTy =
MemDescType::get(srcTy.getShape(), srcTy.getElementType(), newInnerEnc);
MemDescType::get(srcTy.getShape(), srcTy.getElementType(), newInnerEnc,
allocType.getMemorySpace());
auto newAlloc = rewriter.create<LocalAllocOp>(allocOp.getLoc(), innerTy,
trans.getSrc());
rewriter.replaceOpWithNewOp<TransOp>(allocOp, newAlloc,
Expand Down
32 changes: 24 additions & 8 deletions lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -245,9 +245,11 @@ static void createAsyncCopy(scf::ForOp &forOp, tt::LoadOp loadOp, Value alloc,
tt::MemDescType allocTy = cast<tt::MemDescType>(alloc.getType());
SmallVector<Value> copyOffsets(allocTy.getRank(), zero);
copyOffsets[0] = insertIdx;
Attribute sharedMemorySpace =
triton::gpu::SharedMemorySpaceAttr::get(forOp.getContext());
tt::MemDescType subviewTy = tt::MemDescType::get(
allocTy.getShape().drop_front(), allocTy.getElementType(),
allocTy.getEncoding(), /*mutableMemory=*/true);
allocTy.getEncoding(), sharedMemorySpace, /*mutableMemory=*/true);
auto view =
builder.create<ttg::MemDescSubviewOp>(loc, subviewTy, alloc, copyOffsets);
Operation *copy = builder.create<ttg::AsyncCopyGlobalToLocalOp>(
Expand Down Expand Up @@ -316,6 +318,8 @@ static void createTMAAsyncCopy(
llvm::MapVector<Operation *, LoadInfo> &loadToInfo, int numStages) {
assert(phase && "Phase value is required for TMA async copy.");
OpBuilder builder(forOp);
Attribute sharedMemorySpace =
triton::gpu::SharedMemorySpaceAttr::get(forOp.getContext());
Value zero = builder.create<arith::ConstantIntOp>(forOp.getLoc(), 0, 32);
builder.setInsertionPoint(loadOp);
Location loc = loadOp.getLoc();
Expand All @@ -324,7 +328,7 @@ static void createTMAAsyncCopy(
copyOffsets[0] = insertIdx;
tt::MemDescType subviewTy = tt::MemDescType::get(
allocTy.getShape().drop_front(), allocTy.getElementType(),
allocTy.getEncoding(), /*mutableMemory=*/true);
allocTy.getEncoding(), sharedMemorySpace, /*mutableMemory=*/true);
auto view =
builder.create<ttg::MemDescSubviewOp>(loc, subviewTy, alloc, copyOffsets);

Expand Down Expand Up @@ -906,11 +910,14 @@ static void scheduleRemainingToLastStage(scf::ForOp forOp,
static Value createAlloc(scf::ForOp &forOp, Operation *loadOp,
ttg::SharedEncodingAttr sharedEnc, unsigned distance) {
OpBuilder builder(forOp);
Attribute sharedMemorySpace =
triton::gpu::SharedMemorySpaceAttr::get(forOp.getContext());
auto ty = cast<RankedTensorType>(loadOp->getResultTypes()[0]);
SmallVector<int64_t> bufferShape(ty.getShape().begin(), ty.getShape().end());
bufferShape.insert(bufferShape.begin(), distance);
Type memdescType = mlir::triton::MemDescType::get(
bufferShape, ty.getElementType(), sharedEnc, /*mutableMemory*/ true);
bufferShape, ty.getElementType(), sharedEnc, sharedMemorySpace,
/*mutableMemory*/ true);
Value alloc = builder.create<mlir::triton::gpu::LocalAllocOp>(
loadOp->getLoc(), memdescType, Value());
return alloc;
Expand All @@ -919,18 +926,21 @@ static Value createAlloc(scf::ForOp &forOp, Operation *loadOp,
// Create an allocation to hold the mbarriers.
static Value createBarrierAlloc(scf::ForOp &forOp, unsigned distance) {
OpBuilder builder(forOp);
Attribute sharedMemorySpace =
triton::gpu::SharedMemorySpaceAttr::get(forOp.getContext());
Location loc = forOp.getLoc();
auto context = forOp.getContext();
auto barrierCTALayout =
ttg::CTALayoutAttr::get(context, /*CTAsPerCGA=*/{1},
/*CTASplitNum=*/{1}, /*CTAOrder=*/{0});
auto barrierEncoding =
ttg::SharedEncodingAttr::get(context, 1, 1, 1, {0}, barrierCTALayout);
Type barrierMemDescType =
tt::MemDescType::get({distance}, builder.getI64Type(), barrierEncoding,
/*mutableMemory=*/true);
Type singleBarrierMemDescType = tt::MemDescType::get(
{1}, builder.getI64Type(), barrierEncoding, /*mutableMemory=*/true);
Type barrierMemDescType = tt::MemDescType::get(
{distance}, builder.getI64Type(), barrierEncoding, sharedMemorySpace,
/*mutableMemory=*/true);
Type singleBarrierMemDescType =
tt::MemDescType::get({1}, builder.getI64Type(), barrierEncoding,
sharedMemorySpace, /*mutableMemory=*/true);
Value barrierAlloc = builder.create<mlir::triton::gpu::LocalAllocOp>(
loc, barrierMemDescType, Value());
for (unsigned i = 0; i < distance; i++) {
Expand Down Expand Up @@ -1026,9 +1036,12 @@ static void createTMABarrierAndWait(
barriers.push_back(barrierAlloc);
Location loc = forOp.getLoc();
OpBuilder builder(forOp);
Attribute sharedMemorySpace =
triton::gpu::SharedMemorySpaceAttr::get(builder.getContext());
tt::MemDescType barrierTy = tt::MemDescType::get(
{1}, builder.getI64Type(),
cast<tt::MemDescType>(barrierAlloc.getType()).getEncoding(),
sharedMemorySpace,
/*mutableMemory=*/true);
builder.setInsertionPoint(group[0]->loadOp);
Value barrier = builder.create<ttg::MemDescSubviewOp>(
Expand Down Expand Up @@ -1167,13 +1180,16 @@ createAsyncOps(scf::ForOp &forOp, CoarseSchedule &schedule,

static void invalidateBarriers(OpBuilder &builder,
SmallVector<Value> &barriers) {
Attribute sharedMemorySpace =
triton::gpu::SharedMemorySpaceAttr::get(builder.getContext());
for (Value barrier : barriers) {
int numBarriers = cast<tt::MemDescType>(barrier.getType()).getShape()[0];
for (int i = 0; i < numBarriers; i++) {
Value idx = builder.create<arith::ConstantIntOp>(barrier.getLoc(), i, 32);
tt::MemDescType barrierTy = tt::MemDescType::get(
{1}, builder.getI64Type(),
cast<tt::MemDescType>(barrier.getType()).getEncoding(),
sharedMemorySpace,
/*mutableMemory=*/true);
Value barrierView = builder.create<ttg::MemDescSubviewOp>(
barrier.getLoc(), barrierTy, barrier, idx);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,11 @@ static Value createAlloc(scf::ForOp &forOp,
encoding = ttg::SharedEncodingAttr::get(
ty.getContext(), ty.getShape(), order, ctaLayout, ty.getElementType());
}

Type memdescType = tt::MemDescType::get(ty.getShape(), ty.getElementType(),
encoding, /*mutableMemory*/ true);
Attribute sharedMemorySpace =
triton::gpu::SharedMemorySpaceAttr::get(ty.getContext());
Type memdescType =
tt::MemDescType::get(ty.getShape(), ty.getElementType(), encoding,
sharedMemorySpace, /*mutableMemory*/ true);
Value alloc = builder.create<ttg::LocalAllocOp>(storeOp->getLoc(),
memdescType, Value());
return alloc;
Expand Down
5 changes: 3 additions & 2 deletions lib/Dialect/TritonGPU/Transforms/Prefetch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -136,8 +136,9 @@ Value Prefetcher::generatePrefetch(Value v, unsigned opIdx, bool isPrologue,
builder.create<arith::ConstantIntOp>(v.getLoc(), off, 32));
Value newSmem = builder.create<triton::gpu::MemDescSubviewOp>(
v.getLoc(),
triton::MemDescType::get(shape, elementType, type.getEncoding()), v,
offsetsVal);
triton::MemDescType::get(shape, elementType, type.getEncoding(),
type.getMemorySpace()),
v, offsetsVal);

auto dotOperandEnc = triton::gpu::DotOperandEncodingAttr::get(
builder.getContext(), opIdx, dotEncoding, prefetchWidth / 8);
Expand Down
6 changes: 4 additions & 2 deletions lib/Dialect/TritonGPU/Transforms/ReduceDataDuplication.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,12 +70,14 @@ class TritonGPUReduceDataDuplicationPass
} else {
sharedOrder = srcOrder;
}
auto sharedMemorySpace =
triton::gpu::SharedMemorySpaceAttr::get(srcType.getContext());
auto tmpType = triton::MemDescType::get(
dstType.getShape(), dstType.getElementType(),
triton::gpu::SharedEncodingAttr::get(
mod.getContext(), dstDotOp, srcType.getShape(), sharedOrder,
triton::gpu::getCTALayout(srcEncoding),
srcType.getElementType()));
triton::gpu::getCTALayout(srcEncoding), srcType.getElementType()),
sharedMemorySpace);
auto tmp = builder.create<triton::gpu::LocalAllocOp>(
cvtOp.getLoc(), tmpType, cvtOp.getSrc());
auto newConvert = builder.create<triton::gpu::LocalLoadOp>(cvtOp.getLoc(),
Expand Down
13 changes: 9 additions & 4 deletions lib/Dialect/TritonNvidiaGPU/Transforms/TMALowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ class TMALoadLowering : public OpRewritePattern<ExperimentalDescriptorLoadOp> {

LogicalResult matchAndRewrite(ExperimentalDescriptorLoadOp op,
PatternRewriter &rewriter) const override {
MLIRContext *ctx = op.getContext();
Attribute sharedMemorySpace = triton::gpu::SharedMemorySpaceAttr::get(ctx);
auto loc = op.getLoc();
auto tensorType = op.getResult().getType();
auto order = getOrder(tensorType.getEncoding());
Expand All @@ -36,15 +38,16 @@ class TMALoadLowering : public OpRewritePattern<ExperimentalDescriptorLoadOp> {
}
MemDescType memDescType =
MemDescType::get(tensorType.getShape(), tensorType.getElementType(),
encoding, /*mutableMemory=*/true);
encoding, sharedMemorySpace, /*mutableMemory=*/true);
Value alloc = rewriter.create<LocalAllocOp>(loc, memDescType, Value());
auto barrierCTALayout = CTALayoutAttr::get(
/*context=*/tensorType.getContext(), /*CTAsPerCGA=*/{1},
/*CTASplitNum=*/{1}, /*CTAOrder=*/{0});
auto barrierEncoding = SharedEncodingAttr::get(tensorType.getContext(), 1,
1, 1, {0}, barrierCTALayout);
MemDescType barrierMemDescType = MemDescType::get(
{1}, rewriter.getI64Type(), barrierEncoding, /*mutableMemory=*/true);
MemDescType barrierMemDescType =
MemDescType::get({1}, rewriter.getI64Type(), barrierEncoding,
sharedMemorySpace, /*mutableMemory=*/true);
Value barrierAlloc =
rewriter.create<LocalAllocOp>(loc, barrierMemDescType, Value());
rewriter.create<InitBarrierOp>(loc, barrierAlloc, 1);
Expand All @@ -70,6 +73,8 @@ class TMAStoreLowering

LogicalResult matchAndRewrite(ExperimentalDescriptorStoreOp op,
PatternRewriter &rewriter) const override {
MLIRContext *ctx = op.getContext();
Attribute sharedMemorySpace = triton::gpu::SharedMemorySpaceAttr::get(ctx);
auto loc = op.getLoc();
auto tensorType = op.getSrc().getType();
auto order = getOrder(tensorType.getEncoding());
Expand All @@ -83,7 +88,7 @@ class TMAStoreLowering
}
MemDescType memDescType =
MemDescType::get(tensorType.getShape(), tensorType.getElementType(),
encoding, /*mutableMemory=*/true);
encoding, sharedMemorySpace, /*mutableMemory=*/true);
Value alloc = rewriter.create<LocalAllocOp>(loc, memDescType, op.getSrc());
rewriter.create<triton::nvidia_gpu::FenceAsyncSharedOp>(loc, false);
rewriter.create<triton::nvidia_gpu::AsyncTMACopyLocalToGlobalOp>(
Expand Down
Loading

0 comments on commit d527c3f

Please sign in to comment.