diff --git a/experimental/iterators/lib/Conversion/IteratorsToLLVM/IteratorsToLLVM.cpp b/experimental/iterators/lib/Conversion/IteratorsToLLVM/IteratorsToLLVM.cpp index 1b801f5b3fc9..37c2eccbe09d 100644 --- a/experimental/iterators/lib/Conversion/IteratorsToLLVM/IteratorsToLLVM.cpp +++ b/experimental/iterators/lib/Conversion/IteratorsToLLVM/IteratorsToLLVM.cpp @@ -365,9 +365,10 @@ buildNextBody(AccumulateOp op, OpBuilder &builder, Value initialState, [&](OpBuilder &builder, Location loc) { ImplicitLocOpBuilder b(loc, builder); - // Don't modify state; return undef element. - Value nextElement = b.create(elementType); - b.create(ValueRange{initialUpstreamState, nextElement}); + // Don't modify state; return init element. + FuncOp initFunc = op.getInitFunc(); + Value initValue = b.create(initFunc)->getResult(0); + b.create(ValueRange{initialUpstreamState, initValue}); }, /*elseBuilder=*/ [&](OpBuilder &builder, Location loc) { diff --git a/experimental/iterators/test/Conversion/IteratorsToLLVM/accumulate.mlir b/experimental/iterators/test/Conversion/IteratorsToLLVM/accumulate.mlir index 2a511deb186a..5ad3c2bf8916 100644 --- a/experimental/iterators/test/Conversion/IteratorsToLLVM/accumulate.mlir +++ b/experimental/iterators/test/Conversion/IteratorsToLLVM/accumulate.mlir @@ -22,7 +22,7 @@ func.func private @sum_struct(%lhs : !element_type, %rhs : !element_type) -> !el // CHECK-NEXT: %[[V1:.*]] = iterators.extractvalue %[[arg0:.*]][0] : !iterators.state, i1> // CHECK-NEXT: %[[V2:.*]] = iterators.extractvalue %[[arg0]][1] : !iterators.state, i1> // CHECK-NEXT: %[[V3:.*]]:2 = scf.if %[[V2]] -> (!iterators.state, !llvm.struct<(i32)>) { -// CHECK-NEXT: %[[V4:.*]] = llvm.mlir.undef : !llvm.struct<(i32)> +// CHECK-NEXT: %[[V4:.*]] = func.call @zero_struct() : () -> !llvm.struct<(i32)> // CHECK-NEXT: scf.yield %[[V1]], %[[V4]] : !iterators.state, !llvm.struct<(i32)> // CHECK-NEXT: } else { // CHECK-NEXT: %[[V4:.*]] = func.call @zero_struct() : () -> !llvm.struct<(i32)> diff --git a/experimental/iterators/test/Integration/Dialect/Iterators/CPU/accumulate.mlir b/experimental/iterators/test/Integration/Dialect/Iterators/CPU/accumulate.mlir index 539a8fe3d692..975e36352c4a 100644 --- a/experimental/iterators/test/Integration/Dialect/Iterators/CPU/accumulate.mlir +++ b/experimental/iterators/test/Integration/Dialect/Iterators/CPU/accumulate.mlir @@ -1,13 +1,16 @@ // RUN: mlir-proto-opt %s \ // RUN: -convert-iterators-to-llvm \ // RUN: -convert-states-to-llvm \ -// RUN: -convert-func-to-llvm \ -// RUN: -convert-scf-to-cf -convert-cf-to-llvm \ +// RUN: -convert-scf-to-cf \ +// RUN: -arith-bufferize -func-bufferize -tensor-bufferize \ +// RUN: -convert-func-to-llvm -convert-memref-to-llvm \ +// RUN: -cse -reconcile-unrealized-casts \ // RUN: | mlir-cpu-runner -e main -entry-point-result=void \ // RUN: | FileCheck %s !struct_i32 = !llvm.struct<(i32)> !struct_i32i32 = !llvm.struct<(i32, i32)> +!struct_i32i32i32i32 = !llvm.struct<(i32, i32, i32, i32)> !struct_f32 = !llvm.struct<(f32)> func.func private @init_sum_struct() -> !struct_i32 { @@ -84,8 +87,62 @@ func.func @test_accumulate_avg_struct() { return } +func.func private @unpack_i32(%input : !struct_i32) -> i32 { + %i = llvm.extractvalue %input[0 : index] : !struct_i32 + return %i : i32 +} + +func.func private @init_histogram() -> tensor<4xi32> { + %init = arith.constant dense<[0, 0, 0, 0]> : tensor<4xi32> + return %init : tensor<4xi32> +} + +func.func private @accumulate_histogram( + %hist : tensor<4xi32>, %val : i32) -> tensor<4xi32> { + %idx = arith.index_cast %val : i32 to index + %oldCount = tensor.extract %hist[%idx] : tensor<4xi32> + %one = arith.constant 1 : i32 + %newCount = arith.addi %oldCount, %one : i32 + %newHist = tensor.insert %newCount into %hist[%idx] : tensor<4xi32> + return %newHist : tensor<4xi32> +} + +func.func private @tensor_to_struct(%input : tensor<4xi32>) -> !struct_i32i32i32i32 { + %idx0 = arith.constant 0 : index + %idx1 = arith.constant 1 : index + %idx2 = arith.constant 2 : index + %idx3 = arith.constant 3 : index + %i0 = tensor.extract %input[%idx0] : tensor<4xi32> + %i1 = tensor.extract %input[%idx1] : tensor<4xi32> + %i2 = tensor.extract %input[%idx2] : tensor<4xi32> + %i3 = tensor.extract %input[%idx3] : tensor<4xi32> + %structu = llvm.mlir.undef : !struct_i32i32i32i32 + %struct0 = llvm.insertvalue %i0, %structu[0 : index] : !struct_i32i32i32i32 + %struct1 = llvm.insertvalue %i1, %struct0[1 : index] : !struct_i32i32i32i32 + %struct2 = llvm.insertvalue %i2, %struct1[2 : index] : !struct_i32i32i32i32 + %struct3 = llvm.insertvalue %i3, %struct2[3 : index] : !struct_i32i32i32i32 + return %struct3 : !struct_i32i32i32i32 +} + +func.func @test_accumulate_histogram() { + %input = "iterators.constantstream"() + { value = [[0 : i32], [1 : i32], [1 : i32], [2 : i32]] } + : () -> (!iterators.stream) + %unpacked = "iterators.map"(%input) {mapFuncRef = @unpack_i32} + : (!iterators.stream) -> (!iterators.stream) + %accumulated = iterators.accumulate(%unpacked, @init_histogram, + @accumulate_histogram) + : (!iterators.stream) -> !iterators.stream> + %transposed = "iterators.map"(%accumulated) {mapFuncRef = @tensor_to_struct} + : (!iterators.stream>) -> (!iterators.stream) + "iterators.sink"(%transposed) : (!iterators.stream) -> () + // CHECK: (1, 2, 1, 0) + return +} + func.func @main() { call @test_accumulate_sum_struct() : () -> () call @test_accumulate_avg_struct() : () -> () + call @test_accumulate_histogram() : () -> () return }