Skip to content

Commit

Permalink
onnx.MaxPool add atenMaxPool1d lowering support (#3452)
Browse files Browse the repository at this point in the history
fixes #3422
  • Loading branch information
PhaneeshB authored Jun 13, 2024
1 parent 39d882f commit 919b599
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 7 deletions.
14 changes: 10 additions & 4 deletions lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -565,15 +565,15 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
Value cstCeilMode =
rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), ceilMode);

if (rank == 3)
return rewriter.notifyMatchFailure(binder.op,
"Unimplemented: AtenMaxPool1dOp");

if (binder.op->getNumResults() == 2) {
Torch::ValueTensorType resultTypeIndices;
if (binder.tensorResultTypeAtIndex(resultTypeIndices, 1))
return failure();

if (rank == 3)
return rewriter.notifyMatchFailure(
binder.op, "Unimplemented: AtenMaxPool1dWithIndicesOp");

if (rank == 4) {
rewriter.replaceOpWithNewOp<Torch::AtenMaxPool2dWithIndicesOp>(
binder.op, resultTypeOut, resultTypeIndices, operand,
Expand All @@ -589,6 +589,12 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
return success();
}
} else {
if (rank == 3) {
rewriter.replaceOpWithNewOp<Torch::AtenMaxPool1dOp>(
binder.op, resultTypeOut, operand, kernelSizeList, stridesList,
paddingList, dilationsList, cstCeilMode);
return success();
}
if (rank == 4) {
rewriter.replaceOpWithNewOp<Torch::AtenMaxPool2dOp>(
binder.op, resultTypeOut, operand, kernelSizeList, stridesList,
Expand Down
3 changes: 0 additions & 3 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -2418,10 +2418,7 @@
"LogSoftmaxBackwardModule_basic",
"MaskedScatterStaticBasic_basic",
"MaxPool1dCeilModeTrueModule_basic",
"MaxPool1dEmptyStrideStaticModule_basic",
"MaxPool1dModule_basic",
"MaxPool1dStaticCeilModeTrueModule_basic",
"MaxPool1dStaticModule_basic",
"MaxPool2dCeilModeTrueModule_basic",
"MaxPool2dModule_basic",
"MaxPool2dWithIndicesAllOnesModule_basic",
Expand Down

0 comments on commit 919b599

Please sign in to comment.