Skip to content

Commit

Permalink
review comments
Browse files Browse the repository at this point in the history
Signed-off-by: Benoit Jacob <[email protected]>
  • Loading branch information
bjacob committed Nov 15, 2024
1 parent 4440c91 commit 33739c2
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 19 deletions.
24 changes: 12 additions & 12 deletions compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -299,34 +299,34 @@ OpaqueMmaLayout getOpaqueMMALayout(MLIRContext *context,
MMASingleSubgroupLayout getASingleSubgroupLayout(MmaInterfaceAttr mmaKind) {
if (auto mmaAttr = dyn_cast<MMAAttr>(mmaKind)) {
return mmaAttr.getASingleSubgroupLayout();
} else if (auto vmmaAttr = dyn_cast<VirtualMMAAttr>(mmaKind)) {
}
if (auto vmmaAttr = dyn_cast<VirtualMMAAttr>(mmaKind)) {
return vmmaAttr.getASingleSubgroupLayout();
} else {
assert(false && "unhandled MMA Interface type.");
return {};
}
assert(false && "unhandled MMA Interface type.");
return {};
}

MMASingleSubgroupLayout getBSingleSubgroupLayout(MmaInterfaceAttr mmaKind) {
if (auto mmaAttr = dyn_cast<MMAAttr>(mmaKind)) {
return mmaAttr.getBSingleSubgroupLayout();
} else if (auto vmmaAttr = dyn_cast<VirtualMMAAttr>(mmaKind)) {
}
if (auto vmmaAttr = dyn_cast<VirtualMMAAttr>(mmaKind)) {
return vmmaAttr.getBSingleSubgroupLayout();
} else {
assert(false && "unhandled MMA Interface type.");
return {};
}
assert(false && "unhandled MMA Interface type.");
return {};
}

MMASingleSubgroupLayout getCSingleSubgroupLayout(MmaInterfaceAttr mmaKind) {
if (auto mmaAttr = dyn_cast<MMAAttr>(mmaKind)) {
return mmaAttr.getCSingleSubgroupLayout();
} else if (auto vmmaAttr = dyn_cast<VirtualMMAAttr>(mmaKind)) {
}
if (auto vmmaAttr = dyn_cast<VirtualMMAAttr>(mmaKind)) {
return vmmaAttr.getCSingleSubgroupLayout();
} else {
assert(false && "unhandled MMA Interface type.");
return {};
}
assert(false && "unhandled MMA Interface type.");
return {};
}

//===----------------------------------------------------------------------===//
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@

namespace mlir::iree_compiler::IREE::GPU {

// Struct describing the detailed subgroup-level layout of a MMA operation.
// Struct describing the detailed subgroup-level layout of a MMA intrinsic.
// Together with element type information and subgroup size, it completes the
// full description of the semantics of a MMA operation.
// full description of the semantics of a MMA intrinsic.
//
// Note: It is not possible to infer subgroup size from the information in this
// struct. The product of the `thread` sizes here is often, but not always equal
Expand All @@ -33,7 +33,7 @@ namespace mlir::iree_compiler::IREE::GPU {
// semantics in that case are that threads within the subgroup whose thread-ids
// differ by a multiple of `P`, are accessing the same elements.
//
// Example observed in RDNA3 WMMA Wave64 ops:
// Example observed in RDNA3 WMMA Wave64 intrinsics:
// If the subgroup size is 64 but the product `P` of `thread` sizes is 32, that
// means that each element is being accessed by 2 threads (2 = 64/32), and the
// threads accessing the same element are those whose tids are exactly 32 apart.
Expand All @@ -44,7 +44,7 @@ struct MMASingleSubgroupLayout {
// are NOT contiguous.
// This is not used by every MMA op; ops which don't use that simply have 1's.
SmallVector<int64_t, 2> outer;
// Cross-thread dimensions (as in TileSwizzle::Dim::Kind::CrossThread).
// Cross-thread dimensions (as in TileSwizzle::Dim::Kind::CrossThread).
// This is the kind of dimension that is present in all GPU MMA ops, by
// definition of "SIMT". It is still possible for one of the `thread` dims to
// be 1, but not both.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,7 @@ LogicalResult materializeOperandConcreteShape(
// Inner tile must have sizes matching the opaque layout.
auto operandType = llvm::cast<RankedTensorType>(operand.getType());
ArrayRef<int64_t> operandShape = operandType.getShape();
SmallVector<int64_t, 2> innerShape(operandShape.end() - opaqueSizes.size(),
operandShape.end());
if (!llvm::equal(opaqueSizes, innerShape)) {
if (opaqueSizes != operandShape.take_back(opaqueSizes.size())) {
return failure();
}

Expand Down

0 comments on commit 33739c2

Please sign in to comment.