Skip to content

Commit

Permalink
Remove redundant add_mul_fusion
Browse files Browse the repository at this point in the history
The same job is done by canonicalize before this rewrite is ever called.
  • Loading branch information
Ricardo Vieira committed Oct 7, 2022
1 parent 9a4c2c3 commit 9b61381
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 826 deletions.
61 changes: 0 additions & 61 deletions aesara/tensor/rewriting/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,6 @@
register_uncanonicalize,
register_useless,
)
from aesara.tensor.rewriting.elemwise import FusionOptimizer, fuse_seqopt
from aesara.tensor.shape import Shape, Shape_i
from aesara.tensor.subtensor import Subtensor
from aesara.tensor.type import (
Expand Down Expand Up @@ -2843,66 +2842,6 @@ def check_input(inputs):
return [ret]


def local_add_mul_fusion(fgraph, node):
"""Fuse consecutive add or mul in one such node with more inputs.
It is better to fuse add/mul that way then in a Composite node as
this make the inner graph of the Composite smaller. This allow to
put more computation in a Composite before hitting the max
recursion limit when pickling Composite.
"""
if not isinstance(node.op, Elemwise) or not isinstance(
node.op.scalar_op, (aes.Add, aes.Mul)
):
return False

s_op = node.op.scalar_op.__class__
new_inp = []
fused = False
nb_inputs = len(node.inputs)
max_inputs = float("inf")
if hasattr(node.op, "max_inputs"):
max_inputs = node.op.max_inputs(node)
for inp in node.inputs:
if (
inp.owner
and isinstance(inp.owner.op, Elemwise)
and isinstance(inp.owner.op.scalar_op, s_op)
and
# Do not duplicate the operation.
len(fgraph.clients[inp]) == 1
and (nb_inputs + len(inp.owner.inputs) - 1) <= max_inputs
):
new_inp.extend(inp.owner.inputs)
fused = True
else:
new_inp.append(inp)

# We can not compare the number of inputs as Mul and Add could have
# 0 or 1 inputs in some corner cases.
if fused:
output = node.op(*new_inp)
copy_stack_trace(node.outputs[0], output)

# Do the recursion here to help lower the number of
# FusionOptimizer iteration.
if output.owner:
output2 = local_add_mul_fusion(fgraph, output.owner)
if output2:
return output2
return [output]


fuse_seqopt.register(
"local_add_mul_fusion",
FusionOptimizer(local_add_mul_fusion),
"fast_run",
"fusion",
position=0,
)


def _skip_mul_1(r):
if r.owner and r.owner.op == mul:
not_is_1 = [i for i in r.owner.inputs if not _is_1(i)]
Expand Down
42 changes: 27 additions & 15 deletions tests/tensor/rewriting/test_elemwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -998,23 +998,33 @@ def test_big_fusion(self):
for node in dlogp.maker.fgraph.toposort()
)

def test_add_mul_fusion_inplace(self):

rewrites = RewriteDatabaseQuery(
include=[
"local_elemwise_fusion",
"composite_elemwise_fusion",
"canonicalize",
"inplace",
],
exclude=["cxx_only", "BlasOpt"],
)

mode = Mode(self.mode.linker, rewrites)
def test_add_mul_fusion_precedence(self):
"""Test that additions and multiplications are "fused together" before
a `Composite` `Op` is introduced. This fusion is done by canonicalization
"""
x, y, z = vectors("x", "y", "z")
out = log((x + y + z) * (x * y * z))
aesara.config.optimizer_verbose = True
f = aesara.function([x, y, z], out, mode=self.mode)
# There should be a single Composite Op
nodes = f.maker.fgraph.apply_nodes
assert len(nodes) == 1
(node,) = nodes
assert isinstance(node.op, Elemwise)
scalar_op = node.op.scalar_op
assert isinstance(scalar_op, Composite)
assert {node.op for node in scalar_op.fgraph.apply_nodes} == {
aes.mul,
aes.add,
aes.log,
}

def test_add_mul_fusion_inplace(self):
# TODO: This has nothing to do with the FusionOptimizer, as the "fusion"
# is done by canonicalize
x, y, z = dmatrices("xyz")
out = dot(x, y) + x + y + z
f = function([x, y, z], out, mode=mode)
f = function([x, y, z], out, mode=self.mode)
topo = [n for n in f.maker.fgraph.toposort()]
assert len(topo) == 2
assert topo[-1].op.inplace_pattern
Expand All @@ -1026,7 +1036,9 @@ def test_add_mul_fusion_inplace(self):

# TODO: Do we really need to do this?
_ = f(
np.random.random((5, 5)), np.random.random((5, 5)), np.random.random((5, 5))
np.random.random((5, 5)),
np.random.random((5, 5)),
np.random.random((5, 5)),
)

@pytest.mark.skipif(not config.cxx, reason="No cxx compiler")
Expand Down
Loading

0 comments on commit 9b61381

Please sign in to comment.