Skip to content

Commit

Permalink
Undo revert of llvm/llvm-project#104668 (#18451)
Browse files Browse the repository at this point in the history
Signed-off-by: MaheshRavishankar <[email protected]>
Co-authored-by: Matthias Springer <[email protected]>
  • Loading branch information
MaheshRavishankar and matthias-springer authored Sep 6, 2024
1 parent d55785d commit 767e288
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 60 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -516,7 +516,6 @@ func.func @reverse_unsigned(%arg0: tensor<3x5xui32>) -> tensor<3x5xui32> {
}
// CHECK-LABEL: func.func @reverse_unsigned
// CHECK-SAME: %[[IN:[a-zA-Z0-9]+]]
// CHECK: %[[BITCAST:.+]] = builtin.unrealized_conversion_cast %[[IN]] : tensor<3x5xui32> to tensor<3x5xi32>
// CHECK: %[[INIT:.+]] = tensor.empty() : tensor<3x5xui32>
// CHECK: %[[GEN:.+]] = linalg.generic {indexing_maps = [#map], iterator_types = ["parallel", "parallel"]} outs(%[[INIT]] : tensor<3x5xui32>)
// CHECK: %[[SAME_DIM:.+]] = linalg.index 0 : index
Expand Down Expand Up @@ -654,7 +653,6 @@ func.func @prefix(%arg0: tensor<7x5xi32>, %arg1: tensor<i32>) -> tensor<7x5xi32>
}) {base_dilations = array<i64: 1, 1>, padding = dense<[[0, 0], [4, 0]]> : tensor<2x2xi64>, window_dilations = array<i64: 1, 1>, window_dimensions = array<i64: 1, 5>, window_strides = array<i64: 1, 1>} : (tensor<7x5xi32>, tensor<i32>) -> tensor<7x5xi32>
return %reduce : tensor<7x5xi32>
}
// CHECK: %extracted = tensor.extract %[[ARG1]][] : tensor<i32>
// CHECK: %[[OUT0:.+]] = tensor.empty() : tensor<7x5xi32>
// CHECK: %[[OUT1:.+]] = tensor.empty() : tensor<7xi32>
// CHECK: %[[FILL:.+]] = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel"]} ins(%[[ARG1]] : tensor<i32>) outs(%[[OUT1]] : tensor<7xi32>)
Expand Down
66 changes: 9 additions & 57 deletions compiler/src/iree/compiler/Codegen/Common/TypePropagationPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -229,29 +229,13 @@ struct GenericOpTypePropagation
signatureConverter.addInputs(index, legalizedArgType.value());
}
rewriter.applySignatureConversion(&modifiedOpRegion.front(),
signatureConverter);
signatureConverter, getTypeConverter());

// 6. Introduce scalar conversion operations to convert back to the
// original scalar type.
{
OpBuilder::InsertionGuard g(rewriter);
Block *entryBlock = modifiedOp.getBlock();
for (auto modifiedOperandIndex : modifiedOperandIndex) {
OpOperand *modifiedOpOperand =
&modifiedOp->getOpOperand(modifiedOperandIndex);
BlockArgument source =
modifiedOp.getMatchingBlockArgument(modifiedOpOperand);
Type destType = getElementTypeOrSelf(
genericOp.getOperand(modifiedOperandIndex).getType());

// 6a. If the value of the argument is used the argument is in the
// legalized type. Convert it to a value that is in the original
// element type for replacement of all uses in the block.
rewriter.setInsertionPointToStart(entryBlock);
Value replacement =
convertElementType(rewriter, source.getLoc(), destType, source);
rewriter.replaceUsesOfBlockArgument(source, replacement);
}

// 6b. If any of the operands modified were outputs, the yield values
// need to be modified as well.
Expand Down Expand Up @@ -372,27 +356,13 @@ struct IREELinalgExtScatterTypePropagation
signatureConverter.addInputs(0, legalizedArgType.value());
signatureConverter.addInputs(1, legalizedArgType.value());
rewriter.applySignatureConversion(&modifiedOpRegion.front(),
signatureConverter);
signatureConverter, getTypeConverter());

{
// Introduce scalar conversion operations to convert back to the original
// scalar type.
OpBuilder::InsertionGuard g(rewriter);
Block *entryBlock = &modifiedOp->getRegion(0).getBlocks().front();
BlockArgument inputArg = entryBlock->getArgument(0);
BlockArgument outputArg = entryBlock->getArgument(1);

auto destType = getElementTypeOrSelf(inputType);
rewriter.setInsertionPointToStart(entryBlock);

Value replacementInput =
convertElementType(rewriter, inputArg.getLoc(), destType, inputArg);
rewriter.replaceUsesOfBlockArgument(entryBlock->getArgument(0),
replacementInput);
Value replacementOutput =
convertElementType(rewriter, outputArg.getLoc(), destType, outputArg);
rewriter.replaceUsesOfBlockArgument(entryBlock->getArgument(1),
replacementOutput);

// If the output is of an illegal type, the yield value needs to be
// modified
Expand Down Expand Up @@ -449,31 +419,7 @@ struct IREELinalgExtSortTypePropagation
signatureConverter.addInputs(index, legalizedArgType.value());
}
rewriter.applySignatureConversion(&modifiedOpRegion.front(),
signatureConverter);

{
// Introduce scalar conversion operations to convert back to the original
// scalar type.
OpBuilder::InsertionGuard g(rewriter);
Block *entryBlock = &modifiedOp->getRegion(0).getBlocks().front();
for (auto [index, operand] : llvm::enumerate(sortOp->getOpOperands())) {
BlockArgument firstInputArg = entryBlock->getArgument(index * 2);
BlockArgument secondInputArg = entryBlock->getArgument(index * 2 + 1);

auto destType = getElementTypeOrSelf(operand.get().getType());
rewriter.setInsertionPointToStart(entryBlock);
if (destType != getElementTypeOrSelf(legalizedResultTypes[index])) {
Value replacementFirstInput = convertElementType(
rewriter, firstInputArg.getLoc(), destType, firstInputArg);
rewriter.replaceUsesOfBlockArgument(firstInputArg,
replacementFirstInput);
Value replacementSecondInput = convertElementType(
rewriter, secondInputArg.getLoc(), destType, secondInputArg);
rewriter.replaceUsesOfBlockArgument(secondInputArg,
replacementSecondInput);
}
}
}
signatureConverter, getTypeConverter());
rewriter.replaceOp(sortOp, modifiedOp->getResults());
return success();
}
Expand Down Expand Up @@ -580,6 +526,12 @@ struct TypePropagationPass final
RewritePatternSet patterns(context);

TypePropagationTypeConverter typeConverter;
typeConverter.addArgumentMaterialization(
[&](OpBuilder &builder, Type type, ValueRange inputs, Location loc) {
assert(inputs.size() == 1 && "expected exactly one input");
return convertElementType(builder, loc, type, inputs[0]);
});

patterns.insert<
ConstantOpTypeConversion, ForwardSourceType<arith::ExtUIOp>,
ForwardSourceType<arith::TruncIOp>, GenericOpTypePropagation,
Expand Down

0 comments on commit 767e288

Please sign in to comment.