Skip to content

Commit

Permalink
where-not optimization
Browse files Browse the repository at this point in the history
Signed-off-by: Ananya <[email protected]>
  • Loading branch information
ananyamukh6 committed Sep 3, 2023
1 parent 0e49375 commit 0880593
Show file tree
Hide file tree
Showing 3 changed files with 84 additions and 0 deletions.
2 changes: 2 additions & 0 deletions onnxoptimizer/pass_registry.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
#include "onnxoptimizer/passes/fuse_consecutive_unsqueezes.h"
#include "onnxoptimizer/passes/eliminate_nop_with_unit.h"
#include "onnxoptimizer/passes/rewrite_input_dtype.h"
#include "onnxoptimizer/passes/rewrite_where.h"

namespace ONNX_NAMESPACE {
namespace optimization {
Expand Down Expand Up @@ -118,6 +119,7 @@ struct GlobalPassRegistry {
registerPass<EliminateDuplicateInitializer>();
registerPass<AdjustSliceAndMatmul>();
registerPass<RewriteInputDtype>();
registerPass<RewriteWhere>();
}

~GlobalPassRegistry() {
Expand Down
56 changes: 56 additions & 0 deletions onnxoptimizer/passes/rewrite_where.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
/*
* SPDX-License-Identifier: Apache-2.0
*/

// ATTENTION: The code in this file is highly EXPERIMENTAL.
// Adventurous users should note that the APIs will probably change.

#pragma once

#include "onnxoptimizer/pass.h"
#include "onnxoptimizer/passes/pass_util.h"

namespace ONNX_NAMESPACE {
namespace optimization {

// where(not(b), x, y) -> where(b, y, x)
// https://github.com/microsoft/onnxruntime/blob/v1.15.1/onnxruntime/core/optimizer/not_where_fusion.h
struct RewriteWhere final : public PredicateBasedPass {
explicit RewriteWhere()
: PredicateBasedPass(PassType::Nop, PassEfficiency::Partial,
PassOptimizationType::Compute) {}

std::string getPassName() const override {
return "rewrite_where";
}

bool patternMatchPredicate(Node* node) override {
bool isWhere = CheckKind(node, Symbol("Where"));
if (isWhere) {
return CheckKind(node->inputs()[0]->node(), Symbol("Not"));
}
return false;
}
bool runTransform(Node* node, Graph& graph,
NodeDestroyType& destroy_current) override {
destroy_current = NodeDestroyType::DestroyZero;
Node* previous_node = node->input(0)->node();
if (previous_node->output()->uses().size() == 1) {
const bool replacing_success =
tryReplacingAllUsesWith(node->input(0), previous_node->input(0));
if (!replacing_success) {
return false;
}
auto x = node->inputs()[1];
auto y = node->inputs()[2];
node->replaceInput(1, y);
node->replaceInput(2, x);
previous_node->destroy();
return true;
}
return false;
}
};

} // namespace optimization
} // namespace ONNX_NAMESPACE
26 changes: 26 additions & 0 deletions onnxoptimizer/test/optimizer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4597,6 +4597,32 @@ def test_eliminate_consecutive_idempotent_op(self):
assert optimized_model.graph.node[0].op_type == "Constant"
assert optimized_model.graph.node[1].op_type == "Reshape"

def test_rewrite_where(self):
model = parser.parse_model("""
<
ir_version: 7,
opset_import:["": 11]
>
agraph (bool[4] A, float[4] X, float[4] Y) => (float[4] F, float[4] H)
{
B = Not(A)
Z = Where(B, X, Y)
F = Sign(Z)
M = And(A,A)
G = Where(M, X, Y)
H = Sign(G)
}
""")

optimized_model = self._optimized(
model,["rewrite_where"], True)

assert len(optimized_model.graph.node) == 5
assert set([i.op_type for i in optimized_model.graph.node]) == {'Where', 'And', 'Sign'}
assert optimized_model.graph.node[0].input == ['A', 'Y', 'X']
assert optimized_model.graph.node[3].input == ['M', 'X', 'Y']



if __name__ == "__main__":
unittest.main()

0 comments on commit 0880593

Please sign in to comment.