From 215df464171eb0bb1c65507fa3eb378cda864d3e Mon Sep 17 00:00:00 2001 From: MRGSRT <57044553+MRGSRT@users.noreply.github.com> Date: Tue, 19 Dec 2023 16:45:18 +0100 Subject: [PATCH] optimization for traverse() --- .../hops/estim/EstimatorLayeredGraph.java | 27 +++++++------------ 1 file changed, 10 insertions(+), 17 deletions(-) diff --git a/src/main/java/org/apache/sysds/hops/estim/EstimatorLayeredGraph.java b/src/main/java/org/apache/sysds/hops/estim/EstimatorLayeredGraph.java index 4792aa9ae8f..f997db6503a 100644 --- a/src/main/java/org/apache/sysds/hops/estim/EstimatorLayeredGraph.java +++ b/src/main/java/org/apache/sysds/hops/estim/EstimatorLayeredGraph.java @@ -60,33 +60,26 @@ public DataCharacteristics estim(MMNode root) { List leafs = getMatrices(root, new ArrayList<>()); List ops = getOps(root, new ArrayList<>()); List LGs = new ArrayList<>(); - LayeredGraph ret; - if(ops.stream().allMatch(op -> op.equals(OpCode.MM))) { - ret = new LayeredGraph(leafs, _rounds); - } - else { - traverse(root, LGs); - ret = LGs.get(LGs.size() - 1); - } + LayeredGraph ret = traverse(root); long nnz = ret.estimateNnz(); return root.setDataCharacteristics(new MatrixCharacteristics( ret._nodes.get(0).length, ret._nodes.get(ret._nodes.size() - 1).length, nnz)); } - public void traverse(MMNode node, List LGs) { - if(node.getLeft() == null || node.getRight() == null) return; - traverse(node.getLeft(), LGs); - traverse(node.getRight(), LGs); + public LayeredGraph traverse(MMNode node) { + if(node.getLeft() == null || node.getRight() == null) return null; + LayeredGraph retL = traverse(node.getLeft()); + LayeredGraph retR = traverse(node.getRight()); LayeredGraph ret, left, right; - left = (node.getLeft().getData() == null && !LGs.isEmpty()) - ? LGs.get(LGs.size() - 1) : new LayeredGraph(node.getLeft().getData(), _rounds); - right = (node.getRight().getData() == null && !LGs.isEmpty()) - ? LGs.get(LGs.size() - 1) : new LayeredGraph(node.getRight().getData(), _rounds); + left = (node.getLeft().getData() == null) + ? retL : new LayeredGraph(node.getLeft().getData(), _rounds); + right = (node.getRight().getData() == null) + ? retR : new LayeredGraph(node.getRight().getData(), _rounds); ret = estimInternal(left, right, node.getOp()); - LGs.add(ret); + return ret; } @Override