Skip to content

Commit

Permalink
[LLVMGPU] Add Virtual MFMA layout that maximizes load through adjuste…
Browse files Browse the repository at this point in the history
…d K-width (#18930)

The main use case for the virtual intrinsics are to change the layout of
intrinsics in K-dimension, such that we can coalesce reads from shared
memory to register.

Currently, the "native" intrinsics need to enforce the "native" layout
(i.e read 4 element per thread for MFMA_F32_16x16x16), however since we
know that K-dim is a reduction dimension which is associative, we can
read the data in non "native"/"correct" but "faster"/"more elements per
read" way but as long as we match the K-dim on both lhs and rhs we will
still get correct results (i.e read 8 contiguous element per thread from
shared memory along dimension K for and then slice them into two
MFMA_F32_16x16x16)).

an IR example for this is if we want to do a 16x16x32(MxNxK) matmul with
MFMA_F32_16x16x16_F16 intrinsics, on lane 0 we used to have something
like:

```
lhs_0 = read(lhs_shared_mem[0:4])
rhs_0 = read(rhs_shared_mem[0:4])
mma_0 = vector.contract(lhs_0, rhs_0)

(16 offset since MFMA_F32_16x16x16xF16 has intrinsic K size of 16)
lhs_1 = read(lhs_shared_mem[16 + 0: 16 + 4])
rhs_1 = read(rhs_shared_mem[16 + 0 : 16 + 4])
mma_1 = vector.contract(lhs_1, rhs_1, mma_0)
```

With this optimization, we will turn into something like:

```
lhs_reg = read(lhs_shared_mem[0:8])
rhs_reg = read(rhs_shared_mem[0:8])

lhs_0 = slice(lhs_reg, [0 : 4])
rhs_0 = slice(rhs_reg, [0 : 4])
mma_0 = vector.contract(lhs_0, rhs_0)

lhs_1 = slice(lhs_reg, [4 : 8])
rhs_1 = slice(rhs_reg, [4 : 8])
mma_1 = vector.contract(lhs_0, rhs_0, mma_0)
```

Currently, we are plumbing it in as MMA intrinsic enums for two variants
of unrolled k == 2 on the F16s(per discussion with @qedawkins and
@Groverkss ), as they are the easiest and non tangly way to
integrate/plumb through. all though in the future we can expose this
attribute as k-width for maximizing generability.

---------

Signed-off-by: Stanley Winata <[email protected]>
  • Loading branch information
raikonenfnu authored Oct 31, 2024
1 parent 20c8347 commit bb542ee
Show file tree
Hide file tree
Showing 7 changed files with 305 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -590,3 +590,96 @@ builtin.module attributes { transform.with_named_sequence } {
// CHECK: %[[B_OUT:.*]] = vector.broadcast %[[R_CAST]] : vector<8x1x1x1xf32> to vector<1x1x8x1x1x1xf32>
// CHECK: %[[R_SIMD:.+]] = iree_vector_ext.to_simd %[[B_OUT]] : vector<1x1x8x1x1x1xf32> -> vector<16x16xf32>
// CHECK: return {{.*}} %[[R_SIMD]]

// -----

// Non-native MFMA_F32_32x32x16_F16, i.e CDNA3 V_MFMA_F32_32x32x8_F16 with unrolled_k = 2.
// This non native layout maximizes reads from shared memory to register.

#map1 = affine_map<(m, n, k) -> (m, k)>
#map2 = affine_map<(m, n, k) -> (k, n)>
#map3 = affine_map<(m, n, k) -> (m, n)>

// A: shape = 32x16, layout = layoutA
#layout_a = #iree_vector_ext.nested_layout<
subgroup_tile = [1, 1],
batch_tile = [1, 1],
outer_tile = [1, 1],
thread_tile = [32, 2],
element_tile = [1, 8],

subgroup_strides = [1, 1],
thread_strides = [1, 32]
>

// B: shape = 16x32, layout = layoutB
#layout_b = #iree_vector_ext.nested_layout<
subgroup_tile = [1, 1],
batch_tile = [1, 1],
outer_tile = [1, 1],
thread_tile = [2, 32],
element_tile = [8, 1],

subgroup_strides = [1, 1],
thread_strides = [32, 1]
>

// C: shape = 32x32, layout = layoutC
#layout_c = #iree_vector_ext.nested_layout<
subgroup_tile = [1, 1],
batch_tile = [1, 1],
outer_tile = [4, 1],
thread_tile = [2, 32],
element_tile = [4, 1],

subgroup_strides = [1, 1],
thread_strides = [32, 1]
>

func.func @contract_to_vmfma_32x32x16_mm(%a : vector<32x16xf16>, %b : vector<16x32xf16>, %c : vector<32x32xf32>) -> vector<32x32xf32> {
%A = iree_vector_ext.to_layout %a to layout(#layout_a) : vector<32x16xf16>
%B = iree_vector_ext.to_layout %b to layout(#layout_b) : vector<16x32xf16>
%C = iree_vector_ext.to_layout %c to layout(#layout_c) : vector<32x32xf32>

%output = vector.contract {
indexing_maps = [#map1, #map2, #map3],
iterator_types = ["parallel", "parallel", "reduction"],
kind = #vector.kind<add>,
iree.amdgpu.mma = #iree_gpu.mma_layout<VMFMA_F32_32x32x16_F16>
} %A, %B, %C : vector<32x16xf16>, vector<16x32xf16> into vector<32x32xf32>

%O = iree_vector_ext.to_layout %output to layout(#layout_c) : vector<32x32xf32>
return %O : vector<32x32xf32>
}

builtin.module attributes { transform.with_named_sequence } {
transform.named_sequence @__transform_main(%variant_op: !transform.any_op {transform.readonly}) {
%top_level_func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op
transform.iree.test_gpu_vector_distribution %top_level_func : !transform.any_op
transform.yield
}
}

// Notable things to look out for:
// 1. We are reading 8xf16 instead of 4xf16 for lhs,rhs operands.
// 2. We slice the 8xf16 to 2 different 4xf16 per operand for use on 2 MMAs.
// 3. Result of first mma becomes the second mma's accumulator.

// CHECK-LABEL: func @contract_to_vmfma_32x32x16_mm
// CHECK: %[[A_CAST:.+]] = vector.shape_cast %{{.+}} : vector<1x1x1x8xf16> to vector<8xf16>
// CHECK: %[[B_CAST:.+]] = vector.shape_cast %{{.+}} : vector<1x1x8x1xf16> to vector<8xf16>
// CHECK: %[[C_CAST:.+]] = vector.shape_cast %{{.+}} : vector<4x1x4x1xf32> to vector<16xf32>
// CHECK: %[[A_SLICE_0:.+]] = vector.extract_strided_slice %[[A_CAST]] {offsets = [0], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
// CHECK: %[[B_SLICE_0:.+]] = vector.extract_strided_slice %[[B_CAST]] {offsets = [0], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
// CHECK: %[[MFMA_0:.*]] = amdgpu.mfma %[[A_SLICE_0]] * %[[B_SLICE_0]] + %[[C_CAST]]
// CHECK-SAME: {blocks = 1 : i32, k = 8 : i32, m = 32 : i32, n = 32 : i32} blgp = none
// CHECK-SAME: : vector<4xf16>, vector<4xf16>, vector<16xf32>
// CHECK: %[[A_SLICE_1:.+]] = vector.extract_strided_slice %[[A_CAST]] {offsets = [4], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
// CHECK: %[[B_SLICE_1:.+]] = vector.extract_strided_slice %[[B_CAST]] {offsets = [4], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
// CHECK: %[[MFMA_1:.+]] = amdgpu.mfma %[[A_SLICE_1]] * %[[B_SLICE_1]] + %[[MFMA_0]]
// CHECK-SAME: {blocks = 1 : i32, k = 8 : i32, m = 32 : i32, n = 32 : i32} blgp = none
// CHECK-SAME: : vector<4xf16>, vector<4xf16>, vector<16xf32>
// CHECK: %[[R_CAST:.+]] = vector.shape_cast %[[MFMA_1]] : vector<16xf32> to vector<4x1x4x1xf32>
// CHECK: %[[B_OUT:.*]] = vector.broadcast %[[R_CAST]] : vector<4x1x4x1xf32> to vector<1x1x4x1x4x1xf32>
// CHECK: %[[R_SIMD:.+]] = iree_vector_ext.to_simd %[[B_OUT]] : vector<1x1x4x1x4x1xf32> -> vector<32x32xf32>
// CHECK: return {{.*}} %[[R_SIMD]]
60 changes: 60 additions & 0 deletions compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,14 @@ static OpaqueMmaLayout getOpaqueMFMALayout(MLIRContext *context,
case MMAIntrinsic::WMMA_I32_16x16x16_I8: {
return OpaqueMmaLayout{16, 16, 16, i8, i8, i32};
}
// V(Virtual)MFMA instructions which have 2 mfma instructions interleaved
// along the k dimension.
case MMAIntrinsic::VMFMA_F32_16x16x32_F16: {
return OpaqueMmaLayout{16, 16, 32, f16, f16, f32};
}
case MMAIntrinsic::VMFMA_F32_32x32x16_F16: {
return OpaqueMmaLayout{32, 32, 16, f16, f16, f32};
}
}
llvm_unreachable("unhandled mfma layout type");
return OpaqueMmaLayout{};
Expand Down Expand Up @@ -412,12 +420,14 @@ MMAAttr::getABCVectorTypes() const {
}
case MMAIntrinsic::MFMA_F32_16x16x32_F8E4M3FNUZ:
case MMAIntrinsic::MFMA_F32_16x16x32_F8E5M2FNUZ:
case MMAIntrinsic::VMFMA_F32_16x16x32_F16:
case MMAIntrinsic::MFMA_I32_16x16x32_I8: {
auto aType = VectorType::get({8}, getAType());
auto bType = VectorType::get({8}, getBType());
auto cType = VectorType::get({4}, getCType());
return std::make_tuple(aType, bType, cType);
}
case MMAIntrinsic::VMFMA_F32_32x32x16_F16:
case MMAIntrinsic::MFMA_I32_32x32x16_I8: {
auto aType = VectorType::get({8}, getAType());
auto bType = VectorType::get({8}, getBType());
Expand Down Expand Up @@ -461,7 +471,9 @@ int64_t MMAAttr::getBlockSize() const {
case MMAIntrinsic::MFMA_I32_32x32x8_I8:
case MMAIntrinsic::MFMA_F32_16x16x32_F8E4M3FNUZ:
case MMAIntrinsic::MFMA_F32_16x16x32_F8E5M2FNUZ:
case MMAIntrinsic::VMFMA_F32_16x16x32_F16:
case MMAIntrinsic::MFMA_I32_16x16x32_I8:
case MMAIntrinsic::VMFMA_F32_32x32x16_F16:
case MMAIntrinsic::MFMA_I32_32x32x16_I8:
case MMAIntrinsic::WMMA_F16_16x16x16_F16:
case MMAIntrinsic::WMMA_F32_16x16x16_F16:
Expand All @@ -484,7 +496,9 @@ static int64_t getIntrinsicSubgroupSize(MMAIntrinsic intrinsic) {
case MMAIntrinsic::MFMA_I32_32x32x8_I8:
case MMAIntrinsic::MFMA_F32_16x16x32_F8E4M3FNUZ:
case MMAIntrinsic::MFMA_F32_16x16x32_F8E5M2FNUZ:
case MMAIntrinsic::VMFMA_F32_16x16x32_F16:
case MMAIntrinsic::MFMA_I32_16x16x32_I8:
case MMAIntrinsic::VMFMA_F32_32x32x16_F16:
case MMAIntrinsic::MFMA_I32_32x32x16_I8: {
return 64;
}
Expand Down Expand Up @@ -549,6 +563,7 @@ MMASingleSubgroupLayout getSingleSubgroupLayout(MMAIntrinsic intrinsic,
return {/*outer=*/{4, 1}, /*thread=*/{2, 32}, /*tstrides=*/{32, 1},
/*element=*/{4, 1}};
}
case MMAIntrinsic::VMFMA_F32_16x16x32_F16:
case MMAIntrinsic::MFMA_F32_16x16x32_F8E4M3FNUZ:
case MMAIntrinsic::MFMA_F32_16x16x32_F8E5M2FNUZ:
case MMAIntrinsic::MFMA_I32_16x16x32_I8:
Expand All @@ -563,6 +578,7 @@ MMASingleSubgroupLayout getSingleSubgroupLayout(MMAIntrinsic intrinsic,
return {/*outer=*/{1, 1}, /*thread=*/{4, 16}, /*tstrides=*/{16, 1},
/*element=*/{4, 1}};
}
case MMAIntrinsic::VMFMA_F32_32x32x16_F16:
case MMAIntrinsic::MFMA_I32_32x32x16_I8:
switch (fragment) {
case MMAFragment::Lhs:
Expand Down Expand Up @@ -616,6 +632,19 @@ MMASingleSubgroupLayout MMAAttr::getCSingleSubgroupLayout() const {
return getSingleSubgroupLayout(getIntrinsic().getValue(), MMAFragment::Acc);
}

// Get virtual intrinsics that is composed/based on queried op.
SmallVector<MMAIntrinsic> MMAAttr::getVirtualIntrinsics() const {
switch (getIntrinsic().getValue()) {
case MMAIntrinsic::MFMA_F32_16x16x16_F16:
return {MMAIntrinsic::VMFMA_F32_16x16x32_F16};
case MMAIntrinsic::MFMA_F32_32x32x8_F16:
return {MMAIntrinsic::VMFMA_F32_32x32x16_F16};
default:
return {};
}
return {};
}

// Generates amdgpu.mfma/wmma operation on the given inputs for this attribute
// type.
FailureOr<Value> MMAAttr::buildMmaOperation(OpBuilder &builder, Location loc,
Expand Down Expand Up @@ -643,6 +672,37 @@ FailureOr<Value> MMAAttr::buildMmaOperation(OpBuilder &builder, Location loc,
rhs, acc)
.getResult();
}
case MMAIntrinsic::VMFMA_F32_16x16x32_F16:
case MMAIntrinsic::VMFMA_F32_32x32x16_F16: {
// Generate mfma's for K with unrolled kernels.
const int64_t unrollKFactor = 2;
auto [m, n, k] = getMNKShape();
// Compute actual/native intrinsic's K size.
int64_t nativeKSize = k / unrollKFactor;

auto [aType, bType, cType] = getABCVectorTypes();
if (aType.getShape()[0] != bType.getShape()[0]) {
// Currently only support case where lhs and rhs
// has same vectorWidth.
return failure();
}
int64_t vectorWidth = aType.getShape()[0] / unrollKFactor;
for (int i = 0; i < unrollKFactor; i++) {
int64_t offset = vectorWidth * i;
Value sliced_lhs = builder.create<vector::ExtractStridedSliceOp>(
loc, lhs, ArrayRef<int64_t>{offset}, ArrayRef<int64_t>{vectorWidth},
ArrayRef<int64_t>{1});
Value sliced_rhs = builder.create<vector::ExtractStridedSliceOp>(
loc, rhs, ArrayRef<int64_t>{offset}, ArrayRef<int64_t>{vectorWidth},
ArrayRef<int64_t>{1});
acc = builder
.create<amdgpu::MFMAOp>(loc, resultType, m, n, nativeKSize,
getBlockSize(), sliced_lhs, sliced_rhs,
acc)
.getResult();
}
return acc;
}
case MMAIntrinsic::MFMA_I32_16x16x16_I8:
case MMAIntrinsic::MFMA_F32_16x16x16_F16:
case MMAIntrinsic::MFMA_F32_16x16x16_BF16:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,8 @@ def IREEGPU_MMAAttr : IREEGPU_MmaVectorLayoutAttr<"MMA", "MMAIntrinsicAttr"> {
MMASingleSubgroupLayout getASingleSubgroupLayout() const;
MMASingleSubgroupLayout getBSingleSubgroupLayout() const;
MMASingleSubgroupLayout getCSingleSubgroupLayout() const;

SmallVector<MMAIntrinsic> getVirtualIntrinsics() const;
}];
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,13 @@ class IREEGPU_I32MmaEnumAttr<string name, string summary, list<I32EnumAttrCase>
let genSpecializedAttr = 0;
}

// Format: <kind>_<output-type>_<M>x<N>x<K>_<input-type>
// Format: <virtual><kind>_<output-type>_<M>x<N>x<K>_<input-type>
//
// "virtual": Prefixes intrinsic with "V" to represent Non native-MFMA
// emulating a larger MMA with smaller ones. This is useful
// to interleave reads in K-dim, S.T we can have wider reads
// or align layouts between matmuls.
//
// Values: 0xABCD where:
// * A = vendor:
// * 0 = AMD
Expand All @@ -121,6 +127,8 @@ class IREEGPU_I32MmaEnumAttr<string name, string summary, list<I32EnumAttrCase>
def MFMA_F32_16x16x4_F32 : I32EnumAttrCase<"MFMA_F32_16x16x4_F32", 0x0900>;
def MFMA_F32_16x16x16_F16 : I32EnumAttrCase<"MFMA_F32_16x16x16_F16", 0x0910>;
def MFMA_F32_32x32x8_F16 : I32EnumAttrCase<"MFMA_F32_32x32x8_F16", 0x0911>;
def VMFMA_F32_16x16x32_F16 : I32EnumAttrCase<"VMFMA_F32_16x16x32_F16", 0x0912>;
def VMFMA_F32_32x32x16_F16 : I32EnumAttrCase<"VMFMA_F32_32x32x16_F16", 0x0913>;
def MFMA_F32_16x16x16_BF16 : I32EnumAttrCase<"MFMA_F32_16x16x16_BF16", 0x0920>;
def MFMA_F32_32x32x8_BF16 : I32EnumAttrCase<"MFMA_F32_32x32x8_BF16", 0x0921>;
def MFMA_F32_16x16x32_F8E5M2FNUZ : I32EnumAttrCase<"MFMA_F32_16x16x32_F8E5M2FNUZ", 0x0930>;
Expand All @@ -145,6 +153,8 @@ def IREEGPU_MMAIntrinsic : IREEGPU_I32MmaEnumAttr<"MMAIntrinsic",
MFMA_F32_16x16x4_F32,
MFMA_F32_16x16x16_F16,
MFMA_F32_32x32x8_F16,
VMFMA_F32_16x16x32_F16,
VMFMA_F32_32x32x16_F16,
MFMA_F32_16x16x16_BF16,
MFMA_F32_32x32x8_BF16,
MFMA_F32_16x16x32_F8E4M3FNUZ,
Expand Down
Loading

0 comments on commit bb542ee

Please sign in to comment.