Skip to content

Commit

Permalink
[SYSTEMDS-3785] Fix rewrite test for simplify bushy binary ops
Browse files Browse the repository at this point in the history
This patch resolves a remaining FIXME after improved rewrite code
coverage by fixing the expressions and other rewrite configs so the
test actually triggers the existing rewrite.
  • Loading branch information
mboehm7 committed Oct 24, 2024
1 parent 63b99e5 commit 7de3657
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 21 deletions.
1 change: 1 addition & 0 deletions src/main/java/org/apache/sysds/hops/OptimizerUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,7 @@ public enum MemoryManager {
* all sum-product related rewrites.
*/
public static boolean ALLOW_SUM_PRODUCT_REWRITES = true;
public static boolean ALLOW_SUM_PRODUCT_REWRITES2 = true;

/**
* Enables additional mmchain optimizations. in the future, this might be merged with
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,8 @@ public ProgramRewriter(boolean staticRewrites, boolean dynamicRewrites)
}
if ( OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES) {
_dagRuleSet.add( new RewriteMatrixMultChainOptimization() ); //dependency: cse
_dagRuleSet.add( new RewriteElementwiseMultChainOptimization() ); //dependency: cse
if( OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES2 )
_dagRuleSet.add( new RewriteElementwiseMultChainOptimization()); //dependency: cse
}
if(OptimizerUtils.ALLOW_ADVANCED_MMCHAIN_REWRITES){
_dagRuleSet.add( new RewriteMatrixMultChainOptimizationTranspose() ); //dependency: cse
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -855,8 +855,8 @@ private static Hop simplifyDistributiveBinaryOperation( Hop parent, Hop hi, int
}

/**
* (X*(Y*(Z%*%v))) -> (X*Y)*(Z%*%v)
* (X+(Y+(Z%*%v))) -> (X+Y)+(Z%*%v)
* t(Z)%*%(X*(Y*(Z%*%v))) -> t(Z)%*%(X*Y)*(Z%*%v)
* t(Z)%*%(X+(Y+(Z%*%v))) -> t(Z)%*%((X+Y)+(Z%*%v))
*
* Note: Restriction ba() at leaf and root instead of data at leaf to not reorganize too
* eagerly, which would loose additional rewrite potential. This rewrite has two goals
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import org.apache.sysds.test.AutomatedTestBase;
import org.apache.sysds.test.TestConfiguration;
import org.apache.sysds.test.TestUtils;
import org.junit.Assert;
import org.junit.Test;

import java.util.HashMap;
Expand All @@ -37,7 +38,7 @@ public class RewriteSimplifyBushyBinaryOperationTest extends AutomatedTestBase {
TEST_DIR + RewriteSimplifyBushyBinaryOperationTest.class.getSimpleName() + "/";

private static final int rows = 500;
private static final int cols = 500;
private static final int cols = 100;
private static final double eps = Math.pow(10, -10);

@Override
Expand All @@ -46,28 +47,28 @@ public void setUp() {
addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"R"}));
}

//pattern: t(Z)%*%(X*(Y*(Z%*%v))) -> t(Z)%*%((X*Y)*(Z%*%v))
@Test
public void testBushyBinaryOperationMultNoRewrite() {
testSimplifyBushyBinaryOperation(1, false);
}

@Test
public void testBushyBinaryOperationMultRewrite() { //pattern: (X*(Y*(Z%*%v))) -> (X*Y)*(Z%*%v)
public void testBushyBinaryOperationMultRewrite() {
testSimplifyBushyBinaryOperation(1, true);
}

//pattern: t(Z)%*%(X+(Y+(Z%*%v))) -> t(Z)%*%((X+Y)+(Z%*%v))
@Test
public void testBushyBinaryOperationAddNoRewrite() {
testSimplifyBushyBinaryOperation(2, false);
}

@Test
public void testBushyBinaryOperationAddtRewrite() { //pattern: (X+(Y+(Z%*%v))) -> (X+Y)+(Z%*%v)
public void testBushyBinaryOperationAddtRewrite() {
testSimplifyBushyBinaryOperation(2, true);
}



private void testSimplifyBushyBinaryOperation(int ID, boolean rewrites) {
boolean oldFlag = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION;
try {
Expand All @@ -76,19 +77,21 @@ private void testSimplifyBushyBinaryOperation(int ID, boolean rewrites) {

String HOME = SCRIPT_DIR + TEST_DIR;
fullDMLScriptName = HOME + TEST_NAME + ".dml";
programArgs = new String[] {"-stats", "-args", input("X"), input("Y"), input("Z"), input("v"), String.valueOf(ID), output("R")};
programArgs = new String[] {"-stats", "-explain", "-args",
input("X"), input("Y"), input("Z"), input("v"), String.valueOf(ID), output("R")};
fullRScriptName = HOME + TEST_NAME + ".R";
rCmd = getRCmd(inputDir(), String.valueOf(ID), expectedDir());

OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = rewrites;
//OptimizerUtils.ALLOW_OPERATOR_FUSION = rewrites;
//OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES = rewrites;

OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES2 = false; //disable nary mult
OptimizerUtils.ALLOW_OPERATOR_FUSION = false; //disable emult reordering
//TODO improved phase ordering

//create matrices
double[][] X = getRandomMatrix(rows, cols, -1, 1, 0.60d, 3);
double[][] Y = getRandomMatrix(rows, cols, -1, 1, 0.60d, 5);
double[][] X = getRandomMatrix(rows, 1, -1, 1, 0.60d, 3);
double[][] Y = getRandomMatrix(rows, 1, -1, 1, 0.60d, 5);
double[][] Z = getRandomMatrix(rows, cols, -1, 1, 0.60d, 6);
double[][] v = getRandomMatrix(rows, cols, -1, 1, 0.60d, 8);
double[][] v = getRandomMatrix(cols, 1, -1, 1, 0.60d, 8);
writeInputMatrixWithMTD("X", X, true);
writeInputMatrixWithMTD("Y", Y, true);
writeInputMatrixWithMTD("Z", Z, true);
Expand All @@ -101,15 +104,14 @@ private void testSimplifyBushyBinaryOperation(int ID, boolean rewrites) {
HashMap<MatrixValue.CellIndex, Double> dmlfile = readDMLMatrixFromOutputDir("R");
HashMap<MatrixValue.CellIndex, Double> rfile = readRMatrixFromExpectedDir("R");
TestUtils.compareMatrices(dmlfile, rfile, eps, "Stat-DML", "Stat-R");

/**
* The rewrite in RewriteAlgebraicSimplificationStatic is not entered. Hence, we fail
* the assertions for this rewrite so that we can revisit this issue later.
*/
//FIXME

if( ID == 1 && rewrites ) //check mmchain, enabled by bushy join
Assert.assertTrue(heavyHittersContainsString("mmchain"));
}
finally {
OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = oldFlag;
OptimizerUtils.ALLOW_OPERATOR_FUSION = true;
OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES2 = true;
Recompiler.reinitRecompiler();
}
}
Expand Down

0 comments on commit 7de3657

Please sign in to comment.