diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/APreAgg.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/APreAgg.java index 86a71103f58..f313663b341 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/APreAgg.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/APreAgg.java @@ -224,7 +224,7 @@ else if(shouldPreAggregateLeft(lhs)) {// left preAgg DictLibMatrixMult.MMDicts(lDict, lhsPA, leftIdx, rightIdx, result); } else {// right preAgg - final IDictionary rhsPA = preAggregateThatIndexStructure(lhs); + final IDictionary rhsPA = this.preAggregateThatIndexStructure(lhs); if(rhsPA != null) DictLibMatrixMult.MMDicts(rhsPA, rDict, leftIdx, rightIdx, result); } diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupDDC.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupDDC.java index e804bf855b6..5ee4c889462 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupDDC.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupDDC.java @@ -687,7 +687,7 @@ private void leftMMIdentityPreAggregateDenseSingleRow(double[] values, int pos, for(int rc = cl; rc < cu; rc++, pos++) { final int idx = _data.getIndex(rc); if(idx != nVal) - values2[_colIndexes.get(_data.getIndex(rc))] += values[pos]; + values2[_colIndexes.get(idx)] += values[pos]; } } else { diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/IdentityDictionary.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/IdentityDictionary.java index b2642dd0a53..924e04cf578 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/IdentityDictionary.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/IdentityDictionary.java @@ -651,11 +651,13 @@ public void MMDictScaling(IDictionary right, IColIndex rowsLeft, IColIndex colsR @Override public void MMDictDense(double[] left, IColIndex rowsLeft, IColIndex colsRight, MatrixBlock result) { // similar to fused transpose left into right locations. + final int leftSide = rowsLeft.size(); + final int colsOut = result.getNumColumns(); final int commonDim = Math.min(left.length / leftSide, nRowCol); final double[] resV = result.getDenseBlockValues(); for(int i = 0; i < leftSide; i++) { // rows in left side - final int offOut = rowsLeft.get(i) * commonDim; + final int offOut = rowsLeft.get(i) * colsOut; final int leftOff = i; for(int j = 0; j < commonDim; j++) { // cols in left side skipping empty from identity resV[offOut + colsRight.get(j)] += left[leftOff + j * leftSide]; diff --git a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibLeftMultBy.java b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibLeftMultBy.java index 8812c6d4cc5..a124bad3100 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibLeftMultBy.java +++ b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibLeftMultBy.java @@ -261,6 +261,7 @@ private static MatrixBlock leftMultByCompressedTransposedMatrixSingleThread(Comp // Force dense output ret.setNonZeros((long) ret.getNumRows() * ret.getNumColumns()); ret.allocateDenseBlock(); + for(int j = 0; j < fLeft.size(); j++) for(int i = 0; i < fRight.size(); i++) fRight.get(i).leftMultByAColGroup(fLeft.get(j), ret, sd); @@ -314,8 +315,6 @@ private static MatrixBlock LMM(List colGroups, MatrixBlock that, Matr else rowSums = that.rowSum(k).getDenseBlockValues(); - // final double multTime = t.stop(); - // add the correction layer for the subtracted common values. if(rowSums != null) { if(ret.isEmpty()) @@ -324,11 +323,6 @@ private static MatrixBlock LMM(List colGroups, MatrixBlock that, Matr ret.sparseToDense(); outerProduct(rowSums, constV, ret, k); } - - // final double outerProd = t.stop(); - // if(LOG.isDebugEnabled()) { - // LOG.debug(String.format("LLM: filter: %10f Mult: %10f outer: %10f", filterGroupsTime, multTime, outerProd)); - // } } else { CLALibUtils.splitPreAgg(colGroups, noPreAggGroups, preAggGroups); diff --git a/src/test/java/org/apache/sysds/test/component/compress/colgroup/CustomColGroupTest.java b/src/test/java/org/apache/sysds/test/component/compress/colgroup/CustomColGroupTest.java index 5135461f891..89cb31c964c 100644 --- a/src/test/java/org/apache/sysds/test/component/compress/colgroup/CustomColGroupTest.java +++ b/src/test/java/org/apache/sysds/test/component/compress/colgroup/CustomColGroupTest.java @@ -22,14 +22,18 @@ import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; +import org.apache.commons.lang3.NotImplementedException; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.apache.sysds.runtime.compress.colgroup.AColGroup; +import org.apache.sysds.runtime.compress.colgroup.ColGroupDDC; import org.apache.sysds.runtime.compress.colgroup.ColGroupEmpty; import org.apache.sysds.runtime.compress.colgroup.ColGroupSDCSingleZeros; +import org.apache.sysds.runtime.compress.colgroup.dictionary.Dictionary; import org.apache.sysds.runtime.compress.colgroup.dictionary.PlaceHolderDict; import org.apache.sysds.runtime.compress.colgroup.indexes.ColIndexFactory; import org.apache.sysds.runtime.compress.colgroup.indexes.IColIndex; +import org.apache.sysds.runtime.compress.colgroup.mapping.MapToFactory; import org.apache.sysds.runtime.compress.colgroup.offset.OffsetFactory; import org.junit.Test; @@ -65,4 +69,15 @@ public void appendEmptyToSDCZero2() { assertEquals(((ColGroupSDCSingleZeros) r).getNumRows(), 7 * 20); } + + @Test(expected = NotImplementedException.class) + public void preAggSparseError() { + + AColGroup g = ColGroupDDC.create(ColIndexFactory.create(3), + Dictionary.create(new double[] {1, 2, 3, 4, 5, 6, 7, 8, 9}), + MapToFactory.create(new int[] {0, 0, 0, 1, 1, 1, 2, 2, 2}, 3), null); + + ((ColGroupDDC)g).preAggregateSparse(null, null, 0, 3, 1, 2); + + } } diff --git a/src/test/java/org/apache/sysds/test/component/compress/lib/CLALibLMMTest.java b/src/test/java/org/apache/sysds/test/component/compress/lib/CLALibLMMTest.java index 905ef59de25..383d31cab95 100644 --- a/src/test/java/org/apache/sysds/test/component/compress/lib/CLALibLMMTest.java +++ b/src/test/java/org/apache/sysds/test/component/compress/lib/CLALibLMMTest.java @@ -154,6 +154,32 @@ public static Collection data() { mb = CompressedMatrixBlock.getUncompressed(cmb); genTests(tests, mb, cmb, "Identity"); + + d = MappingTestUtil.createRandomMap(100, 10, new Random(23)); + idg = ColGroupDDC.create(ColIndexFactory.createI(0,1,2,3,4,6,7,8,9,10), id, d, null); + cmb = new CompressedMatrixBlock(100, 11); + cmb.allocateColGroup(idg); + mb = CompressedMatrixBlock.getUncompressed(cmb); + genTests(tests, mb, cmb, "Identity2"); + + id = new IdentityDictionary(10, true); + + // continuous index range + d = MappingTestUtil.createRandomMap(100, 11, new Random(33)); + idg = ColGroupDDC.create(ColIndexFactory.create(10), id, d, null); + cmb = new CompressedMatrixBlock(100, 10); + cmb.allocateColGroup(idg); + mb = CompressedMatrixBlock.getUncompressed(cmb); + genTests(tests, mb, cmb, "Identity_empty"); + + // not continuous. + d = MappingTestUtil.createRandomMap(100, 11, new Random(33)); + idg = ColGroupDDC.create(ColIndexFactory.createI(0,1,2,3,4,6,7,8,9,10), id, d, null); + cmb = new CompressedMatrixBlock(100, 11); + cmb.allocateColGroup(idg); + mb = CompressedMatrixBlock.getUncompressed(cmb); + genTests(tests, mb, cmb, "Identity_empty2"); + AColGroup empty = new ColGroupEmpty(ColIndexFactory.create(10)); cmb = new CompressedMatrixBlock(100, 10); cmb.allocateColGroup(empty); @@ -415,6 +441,7 @@ private static void execTlR(MatrixBlock mb1, CompressedMatrixBlock cmb1, MatrixB tmb2 = cmb1; MatrixBlock cRet = CLALibMatrixMult.matrixMultiply(tmb2, cmb1, ret, k, true, false); MatrixBlock uRet = LibMatrixMult.matrixMult(mb2, mb1, k); + compare(cRet, uRet); } diff --git a/src/test/java/org/apache/sysds/test/component/compress/lib/CLALibSelectionMultTest.java b/src/test/java/org/apache/sysds/test/component/compress/lib/CLALibSelectionMultTest.java index 980a1558506..33a4ecbd8da 100644 --- a/src/test/java/org/apache/sysds/test/component/compress/lib/CLALibSelectionMultTest.java +++ b/src/test/java/org/apache/sysds/test/component/compress/lib/CLALibSelectionMultTest.java @@ -154,6 +154,22 @@ public static Collection data() { mb = CompressedMatrixBlock.getUncompressed(cmb); genTests(tests, mb, cmb, "Identity"); + d = MappingTestUtil.createRandomMap(100, 10, new Random(23)); + idg = ColGroupDDC.create(ColIndexFactory.createI(0,1,2,3,4,6,7,8,9,10), id, d, null); + cmb = new CompressedMatrixBlock(100, 11); + cmb.allocateColGroup(idg); + mb = CompressedMatrixBlock.getUncompressed(cmb); + genTests(tests, mb, cmb, "Identity2"); + + id = new IdentityDictionary(10, true); + + d = MappingTestUtil.createRandomMap(100, 11, new Random(33)); + idg = ColGroupDDC.create(ColIndexFactory.createI(0,1,2,3,4,6,7,8,9,10), id, d, null); + cmb = new CompressedMatrixBlock(100, 11); + cmb.allocateColGroup(idg); + mb = CompressedMatrixBlock.getUncompressed(cmb); + genTests(tests, mb, cmb, "Identity_empty"); + AColGroup empty = new ColGroupEmpty(ColIndexFactory.create(10)); cmb = new CompressedMatrixBlock(100, 10); cmb.allocateColGroup(empty);