From 6bd85f4dab1dc34aa507f65925f7e8fe3fcc050f Mon Sep 17 00:00:00 2001 From: Sebastian Baunsgaard Date: Sat, 31 Aug 2024 13:27:30 +0200 Subject: [PATCH] [SYSTEMDS-3771] Compressed Identity Dictionary and Selection Multiply This commit contains the implementation details on LLM refinements for supporting the new Identity dictionaries, that remove the need for many of the matrix multiplications. Furthermore it also contains the implementation details and optimizations for selective Matrix Multiplications of matrices in the left side containing only a single 1 in each row. The implementation there simply decompress the rows associated with the 1, making the overall compressed operation very efficient. The overall implementation further improves the code-coverage of the project by 0.24% --- .../compress/CompressedMatrixBlock.java | 5 + .../runtime/compress/colgroup/AColGroup.java | 39 + .../runtime/compress/colgroup/APreAgg.java | 25 +- .../compress/colgroup/ColGroupConst.java | 11 + .../compress/colgroup/ColGroupDDC.java | 92 ++- .../compress/colgroup/ColGroupDDCFOR.java | 33 +- .../compress/colgroup/ColGroupEmpty.java | 17 +- .../colgroup/ColGroupLinearFunctional.java | 15 +- .../compress/colgroup/ColGroupOLE.java | 21 +- .../compress/colgroup/ColGroupRLE.java | 21 +- .../compress/colgroup/ColGroupSDC.java | 64 +- .../compress/colgroup/ColGroupSDCFOR.java | 12 + .../compress/colgroup/ColGroupSDCSingle.java | 13 + .../colgroup/ColGroupSDCSingleZeros.java | 29 +- .../compress/colgroup/ColGroupSDCZeros.java | 60 +- .../colgroup/ColGroupUncompressed.java | 108 +++ .../compress/colgroup/ColGroupUtils.java | 54 ++ .../colgroup/dictionary/ADictionary.java | 7 + .../dictionary/DictLibMatrixMult.java | 2 +- .../colgroup/dictionary/Dictionary.java | 5 +- .../dictionary/DictionaryFactory.java | 216 ++--- .../colgroup/dictionary/IDictionary.java | 18 +- .../dictionary/IdentityDictionary.java | 86 +- .../dictionary/IdentityDictionarySlice.java | 67 +- .../dictionary/MatrixBlockDictionary.java | 3 + .../colgroup/dictionary/PlaceHolderDict.java | 6 +- .../compress/lib/CLALibDecompress.java | 5 + .../compress/lib/CLALibLeftMultBy.java | 766 ++++++++++++------ .../compress/lib/CLALibMatrixMult.java | 4 +- .../compress/lib/CLALibSelectionMult.java | 249 ++++++ .../readers/ReaderColumnSelection.java | 63 +- .../ReaderColumnSelectionDenseMultiBlock.java | 1 - .../runtime/matrix/data/MatrixBlock.java | 20 + .../java/org/apache/sysds/test/TestUtils.java | 15 +- .../colgroup/ColGroupFactoryTest.java | 4 +- .../colgroup/ColGroupNegativeTests.java | 41 +- .../compress/colgroup/ColGroupTest.java | 157 ++-- .../colgroup/scheme/SchemeTestBase.java | 25 +- .../compress/dictionary/CombineTest.java | 229 ++++-- .../dictionary/CustomDictionaryTest.java | 319 ++++++++ .../compress/dictionary/DictionaryTests.java | 96 ++- .../dictionary/PlaceHolderDictTest.java | 520 ++++++++++++ .../functional/LinearRegNegativeTest.java | 48 ++ .../functional/LinearRegressionTests.java | 1 - .../component/compress/lib/CLALibLMMTest.java | 435 ++++++++++ .../lib/CLALibSelectionMultCustomTest.java | 128 +++ .../compress/lib/CLALibSelectionMultTest.java | 285 +++++++ .../compress/mapping/MappingTestUtil.java | 2 +- 48 files changed, 3854 insertions(+), 588 deletions(-) create mode 100644 src/main/java/org/apache/sysds/runtime/compress/lib/CLALibSelectionMult.java create mode 100644 src/test/java/org/apache/sysds/test/component/compress/dictionary/PlaceHolderDictTest.java create mode 100644 src/test/java/org/apache/sysds/test/component/compress/functional/LinearRegNegativeTest.java create mode 100644 src/test/java/org/apache/sysds/test/component/compress/lib/CLALibLMMTest.java create mode 100644 src/test/java/org/apache/sysds/test/component/compress/lib/CLALibSelectionMultCustomTest.java create mode 100644 src/test/java/org/apache/sysds/test/component/compress/lib/CLALibSelectionMultTest.java diff --git a/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlock.java b/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlock.java index 816bb57f154..68cfc6f9830 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlock.java +++ b/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlock.java @@ -1245,6 +1245,11 @@ public void allocateAndResetSparseBlock(boolean clearNNZ, SparseBlock.Type stype throw new DMLCompressionException("Invalid to allocate block on a compressed MatrixBlock"); } + @Override + public MatrixBlock transpose(int k) { + return getUncompressed().transpose(k); + } + @Override public String toString() { StringBuilder sb = new StringBuilder(); diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/AColGroup.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/AColGroup.java index 1184cc0aec4..f23981c4151 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/AColGroup.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/AColGroup.java @@ -27,6 +27,7 @@ import org.apache.commons.lang3.NotImplementedException; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; +import org.apache.sysds.runtime.compress.colgroup.ColGroupUtils.P; import org.apache.sysds.runtime.compress.colgroup.indexes.IColIndex; import org.apache.sysds.runtime.compress.colgroup.indexes.IColIndex.SliceResult; import org.apache.sysds.runtime.compress.colgroup.scheme.ICLAScheme; @@ -728,6 +729,44 @@ public AColGroup sortColumnIndexes() { */ public abstract AColGroup reduceCols(); + /** + * Selection (left matrix multiply) + * + * @param selection A sparse matrix with "max" a single one in each row all other values are zero. + * @param points The coordinates in the selection matrix to extract. + * @param ret The MatrixBlock to decompress the selected rows into + * @param rl The row to start at in the selection matrix + * @param ru the row to end at in the selection matrix (not inclusive) + */ + public final void selectionMultiply(MatrixBlock selection, P[] points, MatrixBlock ret, int rl, int ru) { + if(ret.isInSparseFormat()) + sparseSelection(selection, points, ret, rl, ru); + else + denseSelection(selection, points, ret, rl, ru); + } + + /** + * Sparse selection (left matrix multiply) + * + * @param selection A sparse matrix with "max" a single one in each row all other values are zero. + * @param points The coordinates in the selection matrix to extract. + * @param ret The Sparse MatrixBlock to decompress the selected rows into + * @param rl The row to start at in the selection matrix + * @param ru the row to end at in the selection matrix (not inclusive) + */ + protected abstract void sparseSelection(MatrixBlock selection, P[] points, MatrixBlock ret, int rl, int ru); + + /** + * Dense selection (left matrix multiply) + * + * @param selection A sparse matrix with "max" a single one in each row all other values are zero. + * @param points The coordinates in the selection matrix to extract. + * @param ret The Dense MatrixBlock to decompress the selected rows into + * @param rl The row to start at in the selection matrix + * @param ru the row to end at in the selection matrix (not inclusive) + */ + protected abstract void denseSelection(MatrixBlock selection, P[] points, MatrixBlock ret, int rl, int ru); + @Override public String toString() { StringBuilder sb = new StringBuilder(); diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/APreAgg.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/APreAgg.java index 86cb58cf54f..86a71103f58 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/APreAgg.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/APreAgg.java @@ -85,14 +85,14 @@ else if(lhs instanceof ColGroupUncompressed) * @return A aggregate dictionary */ public final IDictionary preAggregateThatIndexStructure(APreAgg that) { - final long outputLength = (long)that._colIndexes.size() * this.getNumValues(); + final long outputLength = (long) that._colIndexes.size() * this.getNumValues(); if(outputLength > Integer.MAX_VALUE) throw new NotImplementedException("Not supported pre aggregate of above integer length"); if(outputLength <= 0) // if the pre aggregate output is empty or nothing, return null return null; - + // create empty Dictionary that we slowly fill, hence the dictionary is empty and no check - final Dictionary ret = Dictionary.createNoCheck(new double[(int)outputLength]); + final Dictionary ret = Dictionary.createNoCheck(new double[(int) outputLength]); if(that instanceof ColGroupDDC) preAggregateThatDDCStructure((ColGroupDDC) that, ret); @@ -119,7 +119,7 @@ else if(that instanceof ColGroupRLE) */ public final void preAggregate(MatrixBlock m, double[] preAgg, int rl, int ru) { if(m.isInSparseFormat()) - preAggregateSparse(m.getSparseBlock(), preAgg, rl, ru); + preAggregateSparse(m.getSparseBlock(), preAgg, rl, ru, 0, m.getNumColumns()); else preAggregateDense(m, preAgg, rl, ru, 0, m.getNumColumns()); } @@ -136,7 +136,7 @@ public final void preAggregate(MatrixBlock m, double[] preAgg, int rl, int ru) { */ public abstract void preAggregateDense(MatrixBlock m, double[] preAgg, int rl, int ru, int cl, int cu); - public abstract void preAggregateSparse(SparseBlock sb, double[] preAgg, int rl, int ru); + public abstract void preAggregateSparse(SparseBlock sb, double[] preAgg, int rl, int ru, int cl, int cu); protected abstract void preAggregateThatDDCStructure(ColGroupDDC that, Dictionary ret); @@ -160,11 +160,13 @@ private void tsmmAPreAgg(APreAgg lg, MatrixBlock result) { final boolean left = shouldPreAggregateLeft(lg); if(!loggedWarningForDirect && shouldDirectMultiply(lg, leftIdx.size(), rightIdx.size(), left)) { loggedWarningForDirect = true; - LOG.warn("Not implemented direct tsmm colgroup: " + lg.getClass().getSimpleName() + " %*% " + this.getClass().getSimpleName() ); + LOG.warn("Not implemented direct tsmm colgroup: " + lg.getClass().getSimpleName() + " %*% " + + this.getClass().getSimpleName()); } if(left) { final IDictionary lpa = this.preAggregateThatIndexStructure(lg); + if(lpa != null) DictLibMatrixMult.TSMMToUpperTriangle(lpa, _dict, leftIdx, rightIdx, result); } @@ -311,17 +313,20 @@ public void mmWithDictionary(MatrixBlock preAgg, MatrixBlock tmpRes, MatrixBlock // Shallow copy the preAgg to allow sparse PreAgg multiplication but do not remove the original dense allocation // since the dense allocation is reused. final MatrixBlock preAggCopy = new MatrixBlock(); - preAggCopy.copy(preAgg); + preAggCopy.copyShallow(preAgg); final MatrixBlock tmpResCopy = new MatrixBlock(); - tmpResCopy.copy(tmpRes); + tmpResCopy.copyShallow(tmpRes); // Get dictionary matrixBlock final MatrixBlock dict = getDictionary().getMBDict(_colIndexes.size()).getMatrixBlock(); if(dict != null) { // Multiply - LibMatrixMult.matrixMult(preAggCopy, dict, tmpResCopy, k); - ColGroupUtils.addMatrixToResult(tmpResCopy, ret, _colIndexes, rl, ru); + LibMatrixMult.matrixMult(preAggCopy, dict, tmpRes, k); + ColGroupUtils.addMatrixToResult(tmpRes, ret, _colIndexes, rl, ru); } } protected abstract int numRowsToMultiply(); + + public abstract void leftMMIdentityPreAggregateDense(MatrixBlock that, MatrixBlock ret, int rl, int ru, int cl, + int cu); } diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupConst.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupConst.java index 0ef7a423503..380fc29b26f 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupConst.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupConst.java @@ -24,6 +24,7 @@ import org.apache.commons.lang3.NotImplementedException; import org.apache.sysds.runtime.compress.DMLCompressionException; +import org.apache.sysds.runtime.compress.colgroup.ColGroupUtils.P; import org.apache.sysds.runtime.compress.colgroup.dictionary.Dictionary; import org.apache.sysds.runtime.compress.colgroup.dictionary.DictionaryFactory; import org.apache.sysds.runtime.compress.colgroup.dictionary.IDictionary; @@ -647,4 +648,14 @@ public AMapToData getMapToData() { return MapToFactory.create(0, 0); } + @Override + protected void sparseSelection(MatrixBlock selection, P[] points, MatrixBlock ret, int rl, int ru) { + throw new NotImplementedException(); + } + + @Override + protected void denseSelection(MatrixBlock selection, P[] points, MatrixBlock ret, int rl, int ru) { + throw new NotImplementedException(); + } + } diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupDDC.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupDDC.java index 79ffbe2bbe6..e804bf855b6 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupDDC.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupDDC.java @@ -28,12 +28,15 @@ import org.apache.sysds.runtime.DMLRuntimeException; import org.apache.sysds.runtime.compress.CompressedMatrixBlock; import org.apache.sysds.runtime.compress.DMLCompressionException; +import org.apache.sysds.runtime.compress.colgroup.ColGroupUtils.P; import org.apache.sysds.runtime.compress.colgroup.dictionary.Dictionary; import org.apache.sysds.runtime.compress.colgroup.dictionary.DictionaryFactory; import org.apache.sysds.runtime.compress.colgroup.dictionary.IDictionary; +import org.apache.sysds.runtime.compress.colgroup.dictionary.IdentityDictionary; import org.apache.sysds.runtime.compress.colgroup.dictionary.MatrixBlockDictionary; import org.apache.sysds.runtime.compress.colgroup.indexes.ColIndexFactory; import org.apache.sysds.runtime.compress.colgroup.indexes.IColIndex; +import org.apache.sysds.runtime.compress.colgroup.indexes.RangeIndex; import org.apache.sysds.runtime.compress.colgroup.mapping.AMapToData; import org.apache.sysds.runtime.compress.colgroup.mapping.MapToByte; import org.apache.sysds.runtime.compress.colgroup.mapping.MapToChar; @@ -398,7 +401,10 @@ public void preAggregateDense(MatrixBlock m, double[] preAgg, int rl, int ru, in } @Override - public void preAggregateSparse(SparseBlock sb, double[] preAgg, int rl, int ru) { + public void preAggregateSparse(SparseBlock sb, double[] preAgg, int rl, int ru, int cl, int cu) { + if(cl != 0 || cu != _data.size()) { + throw new NotImplementedException(); + } _data.preAggregateSparse(sb, preAgg, rl, ru); } @@ -628,6 +634,90 @@ protected AColGroup fixColIndexes(IColIndex newColIndex, int[] reordering) { return ColGroupDDC.create(newColIndex, _dict.reorder(reordering), _data, getCachedCounts()); } + @Override + public void sparseSelection(MatrixBlock selection,P[] points, MatrixBlock ret, int rl, int ru) { + // morph(CompressionType.UNCOMPRESSED, _data.size()).sparseSelection(selection, ret, rl, ru);; + final SparseBlock sb = selection.getSparseBlock(); + final SparseBlock retB = ret.getSparseBlock(); + for(int r = rl; r < ru; r++) { + if(sb.isEmpty(r)) + continue; + final int sPos = sb.pos(r); + final int rowCompressed = sb.indexes(r)[sPos]; // column index with 1 + decompressToSparseBlock(retB, rowCompressed, rowCompressed + 1, r - rowCompressed, 0); + } + } + + + @Override + protected void denseSelection(MatrixBlock selection, P[] points, MatrixBlock ret, int rl, int ru) { + // morph(CompressionType.UNCOMPRESSED, _data.size()).sparseSelection(selection, ret, rl, ru);; + final SparseBlock sb = selection.getSparseBlock(); + final DenseBlock retB = ret.getDenseBlock(); + for(int r = rl; r < ru; r++) { + if(sb.isEmpty(r)) + continue; + final int sPos = sb.pos(r); + final int rowCompressed = sb.indexes(r)[sPos]; // column index with 1 + decompressToDenseBlock(retB, rowCompressed, rowCompressed + 1, r - rowCompressed, 0); + } + } + + @Override + public void leftMMIdentityPreAggregateDense(MatrixBlock that, MatrixBlock ret, int rl, int ru, int cl, int cu) { + DenseBlock db = that.getDenseBlock(); + DenseBlock retDB = ret.getDenseBlock(); + if(rl == ru - 1) + leftMMIdentityPreAggregateDenseSingleRow(db.values(rl), db.pos(rl), retDB.values(rl), retDB.pos(rl), cl, cu); + else + throw new NotImplementedException(); + } + + + private void leftMMIdentityPreAggregateDenseSingleRow(double[] values, int pos, double[] values2, int pos2, int cl, + int cu) { + IdentityDictionary a = (IdentityDictionary) _dict; + if(_colIndexes instanceof RangeIndex) + leftMMIdentityPreAggregateDenseSingleRowRangeIndex(values, pos, values2, pos2, cl, cu); + else { + + pos += cl; // left side matrix position offset. + if(a.withEmpty()) { + final int nVal = _dict.getNumberOfValues(_colIndexes.size()) - 1; + for(int rc = cl; rc < cu; rc++, pos++) { + final int idx = _data.getIndex(rc); + if(idx != nVal) + values2[_colIndexes.get(_data.getIndex(rc))] += values[pos]; + } + } + else { + for(int rc = cl; rc < cu; rc++, pos++) + values2[_colIndexes.get(_data.getIndex(rc))] += values[pos]; + } + } + } + + + private void leftMMIdentityPreAggregateDenseSingleRowRangeIndex(double[] values, int pos, double[] values2, int pos2, + int cl, int cu) { + IdentityDictionary a = (IdentityDictionary) _dict; + + final int firstCol = _colIndexes.get(0); + pos += cl; // left side matrix position offset. + if(a.withEmpty()) { + final int nVal = _dict.getNumberOfValues(_colIndexes.size()) - 1; + for(int rc = cl; rc < cu; rc++, pos++) { + final int idx = _data.getIndex(rc); + if(idx != nVal) + values2[firstCol + idx] += values[pos]; + } + } + else { + for(int rc = cl; rc < cu; rc++, pos++) + values2[firstCol + _data.getIndex(rc)] += values[pos]; + } + } + @Override public String toString() { StringBuilder sb = new StringBuilder(); diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupDDCFOR.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupDDCFOR.java index d09ba4e624d..14dcf3c635c 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupDDCFOR.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupDDCFOR.java @@ -26,6 +26,7 @@ import org.apache.commons.lang3.NotImplementedException; import org.apache.sysds.runtime.DMLRuntimeException; +import org.apache.sysds.runtime.compress.colgroup.ColGroupUtils.P; import org.apache.sysds.runtime.compress.colgroup.dictionary.DictionaryFactory; import org.apache.sysds.runtime.compress.colgroup.dictionary.IDictionary; import org.apache.sysds.runtime.compress.colgroup.dictionary.MatrixBlockDictionary; @@ -40,6 +41,8 @@ import org.apache.sysds.runtime.compress.estim.encoding.EncodingFactory; import org.apache.sysds.runtime.compress.estim.encoding.IEncode; import org.apache.sysds.runtime.compress.utils.Util; +import org.apache.sysds.runtime.data.DenseBlock; +import org.apache.sysds.runtime.data.SparseBlock; import org.apache.sysds.runtime.functionobjects.Builtin; import org.apache.sysds.runtime.functionobjects.Divide; import org.apache.sysds.runtime.functionobjects.Minus; @@ -252,7 +255,7 @@ public AColGroup replace(double pattern, double replace) { if(patternInReference) { double[] nRef = new double[_reference.length]; for(int i = 0; i < _reference.length; i++) - if(Util.eq(pattern ,_reference[i])) + if(Util.eq(pattern, _reference[i])) nRef[i] = replace; else nRef[i] = _reference[i]; @@ -489,6 +492,34 @@ protected AColGroup fixColIndexes(IColIndex newColIndex, int[] reordering) { throw new NotImplementedException(); } + @Override + protected void sparseSelection(MatrixBlock selection, P[] points, MatrixBlock ret, int rl, int ru) { + final SparseBlock sb = selection.getSparseBlock(); + final SparseBlock retB = ret.getSparseBlock(); + for(int r = rl; r < ru; r++) { + if(sb.isEmpty(r)) + continue; + + final int sPos = sb.pos(r); + final int rowCompressed = sb.indexes(r)[sPos]; + decompressToSparseBlock(retB, rowCompressed, rowCompressed + 1, r - rowCompressed, 0); + } + } + + @Override + protected void denseSelection(MatrixBlock selection, P[] points, MatrixBlock ret, int rl, int ru) { + final SparseBlock sb = selection.getSparseBlock(); + final DenseBlock retB = ret.getDenseBlock(); + for(int r = rl; r < ru; r++) { + if(sb.isEmpty(r)) + continue; + + final int sPos = sb.pos(r); + final int rowCompressed = sb.indexes(r)[sPos]; + decompressToDenseBlock(retB, rowCompressed, rowCompressed + 1, r - rowCompressed, 0); + } + } + @Override public String toString() { StringBuilder sb = new StringBuilder(); diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupEmpty.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupEmpty.java index ce7954c7a1e..e90dd7a2545 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupEmpty.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupEmpty.java @@ -23,8 +23,10 @@ import java.io.IOException; import java.util.Arrays; +import org.apache.commons.lang3.NotImplementedException; import org.apache.sysds.runtime.DMLRuntimeException; import org.apache.sysds.runtime.compress.DMLCompressionException; +import org.apache.sysds.runtime.compress.colgroup.ColGroupUtils.P; import org.apache.sysds.runtime.compress.colgroup.dictionary.ADictionary; import org.apache.sysds.runtime.compress.colgroup.dictionary.Dictionary; import org.apache.sysds.runtime.compress.colgroup.indexes.ColIndexFactory; @@ -53,7 +55,7 @@ import org.apache.sysds.runtime.matrix.operators.UnaryOperator; public class ColGroupEmpty extends AColGroupCompressed - implements IContainADictionary, IContainDefaultTuple, AOffsetsGroup ,IMapToDataGroup{ + implements IContainADictionary, IContainDefaultTuple, AOffsetsGroup, IMapToDataGroup { private static final long serialVersionUID = -2307677253622099958L; /** @@ -403,9 +405,18 @@ public AMapToData getMapToData() { return MapToFactory.create(0, 0); } - @Override - public AColGroup reduceCols(){ + @Override + public AColGroup reduceCols() { return null; } + @Override + protected void sparseSelection(MatrixBlock selection, P[] points, MatrixBlock ret, int rl, int ru) { + throw new NotImplementedException(); + } + + @Override + protected void denseSelection(MatrixBlock selection, P[] points, MatrixBlock ret, int rl, int ru) { + throw new NotImplementedException(); + } } diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupLinearFunctional.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupLinearFunctional.java index f083a4dfd9a..2983534c249 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupLinearFunctional.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupLinearFunctional.java @@ -26,6 +26,7 @@ import org.apache.commons.lang3.NotImplementedException; import org.apache.sysds.runtime.compress.DMLCompressionException; +import org.apache.sysds.runtime.compress.colgroup.ColGroupUtils.P; import org.apache.sysds.runtime.compress.colgroup.indexes.ColIndexFactory; import org.apache.sysds.runtime.compress.colgroup.indexes.IColIndex; import org.apache.sysds.runtime.compress.colgroup.scheme.ICLAScheme; @@ -703,8 +704,18 @@ protected AColGroup fixColIndexes(IColIndex newColIndex, int[] reordering) { throw new NotImplementedException(); } - @Override - public AColGroup reduceCols(){ + @Override + public AColGroup reduceCols() { + throw new NotImplementedException(); + } + + @Override + public void sparseSelection(MatrixBlock selection, P[] points, MatrixBlock ret, int rl, int ru) { + throw new NotImplementedException(); + } + + @Override + protected void denseSelection(MatrixBlock selection, P[] points, MatrixBlock ret, int rl, int ru) { throw new NotImplementedException(); } diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupOLE.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupOLE.java index 8af0f959e0c..9f78fe51f4a 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupOLE.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupOLE.java @@ -27,6 +27,7 @@ import org.apache.sysds.runtime.compress.CompressionSettings; import org.apache.sysds.runtime.compress.bitmap.ABitmap; import org.apache.sysds.runtime.compress.colgroup.dictionary.IDictionary; +import org.apache.sysds.runtime.compress.colgroup.ColGroupUtils.P; import org.apache.sysds.runtime.compress.colgroup.dictionary.Dictionary; import org.apache.sysds.runtime.compress.colgroup.dictionary.DictionaryFactory; import org.apache.sysds.runtime.compress.colgroup.indexes.ColIndexFactory; @@ -601,7 +602,15 @@ public void preAggregateDense(MatrixBlock m, double[] preAgg, int rl, int ru, in } @Override - public void preAggregateSparse(SparseBlock sb, double[] preAgg, int rl, int ru) { + public void leftMMIdentityPreAggregateDense(MatrixBlock that, MatrixBlock ret, int rl, int ru, int cl, int cu) { + throw new NotImplementedException(); + } + + @Override + public void preAggregateSparse(SparseBlock sb, double[] preAgg, int rl, int ru, int cl, int cu) { + if(cl != 0 || cu != _numRows) { + throw new NotImplementedException(); + } throw new NotImplementedException(); } @@ -689,4 +698,14 @@ protected AColGroup fixColIndexes(IColIndex newColIndex, int[] reordering) { throw new NotImplementedException(); } + @Override + protected void sparseSelection(MatrixBlock selection, P[] points, MatrixBlock ret, int rl, int ru) { + throw new NotImplementedException(); + } + + @Override + protected void denseSelection(MatrixBlock selection, P[] points, MatrixBlock ret, int rl, int ru) { + throw new NotImplementedException(); + } + } diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupRLE.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupRLE.java index 23596c1e190..834178162e1 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupRLE.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupRLE.java @@ -28,6 +28,7 @@ import org.apache.commons.lang3.NotImplementedException; import org.apache.sysds.runtime.compress.bitmap.ABitmap; import org.apache.sysds.runtime.compress.colgroup.dictionary.IDictionary; +import org.apache.sysds.runtime.compress.colgroup.ColGroupUtils.P; import org.apache.sysds.runtime.compress.colgroup.dictionary.Dictionary; import org.apache.sysds.runtime.compress.colgroup.dictionary.DictionaryFactory; import org.apache.sysds.runtime.compress.colgroup.indexes.ColIndexFactory; @@ -795,7 +796,15 @@ public void preAggregateDense(MatrixBlock m, double[] preAgg, final int rl, fina } @Override - public void preAggregateSparse(SparseBlock sb, double[] preAgg, int rl, int ru) { + public void leftMMIdentityPreAggregateDense(MatrixBlock that, MatrixBlock ret, int rl, int ru, int cl, int cu) { + throw new NotImplementedException(); + } + + @Override + public void preAggregateSparse(SparseBlock sb, double[] preAgg, int rl, int ru, int cl, int cu) { + if(cl != 0 || cu != _numRows) { + throw new NotImplementedException(); + } final int nv = getNumValues(); for(int r = rl; r < ru; r++) { // for each row @@ -1146,4 +1155,14 @@ public static char[] genRLEBitmap(int[] offsets, int len) { return ret; } + + @Override + protected void sparseSelection(MatrixBlock selection, P[] points, MatrixBlock ret, int rl, int ru) { + throw new NotImplementedException(); + } + + @Override + protected void denseSelection(MatrixBlock selection, P[] points, MatrixBlock ret, int rl, int ru) { + throw new NotImplementedException(); + } } diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDC.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDC.java index a905e401e42..ea4f2fb5811 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDC.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDC.java @@ -24,8 +24,10 @@ import java.io.IOException; import java.util.Arrays; +import org.apache.commons.lang3.NotImplementedException; import org.apache.sysds.runtime.DMLRuntimeException; import org.apache.sysds.runtime.compress.DMLCompressionException; +import org.apache.sysds.runtime.compress.colgroup.ColGroupUtils.P; import org.apache.sysds.runtime.compress.colgroup.dictionary.Dictionary; import org.apache.sysds.runtime.compress.colgroup.dictionary.DictionaryFactory; import org.apache.sysds.runtime.compress.colgroup.dictionary.IDictionary; @@ -43,6 +45,7 @@ import org.apache.sysds.runtime.compress.estim.encoding.EncodingFactory; import org.apache.sysds.runtime.compress.estim.encoding.IEncode; import org.apache.sysds.runtime.compress.utils.Util; +import org.apache.sysds.runtime.data.SparseBlock; import org.apache.sysds.runtime.functionobjects.Builtin; import org.apache.sysds.runtime.instructions.cp.CM_COV_Object; import org.apache.sysds.runtime.matrix.data.MatrixBlock; @@ -459,7 +462,7 @@ public AColGroup replace(double pattern, double replace) { IDictionary replaced = _dict.replace(pattern, replace, _colIndexes.size()); double[] newDefaultTuple = new double[_defaultTuple.length]; for(int i = 0; i < _defaultTuple.length; i++) - newDefaultTuple[i] = Util.eq(_defaultTuple[i],pattern) ? replace : _defaultTuple[i]; + newDefaultTuple[i] = Util.eq(_defaultTuple[i], pattern) ? replace : _defaultTuple[i]; return create(_colIndexes, _numRows, replaced, newDefaultTuple, _indexes, _data, getCachedCounts()); } @@ -662,6 +665,65 @@ public int getNumberOffsets() { return _data.size(); } + @Override + protected void sparseSelection(MatrixBlock selection, P[] points, MatrixBlock ret, int rl, int ru) { + final SparseBlock sr = ret.getSparseBlock(); + final int nCol = _colIndexes.size(); + final AIterator it = _indexes.getIterator(); + final int last = _indexes.getOffsetToLast(); + + int c = 0; + + int of = it.value(); + while(of < last && c < points.length) { + // final int of = it.value(); + if(points[c].o < of) { + putDefault(points[c].r, sr, nCol); + c++; + } + else { + while(c < points.length && points[c].o == of) { + _dict.put(sr, _data.getIndex(it.getDataIndex()), points[c].r, nCol, _colIndexes); + c++; + } + of = it.next(); + } + // c++; + } + + for(; c < points.length && points[c].o < last; c++) { + putDefault(points[c].r, sr, nCol); + } + + while(of == last && c < points.length && points[c].o == of) { + _dict.put(sr, _data.getIndex(it.getDataIndex()), points[c].r, nCol, _colIndexes); + c++; + } + + // set default in tail. + for(; c < points.length; c++) { + putDefault(points[c].r, sr, nCol); + } + + } + + private void putDefault(final int r, final SparseBlock sr, final int nCol) { + if(sr.isAllocated(r)) + for(int i = 0; i < nCol; i++) + sr.add(r, _colIndexes.get(i), _defaultTuple[i]); + else { + sr.allocate(r, _colIndexes.size()); + for(int i = 0; i < nCol; i++) + sr.append(r, _colIndexes.get(i), _defaultTuple[i]); + } + + } + + @Override + protected void denseSelection(MatrixBlock selection, P[] points, MatrixBlock ret, int rl, int ru) { + throw new NotImplementedException(); + } + @Override public String toString() { StringBuilder sb = new StringBuilder(); diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCFOR.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCFOR.java index dfb9a605118..a3f39a0eb5a 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCFOR.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCFOR.java @@ -27,6 +27,7 @@ import org.apache.commons.lang3.NotImplementedException; import org.apache.sysds.runtime.compress.DMLCompressionException; import org.apache.sysds.runtime.compress.colgroup.dictionary.IDictionary; +import org.apache.sysds.runtime.compress.colgroup.ColGroupUtils.P; import org.apache.sysds.runtime.compress.colgroup.dictionary.Dictionary; import org.apache.sysds.runtime.compress.colgroup.dictionary.DictionaryFactory; import org.apache.sysds.runtime.compress.colgroup.indexes.ColIndexFactory; @@ -48,6 +49,7 @@ import org.apache.sysds.runtime.functionobjects.Multiply; import org.apache.sysds.runtime.functionobjects.Plus; import org.apache.sysds.runtime.instructions.cp.CM_COV_Object; +import org.apache.sysds.runtime.matrix.data.MatrixBlock; import org.apache.sysds.runtime.matrix.operators.BinaryOperator; import org.apache.sysds.runtime.matrix.operators.CMOperator; import org.apache.sysds.runtime.matrix.operators.ScalarOperator; @@ -520,6 +522,16 @@ public ICLAScheme getCompressionScheme() { throw new NotImplementedException(); } + @Override + public void sparseSelection(MatrixBlock selection, P[] points, MatrixBlock ret, int rl, int ru) { + throw new NotImplementedException(); + } + + @Override + protected void denseSelection(MatrixBlock selection, P[] points, MatrixBlock ret, int rl, int ru) { + throw new NotImplementedException(); + } + @Override public String toString() { StringBuilder sb = new StringBuilder(); diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCSingle.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCSingle.java index 3e9b1deba22..7278b7467bb 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCSingle.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCSingle.java @@ -24,8 +24,10 @@ import java.io.IOException; import java.util.Arrays; +import org.apache.commons.lang3.NotImplementedException; import org.apache.sysds.runtime.DMLRuntimeException; import org.apache.sysds.runtime.compress.DMLCompressionException; +import org.apache.sysds.runtime.compress.colgroup.ColGroupUtils.P; import org.apache.sysds.runtime.compress.colgroup.dictionary.Dictionary; import org.apache.sysds.runtime.compress.colgroup.dictionary.DictionaryFactory; import org.apache.sysds.runtime.compress.colgroup.dictionary.IDictionary; @@ -43,6 +45,7 @@ import org.apache.sysds.runtime.compress.utils.Util; import org.apache.sysds.runtime.functionobjects.Builtin; import org.apache.sysds.runtime.instructions.cp.CM_COV_Object; +import org.apache.sysds.runtime.matrix.data.MatrixBlock; import org.apache.sysds.runtime.matrix.operators.BinaryOperator; import org.apache.sysds.runtime.matrix.operators.CMOperator; import org.apache.sysds.runtime.matrix.operators.ScalarOperator; @@ -619,6 +622,16 @@ protected AColGroup fixColIndexes(IColIndex newColIndex, int[] reordering) { ColGroupUtils.reorderDefault(_defaultTuple, reordering), _indexes, getCachedCounts()); } + @Override + protected void sparseSelection(MatrixBlock selection, P[] points, MatrixBlock ret, int rl, int ru) { + throw new NotImplementedException(); + } + + @Override + protected void denseSelection(MatrixBlock selection, P[] points, MatrixBlock ret, int rl, int ru) { + throw new NotImplementedException(); + } + @Override public String toString() { StringBuilder sb = new StringBuilder(); diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCSingleZeros.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCSingleZeros.java index cc25e54b41c..06b4474534e 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCSingleZeros.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCSingleZeros.java @@ -27,6 +27,7 @@ import org.apache.commons.lang3.NotImplementedException; import org.apache.sysds.runtime.compress.CompressedMatrixBlock; import org.apache.sysds.runtime.compress.DMLCompressionException; +import org.apache.sysds.runtime.compress.colgroup.ColGroupUtils.P; import org.apache.sysds.runtime.compress.colgroup.dictionary.Dictionary; import org.apache.sysds.runtime.compress.colgroup.dictionary.DictionaryFactory; import org.apache.sysds.runtime.compress.colgroup.dictionary.IDictionary; @@ -394,7 +395,15 @@ else if(cu < _indexes.getOffsetToLast() + 1) { } @Override - public void preAggregateSparse(SparseBlock sb, double[] preAgg, int rl, int ru) { + public void leftMMIdentityPreAggregateDense(MatrixBlock that, MatrixBlock ret, int rl, int ru, int cl, int cu) { + throw new NotImplementedException(); + } + + @Override + public void preAggregateSparse(SparseBlock sb, double[] preAgg, int rl, int ru, int cl, int cu) { + if(cl != 0 || cu < _indexes.getOffsetToLast()) { + throw new NotImplementedException(); + } final AOffsetIterator it = _indexes.getOffsetIterator(); if(rl == ru - 1) preAggregateSparseSingleRow(sb, preAgg, rl, _indexes.getOffsetToLast(), it); @@ -826,8 +835,8 @@ public AColGroup sliceRows(int rl, int ru) { OffsetSliceInfo off = _indexes.slice(rl, ru); if(off.lIndex == -1) return null; - if(CompressedMatrixBlock.debug){ - if(off.offsetSlice.getOffsetToFirst() < 0 || off.offsetSlice.getOffsetToLast() > ru-rl) + if(CompressedMatrixBlock.debug) { + if(off.offsetSlice.getOffsetToFirst() < 0 || off.offsetSlice.getOffsetToLast() > ru - rl) throw new DMLCompressionException("Failed to slice : " + rl + " " + ru + " in: " + this); } return create(_colIndexes, ru - rl, _dict, off.offsetSlice, null); @@ -853,12 +862,12 @@ public AColGroup appendNInternal(AColGroup[] g, int blen, int rlen) { return null; } - if(!(gs instanceof AOffsetsGroup )) { + if(!(gs instanceof AOffsetsGroup)) { LOG.warn("Not SDCFOR but " + gs.getClass().getSimpleName()); return null; } - if( gs instanceof ColGroupSDCSingleZeros){ + if(gs instanceof ColGroupSDCSingleZeros) { final ColGroupSDCSingleZeros gc = (ColGroupSDCSingleZeros) gs; if(!gc._dict.equals(_dict)) { LOG.warn("Not same Dictionaries therefore not appending \n" + _dict + "\n\n" + gc._dict); @@ -885,6 +894,16 @@ public int getNumberOffsets() { return getCounts()[0]; } + @Override + protected void sparseSelection(MatrixBlock selection, P[] points, MatrixBlock ret, int rl, int ru) { + throw new NotImplementedException(); + } + + @Override + protected void denseSelection(MatrixBlock selection, P[] points, MatrixBlock ret, int rl, int ru) { + throw new NotImplementedException(); + } + @Override public String toString() { StringBuilder sb = new StringBuilder(); diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCZeros.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCZeros.java index 38d88ef106b..c1e081f2533 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCZeros.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCZeros.java @@ -24,11 +24,13 @@ import java.io.IOException; import java.util.Arrays; +import org.apache.commons.lang3.NotImplementedException; import org.apache.sysds.runtime.DMLRuntimeException; import org.apache.sysds.runtime.compress.DMLCompressionException; -import org.apache.sysds.runtime.compress.colgroup.dictionary.IDictionary; +import org.apache.sysds.runtime.compress.colgroup.ColGroupUtils.P; import org.apache.sysds.runtime.compress.colgroup.dictionary.Dictionary; import org.apache.sysds.runtime.compress.colgroup.dictionary.DictionaryFactory; +import org.apache.sysds.runtime.compress.colgroup.dictionary.IDictionary; import org.apache.sysds.runtime.compress.colgroup.dictionary.MatrixBlockDictionary; import org.apache.sysds.runtime.compress.colgroup.indexes.ColIndexFactory; import org.apache.sysds.runtime.compress.colgroup.indexes.IColIndex; @@ -489,7 +491,15 @@ public void preAggregateDense(MatrixBlock m, double[] preAgg, int rl, int ru, in } @Override - public void preAggregateSparse(SparseBlock sb, double[] preAgg, int rl, int ru) { + public void leftMMIdentityPreAggregateDense(MatrixBlock that, MatrixBlock ret, int rl, int ru, int cl, int cu) { + throw new NotImplementedException(); + } + + @Override + public void preAggregateSparse(SparseBlock sb, double[] preAgg, int rl, int ru, int cl, int cu) { + if(cl != 0 || cu < _indexes.getOffsetToLast()) { + throw new NotImplementedException(); + } _data.preAggregateSparse(sb, preAgg, rl, ru, _indexes); } @@ -767,7 +777,7 @@ public AColGroup append(AColGroup g) { @Override public AColGroup appendNInternal(AColGroup[] g, int blen, int rlen) { - + for(int i = 1; i < g.length; i++) { final AColGroup gs = g[i]; if(!_colIndexes.equals(gs._colIndexes)) { @@ -775,12 +785,12 @@ public AColGroup appendNInternal(AColGroup[] g, int blen, int rlen) { return null; } - if(!(gs instanceof AOffsetsGroup )) { + if(!(gs instanceof AOffsetsGroup)) { LOG.warn("Not valid OffsetGroup but " + gs.getClass().getSimpleName()); return null; } - if( gs instanceof ColGroupSDCZeros){ + if(gs instanceof ColGroupSDCZeros) { final ColGroupSDCZeros gc = (ColGroupSDCZeros) gs; if(!gc._dict.equals(_dict)) { LOG.warn("Not same Dictionaries therefore not appending \n" + _dict + "\n\n" + gc._dict); @@ -815,6 +825,46 @@ protected AColGroup fixColIndexes(IColIndex newColIndex, int[] reordering) { getCachedCounts()); } + @Override + public void sparseSelection(MatrixBlock selection, P[] points, MatrixBlock ret, int rl, int ru) { + final SparseBlock sr = ret.getSparseBlock(); + final int nCol = _colIndexes.size(); + final AIterator it = _indexes.getIterator(); + final int last = _indexes.getOffsetToLast(); + int c = 0; + int of = it.value(); + + while(of < last && c < points.length) { + if(points[c].o == of) { + c = processRow(points, sr, nCol, c, of, _data.getIndex(it.getDataIndex())); + of = it.next(); + } + else if(points[c].o < of) + c++; + else + of = it.next(); + } + // increment the c pointer until it is pointing at least to last point or is done. + while(c < points.length && points[c].o < last) + c++; + + c = processRow(points, sr, nCol, c, of, _data.getIndex(it.getDataIndex())); + + } + + @Override + protected void denseSelection(MatrixBlock selection, P[] points, MatrixBlock ret, int rl, int ru) { + throw new NotImplementedException(); + } + + private int processRow(P[] points, final SparseBlock sr, final int nCol, int c, int of, final int did) { + while(c < points.length && points[c].o == of) { + _dict.put(sr, did, points[c].r, nCol, _colIndexes); + c++; + } + return c; + } + public String toString() { StringBuilder sb = new StringBuilder(); sb.append(super.toString()); diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupUncompressed.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupUncompressed.java index f7eec6bff46..df9ccee552d 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupUncompressed.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupUncompressed.java @@ -32,6 +32,7 @@ import org.apache.sysds.runtime.compress.CompressionSettings; import org.apache.sysds.runtime.compress.CompressionSettingsBuilder; import org.apache.sysds.runtime.compress.DMLCompressionException; +import org.apache.sysds.runtime.compress.colgroup.ColGroupUtils.P; import org.apache.sysds.runtime.compress.colgroup.dictionary.DictLibMatrixMult; import org.apache.sysds.runtime.compress.colgroup.dictionary.IDictionary; import org.apache.sysds.runtime.compress.colgroup.dictionary.MatrixBlockDictionary; @@ -907,6 +908,113 @@ protected AColGroup fixColIndexes(IColIndex newColIndex, int[] reordering) { return create(newColIndex, ret, false); } + @Override + public void sparseSelection(MatrixBlock selection, P[] points, MatrixBlock ret, int rl, int ru) { + if(_data.isInSparseFormat()) + sparseSelectionSparseColumnGroup(selection, ret, rl, ru); + else + sparseSelectionDenseColumnGroup(selection, ret, rl, ru); + } + + @Override + protected void denseSelection(MatrixBlock selection, P[] points, MatrixBlock ret, int rl, int ru) { + if(_data.isInSparseFormat()) + denseSelectionSparseColumnGroup(selection, ret, rl, ru); + else + denseSelectionDenseColumnGroup(selection, ret, rl, ru); + } + + private void sparseSelectionSparseColumnGroup(MatrixBlock selection, MatrixBlock ret, int rl, int ru) { + + final SparseBlock sb = selection.getSparseBlock(); + final SparseBlock retB = ret.getSparseBlock(); + final SparseBlock tb = _data.getSparseBlock(); + for(int r = rl; r < ru; r++) { + if(sb.isEmpty(r)) + continue; + + final int sPos = sb.pos(r); + final int rowCompressed = sb.indexes(r)[sPos]; + if(tb.isEmpty(rowCompressed)) + continue; + final int tPos = tb.pos(rowCompressed); + final int tEnd = tb.size(rowCompressed) + tPos; + final int[] tIx = tb.indexes(rowCompressed); + final double[] tVal = tb.values(rowCompressed); + for(int j = tPos; j < tEnd; j++) + retB.append(r, _colIndexes.get(tIx[j]), tVal[j]); + } + + } + + private void sparseSelectionDenseColumnGroup(MatrixBlock selection, MatrixBlock ret, int rl, int ru) { + final SparseBlock sb = selection.getSparseBlock(); + final SparseBlock retB = ret.getSparseBlock(); + final DenseBlock tb = _data.getDenseBlock(); + final int nCol = _colIndexes.size(); + for(int r = rl; r < ru; r++) { + if(sb.isEmpty(r)) + continue; + + final int sPos = sb.pos(r); + final int rowCompressed = sb.indexes(r)[sPos]; + + double[] tVal = tb.values(rowCompressed); + int tPos = tb.pos(rowCompressed); + for(int j = 0; j < nCol; j++) + retB.append(r, _colIndexes.get(j), tVal[tPos + j]); + } + } + + private void denseSelectionSparseColumnGroup(MatrixBlock selection, MatrixBlock ret, int rl, int ru) { + + final SparseBlock sb = selection.getSparseBlock(); + final DenseBlock retB = ret.getDenseBlock(); + final SparseBlock tb = _data.getSparseBlock(); + for(int r = rl; r < ru; r++) { + if(sb.isEmpty(r)) + continue; + + final int sPos = sb.pos(r); + final int rowCompressed = sb.indexes(r)[sPos]; + if(tb.isEmpty(rowCompressed)) + continue; + final int tPos = tb.pos(rowCompressed); + final int tEnd = tb.size(rowCompressed) + tPos; + final int[] tIx = tb.indexes(rowCompressed); + final double[] tVal = tb.values(rowCompressed); + + final double[] rVal = retB.values(r); + final int pos = retB.pos(r); + for(int j = tPos; j < tEnd; j++) + rVal[pos + _colIndexes.get(tIx[j])] += tVal[j]; + } + + } + + private void denseSelectionDenseColumnGroup(MatrixBlock selection, MatrixBlock ret, int rl, int ru) { + final SparseBlock sb = selection.getSparseBlock(); + final DenseBlock retB = ret.getDenseBlock(); + final DenseBlock tb = _data.getDenseBlock(); + final int nCol = _colIndexes.size(); + for(int r = rl; r < ru; r++) { + if(sb.isEmpty(r)) + continue; + + final int sPos = sb.pos(r); + final int rowCompressed = sb.indexes(r)[sPos]; + + double[] tVal = tb.values(rowCompressed); + int tPos = tb.pos(rowCompressed); + + final double[] rVal = retB.values(r); + final int pos = retB.pos(r); + + for(int j = 0; j < nCol; j++) + rVal[pos + _colIndexes.get(j)] += tVal[tPos + j]; + } + } + @Override public AColGroup reduceCols() { MatrixBlock mb = _data.rowSum(); diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupUtils.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupUtils.java index c67a40b34c1..5f0312c21f6 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupUtils.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupUtils.java @@ -19,6 +19,8 @@ package org.apache.sysds.runtime.compress.colgroup; +import java.util.Arrays; + import org.apache.sysds.runtime.compress.colgroup.indexes.IColIndex; import org.apache.sysds.runtime.compress.utils.DoubleCountHashMap; import org.apache.sysds.runtime.data.SparseBlock; @@ -312,4 +314,56 @@ public static double[] reorderDefault(double[] vals, int[] reordering){ return ret; } + + /** + * Get a list of points locations from the SparseBlock. + * + * This is used to find 1 indexes in a sparse selection matrix. + * + * We assume the input only have one non zero per row, and that non zero is a 1. + * + * @param sb Sparse block to extract points from + * @param rl row to start from + * @param ru row to end at + * @return The coordinates that contain values. + */ + public static P[] getSortedSelection(SparseBlock sb, int rl, int ru) { + + int c = 0; + // count loop + for(int i = rl; i < ru; i++) { + if(!sb.isEmpty(i)) + c++; + } + + P[] points = new P[c]; + c = 0; // count from start again + for(int i = rl; i < ru; i++) { + if(sb.isEmpty(i)) + continue; + final int sPos = sb.pos(i); + points[c++] = new P(i, sb.indexes(i)[sPos]); + } + + Arrays.sort(points, (a, b) -> { + return a.o < b.o ? -1 : a.o == b.o ? 0 : 1; + }); + return points; + } + + public static class P { + public final int r; + public final int o; + + private P(int r, int o) { + this.r = r; + this.o = o; + } + + @Override + public String toString() { + return "P(" + r + "," + o + ")"; + } + } + } diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/ADictionary.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/ADictionary.java index 67f546c6ac5..d41e2675f57 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/ADictionary.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/ADictionary.java @@ -22,6 +22,7 @@ import java.io.Serializable; import org.apache.sysds.runtime.compress.colgroup.indexes.IColIndex; +import org.apache.sysds.runtime.data.SparseBlock; import org.apache.sysds.runtime.functionobjects.ValueFunction; import org.apache.sysds.runtime.instructions.cp.CM_COV_Object; @@ -84,4 +85,10 @@ public static void correctNan(double[] res, IColIndex colIndexes) { res[cix] = Double.isNaN(res[cix]) ? 0 : res[cix]; } } + + @Override + public void put(SparseBlock sb, int idx, int rowOut, int nCol, IColIndex columns) { + for(int i = 0; i < nCol; i++) + sb.append(rowOut, columns.get(i), getValue(idx, i, nCol)); + } } diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/DictLibMatrixMult.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/DictLibMatrixMult.java index 9aba711a30e..aad60def17a 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/DictLibMatrixMult.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/DictLibMatrixMult.java @@ -305,7 +305,7 @@ protected static void MMDictsDenseSparse(double[] left, SparseBlock right, IColI } } - protected static void MMDictsScalingDenseSparse(double[] left, SparseBlock right, IColIndex rowsLeft, IColIndex colsRight, + protected static void MMDictsScalingDenseSparse(double[] left, SparseBlock right, IColIndex rowsLeft, IColIndex colsRight, MatrixBlock result, int[] scaling) { final double[] resV = result.getDenseBlockValues(); final int leftSize = rowsLeft.size(); diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/Dictionary.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/Dictionary.java index 9d274e0ddb4..6175daea1e0 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/Dictionary.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/Dictionary.java @@ -101,7 +101,7 @@ public long getInMemorySize() { return getInMemorySize(size()); } - protected static long getInMemorySize(int valuesCount) { + public static long getInMemorySize(int valuesCount) { // object + values array return 16 + (long) MemoryEstimates.doubleArrayCost(valuesCount); } @@ -1131,6 +1131,9 @@ else if(o instanceof MatrixBlockDictionary) { final double[] dv = mb.getDenseBlockValues(); return Arrays.equals(_values, dv); } + else if(o instanceof IdentityDictionary) { + return o.equals(this); + } return false; } diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/DictionaryFactory.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/DictionaryFactory.java index 3456cbd5934..f7c21b63688 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/DictionaryFactory.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/DictionaryFactory.java @@ -26,7 +26,6 @@ import org.apache.commons.lang3.NotImplementedException; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; -import org.apache.sysds.runtime.compress.DMLCompressionException; import org.apache.sysds.runtime.compress.bitmap.ABitmap; import org.apache.sysds.runtime.compress.bitmap.Bitmap; import org.apache.sysds.runtime.compress.bitmap.MultiColBitmap; @@ -52,18 +51,21 @@ public enum Type { } public static IDictionary read(DataInput in) throws IOException { - Type type = Type.values()[in.readByte()]; + final Type type = Type.values()[in.readByte()]; switch(type) { case FP64_DICT: return Dictionary.read(in); - case MATRIX_BLOCK_DICT: - return MatrixBlockDictionary.read(in); case INT8_DICT: return QDictionary.read(in); case PLACE_HOLDER: return PlaceHolderDict.read(in); + case IDENTITY: + return IdentityDictionary.read(in); + case IDENTITY_SLICE: + return IdentityDictionarySlice.read(in); + case MATRIX_BLOCK_DICT: default: - throw new DMLCompressionException("Unsupported type of dictionary : " + type); + return MatrixBlockDictionary.read(in); } } @@ -78,52 +80,45 @@ else if(nrColumns > 1 && tupleSparsity < 0.4) } public static IDictionary create(DblArrayCountHashMap map, int nCols, boolean addZeroTuple, double sparsity) { - try { - final ACount[] vals = map.extractValues(); - final int nVals = vals.length; - final int nTuplesOut = nVals + (addZeroTuple ? 1 : 0); - if(sparsity < 0.4) { - final MatrixBlock retB = new MatrixBlock(nTuplesOut, nCols, true); - retB.allocateSparseRowsBlock(); - final SparseBlock sb = retB.getSparseBlock(); - for(int i = 0; i < nVals; i++) { - final ACount dac = vals[i]; - final double[] dv = dac.key().getData(); - for(int k = 0; k < dv.length; k++) - sb.append(dac.id, k, dv[k]); - } - retB.recomputeNonZeros(); - retB.examSparsity(true); - return MatrixBlockDictionary.create(retB); - } - else { - final double[] resValues = new double[(nTuplesOut) * nCols]; - for(int i = 0; i < nVals; i++) { - final ACount dac = vals[i]; - System.arraycopy(dac.key().getData(), 0, resValues, dac.id * nCols, nCols); - } - return Dictionary.create(resValues); + final ACount[] vals = map.extractValues(); + final int nVals = vals.length; + final int nTuplesOut = nVals + (addZeroTuple ? 1 : 0); + if(sparsity < 0.4) { + final MatrixBlock retB = new MatrixBlock(nTuplesOut, nCols, true); + retB.allocateSparseRowsBlock(); + final SparseBlock sb = retB.getSparseBlock(); + for(int i = 0; i < nVals; i++) { + final ACount dac = vals[i]; + final double[] dv = dac.key().getData(); + for(int k = 0; k < dv.length; k++) + sb.append(dac.id, k, dv[k]); } + retB.recomputeNonZeros(); + retB.examSparsity(true); + return MatrixBlockDictionary.create(retB); } - catch(Exception e) { - throw new RuntimeException("Failed to create dictionary: " + map + " " + nCols, e); + else { + + final double[] resValues = new double[(nTuplesOut) * nCols]; + for(int i = 0; i < nVals; i++) { + final ACount dac = vals[i]; + System.arraycopy(dac.key().getData(), 0, resValues, dac.id * nCols, nCols); + } + return Dictionary.create(resValues); } + } public static IDictionary create(ABitmap ubm) { return create(ubm, 1.0); } - public static IDictionary create(ABitmap ubm, double sparsity, boolean withZeroTuple) { - return (withZeroTuple) ? createWithAppendedZeroTuple(ubm, sparsity) : create(ubm, sparsity); - } - public static IDictionary create(ABitmap ubm, double sparsity) { final int nCol = ubm.getNumColumns(); if(ubm instanceof Bitmap) return Dictionary.create(((Bitmap) ubm).getValues()); - else if(sparsity < 0.4 && nCol > 4 && ubm instanceof MultiColBitmap) { + else if(sparsity < 0.4 && nCol > 4) { // && ubm instanceof MultiColBitmap final MultiColBitmap mcbm = (MultiColBitmap) ubm; final MatrixBlock m = new MatrixBlock(ubm.getNumValues(), nCol, true); @@ -140,7 +135,7 @@ else if(sparsity < 0.4 && nCol > 4 && ubm instanceof MultiColBitmap) { m.examSparsity(true); return MatrixBlockDictionary.create(m); } - else if(ubm instanceof MultiColBitmap) { + else {// if(ubm instanceof MultiColBitmap) { MultiColBitmap mcbm = (MultiColBitmap) ubm; final int nVals = ubm.getNumValues(); double[] resValues = new double[nVals * nCol]; @@ -149,8 +144,6 @@ else if(ubm instanceof MultiColBitmap) { return Dictionary.create(resValues); } - throw new NotImplementedException( - "Not implemented creation of bitmap type : " + ubm.getClass().getSimpleName()); } public static IDictionary create(ABitmap ubm, int defaultIndex, double[] defaultTuple, double sparsity, @@ -185,7 +178,7 @@ public static IDictionary create(ABitmap ubm, int defaultIndex, double[] default defaultTuple[0] = bmv[defaultIndex]; System.arraycopy(bmv, defaultIndex + 1, dict, defaultIndex, bmv.length - defaultIndex - 1); } - else if(ubm instanceof MultiColBitmap) { + else { // if(ubm instanceof MultiColBitmap) { final MultiColBitmap mcbm = (MultiColBitmap) ubm; for(int i = 0; i < defaultIndex; i++) System.arraycopy(mcbm.getValues(i), 0, dict, i * nCol, nCol); @@ -193,47 +186,45 @@ else if(ubm instanceof MultiColBitmap) { for(int i = defaultIndex; i < ubm.getNumValues() - 1; i++) System.arraycopy(mcbm.getValues(i + 1), 0, dict, i * nCol, nCol); } - else - throw new NotImplementedException("not supported ABitmap of type:" + ubm.getClass().getSimpleName()); return Dictionary.create(dict); } } - public static IDictionary createWithAppendedZeroTuple(ABitmap ubm, double sparsity) { - final int nVals = ubm.getNumValues(); - final int nRows = nVals + 1; - final int nCols = ubm.getNumColumns(); - - if(ubm instanceof Bitmap) { - final double[] resValues = new double[nRows]; - final double[] from = ((Bitmap) ubm).getValues(); - System.arraycopy(from, 0, resValues, 0, from.length); - return Dictionary.create(resValues); - } - - final MultiColBitmap mcbm = (MultiColBitmap) ubm; - if(sparsity < 0.4 && nCols > 4) { - final MatrixBlock m = new MatrixBlock(nRows, nCols, true); - m.allocateSparseRowsBlock(); - final SparseBlock sb = m.getSparseBlock(); - - for(int i = 0; i < nVals; i++) { - final double[] tuple = mcbm.getValues(i); - for(int col = 0; col < nCols; col++) - sb.append(i, col, tuple[col]); - } - m.recomputeNonZeros(); - m.examSparsity(true); - return MatrixBlockDictionary.create(m); - } - - final double[] resValues = new double[nRows * nCols]; - for(int i = 0; i < nVals; i++) - System.arraycopy(mcbm.getValues(i), 0, resValues, i * nCols, nCols); - - return Dictionary.create(resValues); - } + // public static IDictionary createWithAppendedZeroTuple(ABitmap ubm, double sparsity) { + // final int nVals = ubm.getNumValues(); + // final int nRows = nVals + 1; + // final int nCols = ubm.getNumColumns(); + + // if(ubm instanceof Bitmap) { + // final double[] resValues = new double[nRows]; + // final double[] from = ((Bitmap) ubm).getValues(); + // System.arraycopy(from, 0, resValues, 0, from.length); + // return Dictionary.create(resValues); + // } + + // final MultiColBitmap mcbm = (MultiColBitmap) ubm; + // if(sparsity < 0.4 && nCols > 4) { + // final MatrixBlock m = new MatrixBlock(nRows, nCols, true); + // m.allocateSparseRowsBlock(); + // final SparseBlock sb = m.getSparseBlock(); + + // for(int i = 0; i < nVals; i++) { + // final double[] tuple = mcbm.getValues(i); + // for(int col = 0; col < nCols; col++) + // sb.append(i, col, tuple[col]); + // } + // m.recomputeNonZeros(); + // m.examSparsity(true); + // return MatrixBlockDictionary.create(m); + // } + + // final double[] resValues = new double[nRows * nCols]; + // for(int i = 0; i < nVals; i++) + // System.arraycopy(mcbm.getValues(i), 0, resValues, i * nCols, nCols); + + // return Dictionary.create(resValues); + // } public static IDictionary create(DoubleCountHashMap map) { final double[] resValues = map.getDictionary(); @@ -348,8 +339,8 @@ public static IDictionary combineFullDictionaries(IDictionary a, int nca, IDicti * @param nca Number of columns left dictionary * @param b Right side dictionary * @param ncb Number of columns right dictionary - * @param filter The mapping filter to not include all possible combinations in the output, this filter is allowed - * to be null, that means the output is defaulting back to a full combine + * @param filter The mapping filter to not include all possible combinations in the output, this filter is allowed to + * be null, that means the output is defaulting back to a full combine * @return A combined dictionary */ public static IDictionary combineFullDictionaries(IDictionary a, int nca, IDictionary b, int ncb, @@ -357,46 +348,59 @@ public static IDictionary combineFullDictionaries(IDictionary a, int nca, IDicti final int ra = a.getNumberOfValues(nca); final int rb = b.getNumberOfValues(ncb); - MatrixBlock ma = a.getMBDict(nca).getMatrixBlock(); - MatrixBlock mb = b.getMBDict(ncb).getMatrixBlock(); + final MatrixBlock ma = a.getMBDict(nca).getMatrixBlock(); + final MatrixBlock mb = b.getMBDict(ncb).getMatrixBlock(); - if(ra == 1 && rb == 1) - return new MatrixBlockDictionary(ma.append(mb)); + if(ra == 1 && rb == 1) { + + if(filter == null || filter.containsKey(0)) + return new MatrixBlockDictionary(ma.append(mb)); + else + return null; + } MatrixBlock out = new MatrixBlock(filter != null ? filter.size() : ra * rb, nca + ncb, false); out.allocateBlock(); - if(filter != null) { - for(int r : filter.keySet()) { - int o = filter.get(r); - int ia = r % ra; - int ib = r / ra; - for(int c = 0; c < nca; c++) - out.set(o, c, ma.get(ia, c)); + if(filter != null) + combineFullWithFilter(nca, ncb, filter, ra, ma, mb, out); + else + combineFullWithoutFilter(nca, ncb, ra, ma, mb, out); - for(int c = 0; c < ncb; c++) - out.set(o, c + nca, mb.get(ib, c)); + return new MatrixBlockDictionary(out); + } + + private static void combineFullWithoutFilter(int nca, int ncb, final int ra, MatrixBlock ma, MatrixBlock mb, + MatrixBlock out) { + for(int r = 0; r < out.getNumRows(); r++) { + int ia = r % ra; + int ib = r / ra; + for(int c = 0; c < nca; c++) + out.set(r, c, ma.get(ia, c)); + + for(int c = 0; c < ncb; c++) + out.set(r, c + nca, mb.get(ib, c)); - } } - else { + } - for(int r = 0; r < out.getNumRows(); r++) { - int ia = r % ra; - int ib = r / ra; - for(int c = 0; c < nca; c++) - out.set(r, c, ma.get(ia, c)); + private static void combineFullWithFilter(int nca, int ncb, Map filter, final int ra, + MatrixBlock ma, MatrixBlock mb, MatrixBlock out) { + for(int r : filter.keySet()) { + int o = filter.get(r); + int ia = r % ra; + int ib = r / ra; + for(int c = 0; c < nca; c++) + out.set(o, c, ma.get(ia, c)); - for(int c = 0; c < ncb; c++) - out.set(r, c + nca, mb.get(ib, c)); + for(int c = 0; c < ncb; c++) + out.set(o, c + nca, mb.get(ib, c)); - } } - return new MatrixBlockDictionary(out); } - public static IDictionary combineSDCRight(IDictionary a, int nca, IDictionary b, double[] tub) { + private static IDictionary combineSDCRight(IDictionary a, int nca, IDictionary b, double[] tub) { final int ncb = tub.length; final int ra = a.getNumberOfValues(nca); @@ -583,7 +587,7 @@ public static IDictionary combineSDC(IDictionary a, double[] tua, IDictionary b, } - public static IDictionary combineSparseConstSparseRet(IDictionary a, int nca, double[] tub) { + private static IDictionary combineSparseConstSparseRet(IDictionary a, int nca, double[] tub) { final int ncb = tub.length; final int ra = a.getNumberOfValues(nca); @@ -632,7 +636,7 @@ private static IDictionary combineSparseConstSparseRet(IDictionary a, int nca, d } - public static IDictionary combineConstSparseSparseRet(double[] tua, IDictionary b, int ncb) { + private static IDictionary combineConstSparseSparseRet(double[] tua, IDictionary b, int ncb) { final int nca = tua.length; final int rb = b.getNumberOfValues(ncb); diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/IDictionary.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/IDictionary.java index 1047692f509..a7a74775be7 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/IDictionary.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/IDictionary.java @@ -953,12 +953,7 @@ public void TSMMToUpperTriangleSparseScaling(SparseBlock left, IColIndex rowsLef */ public IDictionary cbind(IDictionary that, int nCol); - /** - * Indicate if this object is equal to another this takes into part sematic equivalence - * - * @param o The other object - * @return If they are equal - */ + @Override public boolean equals(Object o); /** @@ -985,4 +980,15 @@ public void TSMMToUpperTriangleSparseScaling(SparseBlock left, IColIndex rowsLef */ public IDictionary reorder(int[] reorder); + /** + * Put the row specified into the sparse block, via append calls. + * + * @param sb The sparse block to put into + * @param idx The dictionary index to put in. + * @param rowOut The row in the sparse block to put it into + * @param nCol The number of columns in the dictionary + * @param columns The columns to output into. + */ + public void put(SparseBlock sb, int idx, int rowOut, int nCol, IColIndex columns); + } diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/IdentityDictionary.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/IdentityDictionary.java index 74f5e5b0991..b2642dd0a53 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/IdentityDictionary.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/IdentityDictionary.java @@ -85,13 +85,15 @@ public IdentityDictionary(int nRowCol, boolean withEmpty) { @Override public double[] getValues() { + if(nRowCol < 3) { + // lets live with it if we call it on 3 columns. + double[] ret = new double[nRowCol * nRowCol + (withEmpty ? nRowCol : 0)]; + for(int i = 0; i < nRowCol; i++) { + ret[(i * nRowCol) + i] = 1; + } + return ret; + } throw new DMLCompressionException("Invalid to materialize identity Matrix Please Implement alternative"); - // LOG.warn("Should not call getValues on Identity Dictionary"); - // double[] ret = new double[nRowCol * nRowCol]; - // for(int i = 0; i < nRowCol; i++) { - // ret[(i * nRowCol) + i] = 1; - // } - // return ret; } @Override @@ -104,6 +106,10 @@ public double getValue(int i) { return row == col ? 1 : 0; } + public boolean withEmpty() { + return withEmpty; + } + @Override public double getValue(int r, int c, int nCol) { return r == c ? 1 : 0; @@ -218,21 +224,13 @@ public IDictionary binOpRight(BinaryOperator op, double[] v, IColIndex colIndexe boolean same = false; if(op.fn instanceof Plus || op.fn instanceof Minus) { same = true; - for(int i = 0; i < colIndexes.size(); i++) { - if(v[colIndexes.get(i)] != 0.0) { - same = false; - break; - } - } + for(int i = 0; i < colIndexes.size() && same; i++) + same = v[colIndexes.get(i)] == 0.0; } if(op.fn instanceof Divide) { same = true; - for(int i = 0; i < colIndexes.size(); i++) { - if(v[colIndexes.get(i)] != 1.0) { - same = false; - break; - } - } + for(int i = 0; i < colIndexes.size() && same; i++) + same = v[colIndexes.get(i)] == 1.0; } if(same) return this; @@ -341,11 +339,8 @@ public double[] productAllRowsToDoubleWithReference(double[] reference) { @Override public void colSum(double[] c, int[] counts, IColIndex colIndexes) { - for(int i = 0; i < colIndexes.size(); i++) { - // very nice... - final int idx = colIndexes.get(i); - c[idx] = counts[i]; - } + for(int i = 0; i < colIndexes.size(); i++) + c[colIndexes.get(i)] += counts[i]; } @Override @@ -421,19 +416,16 @@ public long getNumberNonZerosWithReference(int[] counts, double[] reference, int } @Override - public void addToEntry(final double[] v, final int fr, final int to, final int nCol) { - getMBDict().addToEntry(v, fr, to, nCol); + public final void addToEntry(final double[] v, final int fr, final int to, final int nCol) { + addToEntry(v, fr, to, nCol, 1); } @Override public void addToEntry(final double[] v, final int fr, final int to, final int nCol, int rep) { - if(withEmpty) { - if(fr < nRowCol) - v[to * nCol + fr] += rep; - } - else { + if(!withEmpty) + v[to * nCol + fr] += rep; + else if(fr < nRowCol) v[to * nCol + fr] += rep; - } } @Override @@ -499,6 +491,7 @@ public MatrixBlockDictionary getMBDict(int nCol) { } private MatrixBlockDictionary createMBDict() { + if(withEmpty) { final SparseBlock sb = SparseBlockFactory.createIdentityMatrixWithEmptyRow(nRowCol); final MatrixBlock identity = new MatrixBlock(nRowCol + 1, nRowCol, nRowCol, sb); @@ -657,17 +650,15 @@ public void MMDictScaling(IDictionary right, IColIndex rowsLeft, IColIndex colsR @Override public void MMDictDense(double[] left, IColIndex rowsLeft, IColIndex colsRight, MatrixBlock result) { - // getMBDict().MMDictDense(left, rowsLeft, colsRight, result); - // should replace with add to right to output cells. + // similar to fused transpose left into right locations. final int leftSide = rowsLeft.size(); - final int resCols = result.getNumColumns(); final int commonDim = Math.min(left.length / leftSide, nRowCol); final double[] resV = result.getDenseBlockValues(); - for(int i = 0; i < leftSide; i++) {// rows in left side - final int offOut = rowsLeft.get(i) * resCols; - final int leftOff = i * leftSide; + for(int i = 0; i < leftSide; i++) { // rows in left side + final int offOut = rowsLeft.get(i) * commonDim; + final int leftOff = i; for(int j = 0; j < commonDim; j++) { // cols in left side skipping empty from identity - resV[offOut + colsRight.get(j)] += left[leftOff + j]; + resV[offOut + colsRight.get(j)] += left[leftOff + j * leftSide]; } } } @@ -736,20 +727,11 @@ public void TSMMToUpperTriangleSparseScaling(SparseBlock left, IColIndex rowsLef @Override public boolean equals(IDictionary o) { - if(o instanceof IdentityDictionary) - return ((IdentityDictionary) o).nRowCol == nRowCol; - - MatrixBlock mb = getMBDict().getMatrixBlock(); - if(o instanceof MatrixBlockDictionary) - return mb.equals(((MatrixBlockDictionary) o).getMatrixBlock()); - else if(o instanceof Dictionary) { - if(mb.isInSparseFormat()) - return mb.getSparseBlock().equals(((Dictionary) o)._values, nRowCol); - final double[] dv = mb.getDenseBlockValues(); - return Arrays.equals(dv, ((Dictionary) o)._values); - } - - return false; + if(o instanceof IdentityDictionary && // + ((IdentityDictionary) o).nRowCol == nRowCol && // + ((IdentityDictionary) o).withEmpty == withEmpty) + return true; + return getMBDict().equals(o); } @Override diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/IdentityDictionarySlice.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/IdentityDictionarySlice.java index 343267f5afb..f6f2b86d284 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/IdentityDictionarySlice.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/IdentityDictionarySlice.java @@ -27,15 +27,20 @@ import org.apache.commons.lang3.NotImplementedException; import org.apache.sysds.runtime.DMLRuntimeException; +import org.apache.sysds.runtime.compress.DMLCompressionException; import org.apache.sysds.runtime.compress.colgroup.indexes.IColIndex; +import org.apache.sysds.runtime.data.SparseBlock; import org.apache.sysds.runtime.functionobjects.Builtin; import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import org.apache.sysds.runtime.matrix.operators.BinaryOperator; public class IdentityDictionarySlice extends IdentityDictionary { private static final long serialVersionUID = 2535887782150955098L; + /** Lower index for the slice */ private final int l; + /** Upper index for the slice (not inclusive) */ private final int u; /** @@ -59,9 +64,9 @@ public IdentityDictionarySlice(int nRowCol, boolean withEmpty, int l, int u) { public double[] getValues() { LOG.warn("Should not call getValues on Identity Dictionary"); int nCol = u - l; - double[] ret = new double[nCol * nRowCol]; + double[] ret = new double[nCol * (nRowCol + (withEmpty ? 1 : 0))]; for(int i = l; i < u; i++) { - ret[(i * nCol) + i] = 1; + ret[(i * nCol) + (i - l)] = 1; } return ret; } @@ -193,7 +198,14 @@ public double sumSq(int[] counts, int ncol) { @Override public IDictionary sliceOutColumnRange(int idxStart, int idxEnd, int previousNumberOfColumns) { - throw new NotImplementedException("Slice of identity slice ??? this is getting a bit ridiculous"); + return getMBDict().sliceOutColumnRange(idxStart, idxEnd, previousNumberOfColumns); + } + + @Override + public int getNumberOfValues(int ncol) { + if(ncol != u - l) + throw new DMLCompressionException("Invalid call to get Number of values assuming wrong number of columns"); + return nRowCol + (withEmpty ? 1 : 0); } @Override @@ -244,17 +256,24 @@ public IDictionary scaleTuples(int[] scaling, int nCol) { public void write(DataOutput out) throws IOException { out.writeByte(DictionaryFactory.Type.IDENTITY_SLICE.ordinal()); out.writeInt(nRowCol); + out.writeBoolean(withEmpty); out.writeInt(l); out.writeInt(u); } - public static IdentityDictionary read(DataInput in) throws IOException { - return new IdentityDictionary(in.readInt()); + public static IdentityDictionarySlice read(DataInput in) throws IOException { + + int nRowCol = in.readInt(); + boolean empty = in.readBoolean(); + int l = in.readInt(); + int u = in.readInt(); + + return new IdentityDictionarySlice(nRowCol, empty, l, u); } @Override public long getExactSizeOnDisk() { - return 1 + 4 * 3; + return 1 + 4 * 3 + 1; } @Override @@ -275,7 +294,12 @@ public IDictionary replaceWithReference(double pattern, double replace, double[] @Override public double getSparsity() { - return 1d / nRowCol; + return (double) (u - l) / ((u -l) * (nRowCol + (withEmpty ? 1 : 0))); + } + + @Override + public IDictionary binOpRight(BinaryOperator op, double[] v, IColIndex colIndexes) { + return getMBDict().binOpRight(op, v); } @Override @@ -287,12 +311,13 @@ public IDictionary preaggValuesFromDense(final int numVals, final IColIndex colI @Override public void addToEntryVectorized(double[] v, int f1, int f2, int f3, int f4, int f5, int f6, int f7, int f8, int t1, int t2, int t3, int t4, int t5, int t6, int t7, int t8, int nCol) { - throw new NotImplementedException(); + getMBDict().addToEntryVectorized(v, f1, f2, f3, f4, f5, f6, f7, f8, t1, t2, t3, t4, t5, t6, t7, t8, nCol); } @Override public void addToEntry(final double[] v, final int fr, final int to, final int nCol, int rep) { - throw new NotImplementedException(); + if(fr >= l && fr < u) + v[to * nCol + fr - l] += rep; } @Override @@ -303,17 +328,23 @@ public boolean equals(IDictionary o) { } else if(o instanceof IdentityDictionary) return false; - MatrixBlock mb = getMBDict().getMatrixBlock(); - if(o instanceof MatrixBlockDictionary) - return mb.equals(((MatrixBlockDictionary) o).getMatrixBlock()); - else if(o instanceof Dictionary) { - if(mb.isInSparseFormat()) - return mb.getSparseBlock().equals(((Dictionary) o)._values, nRowCol); - final double[] dv = mb.getDenseBlockValues(); - return Arrays.equals(dv, ((Dictionary) o)._values); + else + return getMBDict().equals(o); + } + + @Override + public MatrixBlockDictionary getMBDict() { + final int nCol = u - l; + MatrixBlock mb = new MatrixBlock(nRowCol + (withEmpty ? 1 : 0), nCol, true); + mb.allocateSparseRowsBlock(); + + SparseBlock sb = mb.getSparseBlock(); + for(int i = l; i < u; i++) { + sb.append(i, i - l, 1); } - return false; + mb.setNonZeros(nCol); + return new MatrixBlockDictionary(mb); } } diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/MatrixBlockDictionary.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/MatrixBlockDictionary.java index 5bc5f0605ef..a3e4af9d8fa 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/MatrixBlockDictionary.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/MatrixBlockDictionary.java @@ -2166,6 +2166,9 @@ else if(o instanceof Dictionary) { final double[] dv = _data.getDenseBlockValues(); return Arrays.equals(dv, ((Dictionary) o)._values); } + else if(o instanceof IdentityDictionary) { + return o.equals(this); + } return false; } diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/PlaceHolderDict.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/PlaceHolderDict.java index 51c41ffeec6..88a7be26194 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/PlaceHolderDict.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/PlaceHolderDict.java @@ -524,5 +524,9 @@ public void MMDictScalingSparse(SparseBlock left, IColIndex rowsLeft, IColIndex int[] scaling) { throw new RuntimeException(errMessage); } - + + @Override + public void put(SparseBlock sb, int idx, int rowOut, int nCol, IColIndex columns) { + throw new RuntimeException(errMessage); + } } diff --git a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibDecompress.java b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibDecompress.java index 55cede1ae1f..e77db7cad7b 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibDecompress.java +++ b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibDecompress.java @@ -169,6 +169,11 @@ private static MatrixBlock decompressExecute(CompressedMatrixBlock cmb, int k) { LOG.warn("Decompressing into dense but reallocating after to sparse: overlapping - " + overlapping + ", filter - " + shouldFilter); } + else{ + MatrixBlock tmp = new MatrixBlock(); + tmp.copy( ret); + ret = tmp; + } final int blklen = Math.max(nRows / k, 512); diff --git a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibLeftMultBy.java b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibLeftMultBy.java index 6c24e6ad97e..8812c6d4cc5 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibLeftMultBy.java +++ b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibLeftMultBy.java @@ -20,8 +20,7 @@ package org.apache.sysds.runtime.compress.lib; import java.util.ArrayList; -import java.util.Collections; -import java.util.Comparator; +import java.util.Arrays; import java.util.List; import java.util.concurrent.Callable; import java.util.concurrent.ExecutionException; @@ -30,11 +29,11 @@ import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; -import org.apache.sysds.runtime.DMLRuntimeException; import org.apache.sysds.runtime.compress.CompressedMatrixBlock; import org.apache.sysds.runtime.compress.DMLCompressionException; import org.apache.sysds.runtime.compress.colgroup.AColGroup; import org.apache.sysds.runtime.compress.colgroup.APreAgg; +import org.apache.sysds.runtime.compress.colgroup.dictionary.IdentityDictionary; import org.apache.sysds.runtime.data.DenseBlock; import org.apache.sysds.runtime.data.SparseBlock; import org.apache.sysds.runtime.functionobjects.Plus; @@ -44,10 +43,14 @@ import org.apache.sysds.runtime.matrix.data.MatrixBlock; import org.apache.sysds.runtime.matrix.operators.BinaryOperator; import org.apache.sysds.runtime.util.CommonThreadPool; +import org.apache.sysds.utils.stats.Timing; public final class CLALibLeftMultBy { private static final Log LOG = LogFactory.getLog(CLALibLeftMultBy.class.getName()); + // /** Reusable cache intermediate double array for temporary lmm */ + // private static ThreadLocal> cacheIntermediate = null; + private CLALibLeftMultBy() { // private constructor } @@ -91,11 +94,16 @@ public static MatrixBlock leftMultByMatrixTransposed(CompressedMatrixBlock right */ public static MatrixBlock leftMultByMatrixTransposed(CompressedMatrixBlock right, CompressedMatrixBlock left, MatrixBlock ret, int k) { - if(left.isEmpty() || right.isEmpty()) - return prepareEmptyReturnMatrix(right, left, ret, true); - ret = prepareReturnMatrix(right, left, ret, true); - leftMultByCompressedTransposedMatrix(right, left, ret, k); - return ret; + try { + if(left.isEmpty() || right.isEmpty()) + return prepareEmptyReturnMatrix(right, left, ret, true); + ret = prepareReturnMatrix(right, left, ret, true); + leftMultByCompressedTransposedMatrix(right, left, ret, k); + return ret; + } + catch(Exception e) { + throw new DMLCompressionException("Failed CLA Compressed Transposed LMM", e); + } } /** @@ -111,11 +119,24 @@ public static MatrixBlock leftMultByMatrixTransposed(CompressedMatrixBlock right * @return The result of the matrix multiplication */ public static MatrixBlock leftMultByMatrix(CompressedMatrixBlock right, MatrixBlock left, MatrixBlock ret, int k) { - if(left.isEmpty() || right.isEmpty()) - return prepareEmptyReturnMatrix(right, left, ret, false); - ret = prepareReturnMatrix(right, left, ret, false); - ret = LMM(right.getColGroups(), left, ret, k, right.isOverlapping()); - return ret; + try { + // return LibMatrixMult.matrixMult(left, right.getUncompressed(), ret, k); // uncompressed example + + if(left.isEmpty() // + || right.isEmpty()) + return prepareEmptyReturnMatrix(right, left, ret, false); + + if(CLALibSelectionMult.isSelectionMatrix(left)) + return CLALibSelectionMult.leftSelection(right, left, ret, k); + + ret = prepareReturnMatrix(right, left, ret, false); + ret = LMM(right.getColGroups(), left, ret, k, right.isOverlapping()); + + return ret; + } + catch(Exception e) { + throw new DMLCompressionException("Failed CLA LMM", e); + } } private static MatrixBlock prepareEmptyReturnMatrix(MatrixBlock m1, MatrixBlock m2, MatrixBlock ret, @@ -142,15 +163,15 @@ else if(!(ret.getNumColumns() == numColumnsOutput && ret.getNumRows() == numRows } private static MatrixBlock leftMultByCompressedTransposedMatrix(CompressedMatrixBlock right, - CompressedMatrixBlock left, final MatrixBlock ret, int k) { - if(k > 1 && ret.getInMemorySize() < 1000000) + CompressedMatrixBlock left, final MatrixBlock ret, int k) throws Exception { + if(k > 1) return leftMultByCompressedTransposedMatrixParallel(right, left, ret, k); else return leftMultByCompressedTransposedMatrixSingleThread(right, left, ret); } private static MatrixBlock leftMultByCompressedTransposedMatrixParallel(CompressedMatrixBlock right, - CompressedMatrixBlock left, final MatrixBlock ret, int k) { + CompressedMatrixBlock left, final MatrixBlock ret, int k) throws Exception { final int sd = right.getNumRows(); // shared dim final int cr = right.getNumColumns(); @@ -175,7 +196,6 @@ private static MatrixBlock leftMultByCompressedTransposedMatrixParallel(Compress try { final List> t = new ArrayList<>(); - for(int j = 0; j < fLeft.size(); j++) { final int jj = j; t.add(pool.submit(() -> { @@ -189,31 +209,32 @@ private static MatrixBlock leftMultByCompressedTransposedMatrixParallel(Compress })); } - final double[] retV = ret.getDenseBlockValues(); if(containsLeft && containsRight) // if both -- multiply the left and right vectors scaling by number of shared dim - outerProductWithScaling(cL, cR, sd, retV); + outerProductWithScaling(cL, cR, sd, ret); if(containsLeft) // if left -- multiply left with right sum - outerProduct(cL, CLALibUtils.getColSum(fRight, cr, sd), retV); + for(Future f : outerProductParallelTasks(cL, CLALibUtils.getColSum(fRight, cr, sd), ret, pool)) + f.get(); + if(containsRight)// if right -- multiply right with left sum - outerProduct(CLALibUtils.getColSum(fLeft, rl, sd), cR, retV); + for(Future f : outerProductParallelTasks(CLALibUtils.getColSum(fLeft, rl, sd), cR, ret, pool)) + f.get(); for(Future f : t) { MatrixBlock mb = f.get(); if(!mb.isEmpty()) { if(mb.isInSparseFormat()) LibMatrixBincell.bincellOpInPlaceRight(ret, mb, new BinaryOperator(Plus.getPlusFnObject())); - else if(mb.getDenseBlock().isContiguous()) + else if(mb.getDenseBlock().isContiguous()) { + final double[] retV = ret.getDenseBlockValues(); LibMatrixMult.vectAdd(mb.getDenseBlockValues(), retV, 0, 0, retV.length); + } else LibMatrixBincell.bincellOpInPlaceRight(ret, mb, new BinaryOperator(Plus.getPlusFnObject())); } } ret.recomputeNonZeros(k); } - catch(Exception e) { - throw new DMLCompressionException("Failed parallel Left Compressed Mult", e); - } finally { pool.shutdown(); } @@ -243,171 +264,380 @@ private static MatrixBlock leftMultByCompressedTransposedMatrixSingleThread(Comp for(int j = 0; j < fLeft.size(); j++) for(int i = 0; i < fRight.size(); i++) fRight.get(i).leftMultByAColGroup(fLeft.get(j), ret, sd); - final double[] retV = ret.getDenseBlockValues(); + if(containsLeft && containsRight) // if both -- multiply the left and right vectors scaling by number of shared dim - outerProductWithScaling(cL, cR, sd, retV); + outerProductWithScaling(cL, cR, sd, ret); if(containsLeft) // if left -- multiply left with right sum - outerProduct(cL, CLALibUtils.getColSum(fRight, cr, sd), retV); + outerProductSingleThread(cL, CLALibUtils.getColSum(fRight, cr, sd), ret); if(containsRight)// if right -- multiply right with left sum - outerProduct(CLALibUtils.getColSum(fLeft, rl, sd), cR, retV); + outerProductSingleThread(CLALibUtils.getColSum(fLeft, rl, sd), cR, ret); + ret.recomputeNonZeros(); return ret; } private static MatrixBlock LMM(List colGroups, MatrixBlock that, MatrixBlock ret, int k, - boolean overlapping) { + boolean overlapping) throws Exception { final int numColumnsOut = ret.getNumColumns(); final int lr = that.getNumRows(); final boolean shouldFilter = CLALibUtils.shouldPreFilter(colGroups); final List noPreAggGroups = new ArrayList<>(); final List preAggGroups = new ArrayList<>(); + if(shouldFilter) { - final double[] constV = new double[numColumnsOut]; + // Timing t = new Timing(); + final double[] constV; + // if(CLALibUtils.alreadyPreFiltered(colGroups, ret.getNumColumns())) { + // constV = CLALibUtils.filterGroupsAndSplitPreAggOneConst(colGroups, noPreAggGroups, preAggGroups); + // } + // else { + constV = new double[numColumnsOut]; // millions of columns... CLALibUtils.filterGroupsAndSplitPreAgg(colGroups, constV, noPreAggGroups, preAggGroups); - // Sort so that the big expensive preAgg groups are first. - Collections.sort(preAggGroups, Comparator.comparing(AColGroup::getNumValues).reversed()); + // } - double[] rowSums; + // final double filterGroupsTime = t.stop(); + + // Sort so that the big expensive preAgg groups are first to balance threads + // if(k * 2 < colGroups.size()) + // Collections.sort(preAggGroups, Comparator.comparing(AColGroup::getNumValues).reversed()); + + final double[] rowSums; if(!noPreAggGroups.isEmpty() || !preAggGroups.isEmpty()) { final int sizeSum = preAggGroups.size() + noPreAggGroups.size(); rowSums = new double[lr]; if(k == 1 || sizeSum == 1) - LMMTaskExec(noPreAggGroups, preAggGroups, that, ret, 0, lr, rowSums, k); + LMMTaskExec(noPreAggGroups, preAggGroups, that, ret, 0, lr, rowSums); else LMMParallel(noPreAggGroups, preAggGroups, that, ret, rowSums, overlapping, k); } else rowSums = that.rowSum(k).getDenseBlockValues(); + // final double multTime = t.stop(); + // add the correction layer for the subtracted common values. if(rowSums != null) { if(ret.isEmpty()) ret.allocateDenseBlock(); else ret.sparseToDense(); - - outerProduct(rowSums, constV, ret.getDenseBlockValues()); + outerProduct(rowSums, constV, ret, k); } + + // final double outerProd = t.stop(); + // if(LOG.isDebugEnabled()) { + // LOG.debug(String.format("LLM: filter: %10f Mult: %10f outer: %10f", filterGroupsTime, multTime, outerProd)); + // } } else { CLALibUtils.splitPreAgg(colGroups, noPreAggGroups, preAggGroups); // Sort so that the big expensive preAgg groups are first. - Collections.sort(preAggGroups, Comparator.comparing(AColGroup::getNumValues).reversed()); + // Collections.sort(preAggGroups, Comparator.comparing(AColGroup::getNumValues).reversed()); if(k == 1 || colGroups.size() == 1) - LMMTaskExec(noPreAggGroups, preAggGroups, that, ret, 0, lr, null, k); + LMMTaskExec(noPreAggGroups, preAggGroups, that, ret, 0, lr, null); else LMMParallel(noPreAggGroups, preAggGroups, that, ret, null, overlapping, k); + } ret.recomputeNonZeros(k); - ret.examSparsity(); + ret.examSparsity(k); return ret; } private static void LMMParallel(List npa, List pa, MatrixBlock that, MatrixBlock ret, - double[] rowSums, boolean overlapping, int k) { + double[] rowSums, boolean overlapping, int k) throws Exception { final ExecutorService pool = CommonThreadPool.get(k); try { - final ArrayList> tasks = new ArrayList<>(); - - final int rl = that.getNumRows(); - final int rowBlockSize = Math.max(rl / k, 1); final int nG = npa.size() + pa.size(); + final boolean useTmp = (overlapping && nG > 1) // + || (nG * 2 < k && ret.getNumColumns() < 1000); - final boolean useTmp = overlapping && nG > 1; // skip value to parallelize the pa groups without allocating new arrays - final int s = Math.min(pa.size(), k); - if(!useTmp) { - // Put results directly into ret - for(int blo = 0; blo < rl; blo += rowBlockSize) { - final int end = Math.min(blo + rowBlockSize, rl); - - for(AColGroup g : npa) // all groups get their own task - tasks.add(new LMMNoPreAggTask(g, that, ret, blo, end)); - - for(int off = 0; off < s; off++) { // only allocate k tasks at max - if(off == s - 1) - tasks.add(new LMMPreAggTask(pa, that, ret, blo, end, off, s, rowSums, 1)); - else - tasks.add(new LMMPreAggTask(pa, that, ret, blo, end, off, s, null, 1)); - } + if(!useTmp) + LMMParallelNoTempOut(npa, pa, that, ret, rowSums, overlapping, k, pool); + else + LMMParallelTempOut(npa, pa, that, ret, rowSums, overlapping, k, pool); + } + finally { + pool.shutdown(); + } + } - if(pa.isEmpty() && rowSums != null) // row sums task - tasks.add(new LMMRowSums(that, blo, end, rowSums)); + private static void LMMParallelNoTempOut(List npa, List pa, MatrixBlock that, MatrixBlock ret, + double[] rowSums, boolean overlapping, int k, ExecutorService pool) throws Exception { - } + final int s = Math.min(pa.size(), k); + final int rt = that.getNumRows(); + final int ct = that.getNumColumns(); + final int rowBlockSize = Math.max(rt / k, 1); + + // skip value to parallelize the pa groups without allocating new arrays + + final ArrayList> tasks = new ArrayList<>(); + // Put results directly into ret + for(int blo = 0; blo < rt; blo += rowBlockSize) { + final int start = blo; + final int end = Math.min(blo + rowBlockSize, rt); + LLMNoTempOutRowBlockTasks(npa, pa, that, ret, rowSums, pool, s, ct, tasks, start, end, k); + } + + for(Future future : tasks) + future.get(); + + } - for(Future future : pool.invokeAll(tasks)) - future.get(); + private static void LLMNoTempOutRowBlockTasks(List npa, List pa, MatrixBlock that, + MatrixBlock ret, double[] rowSums, ExecutorService pool, final int s, final int ct, + final ArrayList> tasks, final int start, final int end, int k) { + for(AColGroup g : npa) // all non aggregate groups task + noTmpNoAggGroups(that, ret, pool, ct, tasks, start, end, g, k); + + for(int off = 0; off < s; off++) { + // all pre-aggregate group tasks + // s ensure that there is no more than k number of tasks. + final int offT = off; + tasks.add(pool.submit(() -> LMMWithPreAgg(pa, that, ret, start, end, 0, ct, offT, s, null))); + } + + if(rowSums != null) // row sums task + tasks.add(pool.submit(() -> rowSum(that, rowSums, start, end, 0, ct))); + } + + private static void noTmpNoAggGroups(MatrixBlock that, MatrixBlock ret, ExecutorService pool, final int ct, + final ArrayList> tasks, final int start, final int end, AColGroup g, int k) { + final List> npaSubTask = new ArrayList<>(); + final int retNRow = ret.getNumRows(); + final int retNCol = ret.getNumColumns(); + if(retNCol < 1000000) { + + final int colBlockSize = Math.max(ct / Math.max(k, 2), 64000); + + for(int bloC = 0; bloC < ct; bloC += colBlockSize) { + final int startC = bloC; + final int endC = Math.min(bloC + colBlockSize, ct); + npaSubTask.add(pool.submit(() -> { + Timing t = new Timing(); + final double[] tmp = new double[retNRow * retNCol]; + final MatrixBlock tmpBlock = new MatrixBlock(retNRow, retNCol, tmp); + g.leftMultByMatrixNoPreAgg(that, tmpBlock, start, end, startC, endC); + LOG.debug("noPreAggTiming: " + t); + return tmpBlock; + })); } - else { - // allocate temp - final int nCol = ret.getNumColumns(); - final int nRow = ret.getNumRows(); - for(int blo = 0; blo < rl; blo += rowBlockSize) { - final int end = Math.min(blo + rowBlockSize, rl); - - for(AColGroup g : npa) // all groups get their own task - tasks.add(new LMMNoPreAggTask(g, that, nRow, nCol, blo, end)); - - for(int off = 0; off < s; off++) { // only allocate k tasks at max - if(off == s - 1) - tasks.add(new LMMPreAggTask(pa, that, nRow, nCol, blo, end, off, s, rowSums, 1)); - else - tasks.add(new LMMPreAggTask(pa, that, nRow, nCol, blo, end, off, s, null, 1)); - } - if(pa.isEmpty() && rowSums != null) // row sums task - tasks.add(new LMMRowSums(that, blo, end, rowSums)); + tasks.add(pool.submit(() -> addInPlaceFuture(ret, npaSubTask))); + } + else { + tasks.add(pool.submit(() -> g.leftMultByMatrixNoPreAgg(that, ret, start, end, 0, ct))); + } + } - } + private static Object addInPlaceFuture(MatrixBlock ret, List> npaSubTask) throws Exception { + for(Future f : npaSubTask) + addInPlace(f.get(), ret); + return null; + } - BinaryOperator op = new BinaryOperator(Plus.getPlusFnObject()); - for(Future future : pool.invokeAll(tasks)) { - MatrixBlock mb = future.get(); - mb.examSparsity(); - ret.binaryOperationsInPlace(op, mb); + private static void LMMParallelTempOut(List npa, List pa, MatrixBlock that, MatrixBlock ret, + double[] rowSums, boolean overlapping, int k, ExecutorService pool) throws Exception { + + final int rt = that.getNumRows(); + final int ct = that.getNumColumns(); + // perfect parallel over rows left. + final int rowBlockSize = Math.max(rt / k, 1); + final int threadsUsedOnRows = (int) Math.ceil((double) rt / rowBlockSize); + k = Math.max(1, k / threadsUsedOnRows); + // parallel over column blocks ... should be bigger than largest distinct. + // final int colBlockSize = Math.max(ct, 1); + final int s = Math.min(npa.size() + pa.size(), k); + k = Math.max(1, k / s); // + + // We set it to minimum 4k + final int colBlockSize = Math.max(ct / k, 64000); + final int threadsUsedOnColBlocks = (int) Math.ceil((double) ct / colBlockSize); + k = k / threadsUsedOnColBlocks; + + final ArrayList> tasks = new ArrayList<>(); + // allocate temp + final int retCols = ret.getNumColumns(); + final int retRows = ret.getNumRows(); + for(int blo = 0; blo < rt; blo += rowBlockSize) { + final int start = blo; + final int end = Math.min(blo + rowBlockSize, rt); + + for(AColGroup g : npa) // all groups get their own task + tasks.add(pool.submit(new LMMNoPreAggTask(g, that, retRows, retCols, start, end))); + + for(int off = 0; off < s; off++) { // only allocate k tasks at max + final int offT = off; + + if(that.isInSparseFormat()) { + tasks.add(pool.submit(new LMMPreAggTask(pa, that, retRows, retCols, start, end, 0, ct, offT, s, null))); + } + else { + for(int bloC = 0; bloC < ct; bloC += colBlockSize) { + final int startC = bloC; + final int endC = Math.min(startC + colBlockSize, ct); + tasks.add(pool + .submit(new LMMPreAggTask(pa, that, retRows, retCols, start, end, startC, endC, offT, s, null))); + } } } + if(rowSums != null) // row sums task + tasks.add(pool.submit(new LMMRowSums(that, start, end, rowSums))); } - catch(InterruptedException | ExecutionException e) { - throw new DMLRuntimeException(e); - } - finally{ - pool.shutdown(); + + addInPlaceFuture(ret, tasks); + } + + private static Object addInPlace(MatrixBlock a, MatrixBlock out) throws Exception { + if(a != null) { + final DenseBlock dba = a.getDenseBlock(); + final DenseBlock dbb = out.getDenseBlock(); + final int blocks = dba.numBlocks(); + for(int b = 0; b < blocks; b++) { + final double[] av = dba.valuesAt(b); + final double[] bv = dbb.valuesAt(b); + final int len = av.length; + for(int i = 0; i < len; i++) { + bv[i] += av[i]; + } + } } + return null; } private static void LMMTaskExec(List npa, List pa, MatrixBlock that, MatrixBlock ret, int rl, - int ru, double[] rowSums, int k) { + int ru, double[] rowSums) throws Exception { + final int cu = that.getNumColumns(); if(npa.isEmpty() && pa.isEmpty()) { - rowSum(that, rowSums, rl, ru, 0, that.getNumColumns()); + rowSum(that, rowSums, rl, ru, 0, cu); return; } for(int r = rl; r < ru; r += 4) { final int re = Math.min(r + 4, ru); // Process MMs. - for(int i = 0; i < npa.size(); i++) - LMMNoPreAgg(npa.get(i), that, ret, r, re); - + for(int i = 0; i < npa.size(); i++) { + npa.get(i).leftMultByMatrixNoPreAgg(that, ret, r, re, 0, cu); + } if(pa.size() > 0) - LMMWithPreAgg(pa, that, ret, r, re, 0, 1, rowSums, k); + LMMWithPreAgg(pa, that, ret, r, re, 0, cu, 0, 1, rowSums); } } - private static void outerProduct(final double[] leftRowSum, final double[] rightColumnSum, final double[] result) { - for(int row = 0; row < leftRowSum.length; row++) { + private static void outerProduct(final double[] leftRowSum, final double[] rightColumnSum, final MatrixBlock result, + int k) throws InterruptedException, ExecutionException { + if(k > 1) + outerProductParallel(leftRowSum, rightColumnSum, result, k); + else + outerProductSingleThread(leftRowSum, rightColumnSum, result); + } + + private static void outerProductParallel(final double[] leftRowSum, final double[] rightColumnSum, + final MatrixBlock result, int k) throws InterruptedException, ExecutionException { + final ExecutorService pool = CommonThreadPool.get(k); + try { + for(Future t : outerProductParallelTasks(leftRowSum, rightColumnSum, result, pool)) + t.get(); + } + finally { + pool.shutdown(); + } + } + + private static void outerProductRange(final double[] leftRowSum, final double[] rightColumnSum, + final MatrixBlock result, int rl, int ru, int cl, int cu) { + if(result.getDenseBlock().isContiguous()) + outerProductRangeContiguous(leftRowSum, rightColumnSum, result.getDenseBlockValues(), rl, ru, cl, cu); + else + outerProductRangeGeneric(leftRowSum, rightColumnSum, result.getDenseBlock(), rl, ru, cl, cu); + } + + private static void outerProductRangeContiguous(final double[] leftRowSum, final double[] rightColumnSum, + final double[] result, int rl, int ru, int cl, int cu) { + for(int row = rl; row < ru; row++) { final int offOut = rightColumnSum.length * row; final double vLeft = leftRowSum[row]; - for(int col = 0; col < rightColumnSum.length; col++) { - result[offOut + col] += vLeft * rightColumnSum[col]; + if(vLeft != 0) { + for(int col = cl; col < cu; col++) { + result[offOut + col] += vLeft * rightColumnSum[col]; + } + } + } + } + + private static void outerProductRangeGeneric(final double[] leftRowSum, final double[] rightColumnSum, + final DenseBlock res, int rl, int ru, int cl, int cu) { + for(int row = rl; row < ru; row++) { + final int offOut = res.pos(row); + final double[] result = res.values(row); + final double vLeft = leftRowSum[row]; + if(vLeft != 0) { + for(int col = cl; col < cu; col++) { + result[offOut + col] += vLeft * rightColumnSum[col]; + } } } } + private static void outerProductSingleThread(final double[] leftRowSum, final double[] rightColumnSum, + MatrixBlock result) { + final int blkz = 1024; + for(int row = 0; row < leftRowSum.length; row += blkz) { + final int rl = row; + final int ru = Math.min(leftRowSum.length, row + blkz); + final int colBz = outerProdGetColBz(blkz, row, rl, ru); + + for(int col = 0; col < rightColumnSum.length; col += colBz) { + final int cl = col; + final int cu = Math.min(rightColumnSum.length, col + colBz); + outerProductRange(leftRowSum, rightColumnSum, result, rl, ru, cl, cu); + } + } + } + + private static List> outerProductParallelTasks(final double[] leftRowSum, final double[] rightColumnSum, + final MatrixBlock result, ExecutorService pool) { + // windows of 1024 each + final int blkz = 1024; + final List> tasks = new ArrayList<>(); + for(int row = 0; row < leftRowSum.length; row += blkz) { + final int rl = row; + final int ru = Math.min(leftRowSum.length, row + blkz); + final int colBz = outerProdGetColBz(blkz, row, rl, ru); + + for(int col = 0; col < rightColumnSum.length; col += colBz) { + final int cl = col; + final int cu = Math.min(rightColumnSum.length, col + colBz); + tasks.add(pool.submit(() -> { + outerProductRange(leftRowSum, rightColumnSum, result, rl, ru, cl, cu); + })); + } + } + return tasks; + } + + private static int outerProdGetColBz(final int blkz, int row, final int rl, final int ru) { + final int colBz; + if(ru < row + blkz) + colBz = 1024 * 1024 - ((ru - rl) * 1024) + 1024; + else + colBz = blkz; + return colBz; + } + private static void outerProductWithScaling(final double[] leftRowSum, final double[] rightColumnSum, + final int scaling, final MatrixBlock result) { + if(result.getDenseBlock().isContiguous()) + outerProductWithScalingContiguous(leftRowSum, rightColumnSum, scaling, result.getDenseBlockValues()); + else + outerProductWithScalingGeneric(leftRowSum, rightColumnSum, scaling, result.getDenseBlock()); + } + + private static void outerProductWithScalingContiguous(final double[] leftRowSum, final double[] rightColumnSum, final int scaling, final double[] result) { for(int row = 0; row < leftRowSum.length; row++) { final int offOut = rightColumnSum.length * row; @@ -418,103 +648,175 @@ private static void outerProductWithScaling(final double[] leftRowSum, final dou } } - private static void LMMNoPreAgg(AColGroup g, MatrixBlock that, MatrixBlock ret, int rl, int ru) { - g.leftMultByMatrixNoPreAgg(that, ret, rl, ru, 0, that.getNumColumns()); + private static void outerProductWithScalingGeneric(final double[] leftRowSum, final double[] rightColumnSum, + final int scaling, final DenseBlock res) { + for(int row = 0; row < leftRowSum.length; row++) { + final int offOut = res.pos(row); + final double[] result = res.values(row); + final double vLeft = leftRowSum[row] * scaling; + for(int col = 0; col < rightColumnSum.length; col++) { + result[offOut + col] += vLeft * rightColumnSum[col]; + } + } } - private static void LMMWithPreAgg(List preAggCGs, MatrixBlock that, MatrixBlock ret, int rl, int ru, - int off, int skip, double[] rowSums, int k) { - if(!that.isInSparseFormat()) - LMMWithPreAggDense(preAggCGs, that, ret, rl, ru, off, skip, rowSums); - else - LMMWithPreAggSparse(preAggCGs, that, ret, rl, ru, off, skip, rowSums); + private static void LMMWithPreAgg(List preAggCGs, MatrixBlock that, MatrixBlock ret, int rl, int ru, int cl, + int cu, int off, int skip, double[] rowSums) { + try { + if(!that.isInSparseFormat()) + LMMWithPreAggDense(preAggCGs, that, ret, rl, ru, cl, cu, off, skip, rowSums); + else + LMMWithPreAggSparse(preAggCGs, that, ret, rl, ru, cl, cu, off, skip, rowSums); + } + catch(Exception e) { + throw new RuntimeException("Failed LLM pre aggregate", e); + } } private static void LMMWithPreAggSparse(List preAggCGs, MatrixBlock that, MatrixBlock ret, int rl, int ru, - int off, int skip, double[] rowSum) { - // row multiplication - final MatrixBlock tmpRes = new MatrixBlock(1, ret.getNumColumns(), false); - final int maxV = preAggCGs.get(off).getNumValues(); - final MatrixBlock preA = new MatrixBlock(1, maxV, false); - // final DenseBlock db = preA.getDenseBlock(); - preA.allocateDenseBlock(); - final double[] preAV = preA.getDenseBlockValues(); - tmpRes.allocateDenseBlock(); + int cl, int cu, int off, int skip, double[] rowSum) throws Exception { + + final MatrixBlock preA = new MatrixBlock(); + final MatrixBlock fTmp = new MatrixBlock(); final SparseBlock sb = that.getSparseBlock(); + for(int j = off; j < preAggCGs.size(); j += skip) { // selected column groups for this thread. + final int nCol = preAggCGs.get(j).getNumCols(); + final int nVal = preAggCGs.get(j).getNumValues(); + final APreAgg g = preAggCGs.get(j); - for(int j = off; j < preAggCGs.size(); j += skip) { for(int r = rl; r < ru; r++) { - if(sb.isEmpty(r)) - continue; - final int rcu = r + 1; - final int nCol = preAggCGs.get(j).getNumCols(); - final int nVal = preAggCGs.get(j).getNumValues(); - if(nCol == 1 || (sb.size(r) * nCol < sb.size(r) + nCol * nVal)) - LMMNoPreAgg(preAggCGs.get(j), that, ret, r, rcu); - else { - final APreAgg g = preAggCGs.get(j); - preA.reset(1, g.getPreAggregateSize(), false); - g.preAggregateSparse(sb, preAV, r, rcu); - g.mmWithDictionary(preA, tmpRes, ret, 1, r, rcu); - } + preAggSparseRow(that, ret, cl, cu, preA, fTmp, sb, nCol, nVal, g, r); } } - rowSumSparse(that.getSparseBlock(), rowSum, rl, ru, 0, that.getNumColumns()); + if(rowSum != null) + rowSumSparse(that.getSparseBlock(), rowSum, rl, ru, cl, cu); + + } + + private static void preAggSparseRow(MatrixBlock that, MatrixBlock ret, int cl, int cu, final MatrixBlock preA, + final MatrixBlock fTmp, final SparseBlock sb, final int nCol, final int nVal, final APreAgg g, int r) { + if(sb.isEmpty(r)) + return; + final int rcu = r + 1; + + // if(sb.size(r) * nCol < sb.size(r) + (long) nCol * nVal) { + // g.leftMultByMatrixNoPreAgg(that, ret, r, rcu, cl, cu); + // } + // else { + if(!preA.isAllocated()) { + preA.reset(1, nVal); + preA.allocateDenseBlock(); + } + else + preA.reset(1, nVal); + allocateOrResetTmpRes(ret, fTmp, 1); + + final double[] preAV = preA.getDenseBlockValues(); + preA.setNonZeros(g.getPreAggregateSize()); + fTmp.setNonZeros(1); + g.preAggregateSparse(sb, preAV, r, rcu, cl, cu); + g.mmWithDictionary(preA, fTmp, ret, 1, r, rcu); + // } + } + + private static void allocateOrResetTmpRes(final MatrixBlock ret, final MatrixBlock fTmp, int rows) { + if(!fTmp.isAllocated()) { + fTmp.reset(rows, ret.getNumColumns()); + fTmp.allocateDenseBlock(); + } + else + fTmp.reset(rows, ret.getNumColumns()); } - private static void LMMWithPreAggDense(List preAggCGs, MatrixBlock that, MatrixBlock ret, int rl, int ru, - int off, int skip, double[] rowSum) { + private static void LMMWithPreAggDense(final List preAggCGs, final MatrixBlock that, final MatrixBlock ret, + final int rl, final int ru, final int cl, final int cu, final int off, final int skip, final double[] rowSum) + throws InterruptedException, ExecutionException { + // Timing t = new Timing(); + // ExecutorService pool = CommonThreadPool.get(k); /** The column block size for preAggregating column groups */ - final int colBZ = 1024; + // final int colBZ = 1024; + final int colBZ = 2048; + // final int colBZ = Math.max(1024, lc/2); // The number of rows to process together final int rowBlockSize = 4; // The number of column groups to process together // the value should ideally be set so that the colGroups fits into cache together with a row block. // currently we only try to avoid having a dangling small number of column groups in the last block. - // final int colGroupBlocking = preAggCGs.size() ;// % 16 < 4 ? 20 : 16; - final int colGroupBlocking = 8; - // final int colGroupBlocking = 4; + // final int colGroupBlocking = preAggCGs.size();// % 16 < 4 ? 20 : 16; + // final int colGroupBlocking = 8; + final int colGroupBlocking = 4; final int nColGroups = preAggCGs.size(); // Allocate pre Aggregate Array List - final MatrixBlock[] preAgg = populatePreAggregate(colGroupBlocking); + final double[][] preAgg = new double[colGroupBlocking][]; // Allocate temporary Result matrix // guaranteed to be large enough for all groups - final MatrixBlock tmpRes = new MatrixBlock(rowBlockSize, ret.getNumColumns(), false); + MatrixBlock tmpRes = new MatrixBlock(); - final int lc = that.getNumColumns(); // For each row block for(int rlt = rl; rlt < ru; rlt += rowBlockSize) { final int rut = Math.min(rlt + rowBlockSize, ru); // For each column group block for(int gl = off; gl < nColGroups; gl += colGroupBlocking * skip) { final int gu = Math.min(gl + (colGroupBlocking * skip), nColGroups); - // For each column group in the current block allocate the preaggregate array. - for(int j = gl, p = 0; j < gu; j += skip, p++) { - final int preAggNCol = preAggCGs.get(j).getPreAggregateSize(); - preAgg[p].reset(rut - rlt, preAggNCol, false); - } + // For each column group in the current block allocate the pre aggregate array. + // or reset the pre aggregate. + for(int j = gl, p = 0; j < gu; j += skip, p++) + preAllocate(preAggCGs, j, rut, rlt, preAgg, p); - // PreAggregate current block of column groups - for(int cl = 0; cl < lc; cl += colBZ) { - final int cu = Math.min(cl + colBZ, lc); + for(int clt = cl; clt < cu; clt += colBZ) { + final int cut = Math.min(clt + colBZ, cu); for(int j = gl, p = 0; j < gu; j += skip, p++) - preAggCGs.get(j).preAggregateDense(that, preAgg[p].getDenseBlockValues(), rlt, rut, cl, cu); + preAggregate(that, ret, preAggCGs, rut, rlt, clt, cut, j, preAgg, p); if(gu == nColGroups) - rowSum(that, rowSum, rlt, rut, cl, cu); + rowSum(that, rowSum, rlt, rut, clt, cut); } // Multiply out the PreAggregate to the output matrix. for(int j = gl, p = 0; j < gu; j += skip, p++) { final APreAgg cg = preAggCGs.get(j); - final MatrixBlock preAggThis = preAgg[p]; - cg.mmWithDictionary(preAggThis, tmpRes, ret, 1, rlt, rut); + if(cg.getDictionary() instanceof IdentityDictionary) + continue; + + allocateOrResetTmpRes(ret, tmpRes, rowBlockSize); + postMultiply(ret, tmpRes, preAgg, p, cg, rut, rlt); } } } + + // LOG.debug("SingleCallLMMTime: " + t.stop()); + } + + private static void preAllocate(List preAggCGs, int j, int rut, int rlt, double[][] preAgg, int p) { + final APreAgg cg = preAggCGs.get(j); + if(cg.getDictionary() instanceof IdentityDictionary) + return; + final int preAggNCol = cg.getPreAggregateSize(); + + final int len = (rut - rlt) * preAggNCol; + if(preAgg[p] == null || preAgg[p].length < len) + preAgg[p] = new double[len]; + else + Arrays.fill(preAgg[p], 0, (rut - rlt) * preAggNCol, 0); + } + + private static void preAggregate(MatrixBlock that, MatrixBlock ret, List preAggCGs, int rut, int rlt, + int clt, int cut, int j, double[][] preAgg, int p) { + final APreAgg cg = preAggCGs.get(j); + if(cg.getDictionary() instanceof IdentityDictionary) + cg.leftMMIdentityPreAggregateDense(that, ret, rlt, rut, clt, cut); + else + cg.preAggregateDense(that, preAgg[p], rlt, rut, clt, cut); + } + + private static void postMultiply(MatrixBlock ret, MatrixBlock tmpRes, double[][] preAgg, int p, APreAgg cg, int rut, + int rlt) { + final int preAggNCol = cg.getPreAggregateSize(); + final MatrixBlock preAggThis = new MatrixBlock((rut - rlt), preAggNCol, preAgg[p]); + cg.mmWithDictionary(preAggThis, tmpRes, ret, 1, rlt, rut); } public static double[] rowSum(MatrixBlock mb, int rl, int ru, int cl, int cu) { @@ -524,105 +826,98 @@ public static double[] rowSum(MatrixBlock mb, int rl, int ru, int cl, int cu) { } private static void rowSum(MatrixBlock mb, double[] rowSum, int rl, int ru, int cl, int cu) { - if(mb.isInSparseFormat()) + if(mb.isEmpty()) + throw new DMLCompressionException("Invalid empty block to rowsum"); + else if(rowSum == null) // no sum to make since the rowSum result is null. + return; + else if(mb.isInSparseFormat()) rowSumSparse(mb.getSparseBlock(), rowSum, rl, ru, cl, cu); else rowSumDense(mb, rowSum, rl, ru, cl, cu); } private static void rowSumSparse(SparseBlock sb, double[] rowSum, int rl, int ru, int cl, int cu) { - if(rowSum != null) { - for(int i = rl; i < ru; i++) { - if(sb.isEmpty(i)) - continue; - final int apos = sb.pos(i); - final int alen = sb.size(i) + apos; - final double[] aval = sb.values(i); - final int[] aix = sb.indexes(i); - if(cl == 0 && aix[alen - 1] < cu) - for(int j = apos; j < alen; j++) - rowSum[i] += aval[j]; - else { - int j = apos; - while(j < alen && aix[j] < cl) - j++; - while(j < alen && aix[j] < cu) - rowSum[i] += aval[j++]; - } - } - } + for(int i = rl; i < ru; i++) + rowSumSparseSingleRow(sb, rowSum, cl, cu, i); + } + + private static void rowSumSparseSingleRow(SparseBlock sb, double[] rowSum, int cl, int cu, int i) { + if(sb.isEmpty(i)) + return; + final int apos = sb.pos(i); + final int alen = sb.size(i) + apos; + final double[] aval = sb.values(i); + final int[] aix = sb.indexes(i); + int j = apos; + while(j < alen && aix[j] < cl) + j++; + if(aix[alen - 1] < cu) + while(j < alen) + rowSum[i] += aval[j++]; + else + while(j < alen && aix[j] < cu) + rowSum[i] += aval[j++]; } private static void rowSumDense(MatrixBlock that, double[] rowSum, int rl, int ru, int cl, int cu) { - if(rowSum != null) { - final DenseBlock db = that.getDenseBlock(); + + final DenseBlock db = that.getDenseBlock(); + if(db.isContiguous()) { + final double[] thatV = db.values(0); + for(int r = rl; r < ru; r++) + rowSumDenseSingleRow(rowSum, cl, cu, db, thatV, r); + } + else { for(int r = rl; r < ru; r++) { final double[] thatV = db.values(r); - final int rowOff = db.pos(r); - for(int c = rowOff + cl; c < rowOff + cu; c++) - rowSum[r] += thatV[c]; + rowSumDenseSingleRow(rowSum, cl, cu, db, thatV, r); } } + } - private static MatrixBlock[] populatePreAggregate(int colGroupBlocking) { - final MatrixBlock[] preAgg = new MatrixBlock[colGroupBlocking]; - // populate the preAgg array. - for(int j = 0; j < colGroupBlocking; j++) { - final MatrixBlock m = new MatrixBlock(1, 1, false); - m.allocateDenseBlock(); - preAgg[j] = m; - } - return preAgg; + private static void rowSumDenseSingleRow(double[] rowSum, int cl, int cu, final DenseBlock db, final double[] thatV, + int r) { + final int rowOff = db.pos(r); + double tmp = 0; + for(int c = rowOff + cl; c < rowOff + cu; c++) + tmp += thatV[c]; + rowSum[r] += tmp; } private static class LMMPreAggTask implements Callable { private final List _pa; private final MatrixBlock _that; - private final MatrixBlock _ret; + private final int _retR; + private final int _retC; private final int _rl; private final int _ru; + private final int _cl; + private final int _cu; private final double[] _rowSums; private final int _off; private final int _skip; - private final int _k; - protected LMMPreAggTask(List pa, MatrixBlock that, int retR, int retC, int rl, int ru, int off, int skip, - double[] rowSums, int k) { + protected LMMPreAggTask(List pa, MatrixBlock that, int retR, int retC, int rl, int ru, int cl, int cu, + int off, int skip, double[] rowSums) { _pa = pa; _that = that; - _ret = new MatrixBlock(retR, retC, false); - _ret.allocateDenseBlock(); - _rl = rl; - _ru = ru; - _rowSums = rowSums; - _off = off; - _skip = skip; - _k = k; - } - - protected LMMPreAggTask(List pa, MatrixBlock that, MatrixBlock ret, int rl, int ru, int off, int skip, - double[] rowSums, int k) { - _pa = pa; - _that = that; - _ret = ret; + _retR = retR; + _retC = retC; _rl = rl; _ru = ru; + _cl = cl; + _cu = cu; _rowSums = rowSums; _off = off; _skip = skip; - _k = k; } @Override - public MatrixBlock call() { - try { - LMMWithPreAgg(_pa, _that, _ret, _rl, _ru, _off, _skip, _rowSums, _k); - } - catch(Exception e) { - e.printStackTrace(); - throw new DMLRuntimeException(e); - } + public MatrixBlock call() throws Exception { + final double[] tmpArr = new double[_retR * _retC]; + MatrixBlock _ret = new MatrixBlock(_retR, _retC, tmpArr); + LMMWithPreAgg(_pa, _that, _ret, _rl, _ru, _cl, _cu, _off, _skip, _rowSums); return _ret; } } @@ -643,23 +938,9 @@ protected LMMNoPreAggTask(AColGroup cg, MatrixBlock that, int retR, int retC, in _ru = ru; } - protected LMMNoPreAggTask(AColGroup cg, MatrixBlock that, MatrixBlock ret, int rl, int ru) { - _cg = cg; - _that = that; - _ret = ret; - _rl = rl; - _ru = ru; - } - @Override - public MatrixBlock call() { - try { - LMMNoPreAgg(_cg, _that, _ret, _rl, _ru); - } - catch(Exception e) { - e.printStackTrace(); - throw new DMLRuntimeException(e); - } + public MatrixBlock call() throws Exception { + _cg.leftMultByMatrixNoPreAgg(_that, _ret, _rl, _ru, 0, _that.getNumColumns()); return _ret; } } @@ -678,14 +959,11 @@ protected LMMRowSums(MatrixBlock that, int rl, int ru, double[] rowSums) { } @Override - public MatrixBlock call() { - try { + public MatrixBlock call() throws Exception { + if(_that.isInSparseFormat()) + rowSumSparse(_that.getSparseBlock(), _rowSums, _rl, _ru, 0, _that.getNumColumns()); + else rowSumDense(_that, _rowSums, _rl, _ru, 0, _that.getNumColumns()); - } - catch(Exception e) { - e.printStackTrace(); - throw new DMLRuntimeException(e); - } return null; } } diff --git a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibMatrixMult.java b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibMatrixMult.java index 5b392e8b7e0..92594000458 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibMatrixMult.java +++ b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibMatrixMult.java @@ -35,7 +35,7 @@ private CLALibMatrixMult(){ // private constructor } - public static MatrixBlock matrixMult(MatrixBlock m1, MatrixBlock m2, MatrixBlock ret, int k) { + public static MatrixBlock matrixMultiply(MatrixBlock m1, MatrixBlock m2, MatrixBlock ret, int k) { return matrixMultiply(m1, m2, ret, k, false, false); } @@ -117,7 +117,7 @@ else if(!transposeLeft && transposeRight) { // either compressed matrix. } else { - ret = CLALibMatrixMult.matrixMult(m2, m1, ret, k); + ret = CLALibMatrixMult.matrixMultiply(m2, m1, ret, k); ReorgOperator r_op = new ReorgOperator(SwapIndex.getSwapIndexFnObject(), k); return ret.reorgOperations(r_op, new MatrixBlock(), 0, 0, 0); } diff --git a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibSelectionMult.java b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibSelectionMult.java new file mode 100644 index 00000000000..9bac3b254ee --- /dev/null +++ b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibSelectionMult.java @@ -0,0 +1,249 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.runtime.compress.lib; + +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Future; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.apache.sysds.runtime.compress.CompressedMatrixBlock; +import org.apache.sysds.runtime.compress.DMLCompressionException; +import org.apache.sysds.runtime.compress.colgroup.AColGroup; +import org.apache.sysds.runtime.compress.colgroup.AColGroup.CompressionType; +import org.apache.sysds.runtime.compress.colgroup.ColGroupUtils; +import org.apache.sysds.runtime.compress.colgroup.ColGroupUtils.P; +import org.apache.sysds.runtime.compress.utils.IntArrayList; +import org.apache.sysds.runtime.data.DenseBlock; +import org.apache.sysds.runtime.data.SparseBlock; +import org.apache.sysds.runtime.data.SparseBlockCSR; +import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import org.apache.sysds.runtime.util.CommonThreadPool; + +/** + * This lib is responsible for selecting and extracting specific rows or columns from a compressed matrix. + * + * The operation performed is like a left matrix multiplication where the left side only have max 1 non zero per row. + * + */ +public interface CLALibSelectionMult { + public final Log LOG = LogFactory.getLog(CLALibSelectionMult.class.getName()); + + /** + * Left selection where the left matrix is sparse with a max 1 non zero per row and that non zero is a 1. + * + * @param right Right hand side compressed matrix + * @param left Left hand side matrix + * @param ret Output matrix to put the result into. + * @param k The parallelization degree. + * @return The selected rows and columns of the input matrix + */ + public static MatrixBlock leftSelection(CompressedMatrixBlock right, MatrixBlock left, MatrixBlock ret, int k) { + try { + + if(right.getNonZeros() <= -1) + right.recomputeNonZeros(); + + boolean sparseOut = right.getSparsity() < 0.3; + ret = allocateReturn(right, left, ret, sparseOut); + + final List preFilter = right.getColGroups(); + final boolean shouldFilter = CLALibUtils.shouldPreFilter(preFilter); + if(shouldFilter) + filteredLeftSelection(left, ret, k, sparseOut, preFilter); + else + normalLeftSelection(left, ret, k, sparseOut, preFilter); + + ret.recomputeNonZeros(k); + + return ret; + } + catch(Exception e) { + throw new DMLCompressionException("Failed left selection Multiplication", e); + } + } + + /** + * Analyze if the given matrix is a selection matrix if on the left side of a matrix multiplication. + * + * @param mb The given matrix that should be on the left side + * @return If it is selective + */ + public static boolean isSelectionMatrix(MatrixBlock mb) { + // See if the input is potentially only containing one nonzero per row. + if(mb.isEmpty()) + return false; + else if(mb.getNonZeros() <= mb.getNumRows() && mb.isInSparseFormat()) { + + SparseBlock sb = mb.getSparseBlock(); + // verify every row only contain one 1 value. + for(int i = 0; i < mb.getNumRows(); i++) { + if(sb.isEmpty(i)) + continue; + else if(sb.size(i) != 1) + return false; + else if(!(sb instanceof SparseBlockCSR)) { + double[] values = sb.values(i); + final int spos = sb.pos(i); + final int sEnd = spos + sb.size(i); + for(int j = spos; j < sEnd; j++) { + if(values[j] != 1) { + return false; + } + } + } + } + if(sb instanceof SparseBlockCSR) { + for(double d : sb.values(0)) + if(d != 1) + return false; + } + + return true; + } + return false; + } + + private static MatrixBlock allocateReturn(CompressedMatrixBlock right, MatrixBlock left, MatrixBlock ret, + boolean sparseOut) { + if(ret == null) + ret = new MatrixBlock(); + // sparseOut = false; + ret.reset(left.getNumRows(), right.getNumColumns(), sparseOut); + ret.allocateBlock(); + return ret; + } + + private static void normalLeftSelection(MatrixBlock left, MatrixBlock ret, int k, boolean sparseOut, + final List preFilter) throws Exception { + final int rowLeft = left.getNumRows(); + final boolean pointsNeeded = areSortedCoordinatesNeeded(preFilter); + if(k <= 1 || rowLeft < 1000) + leftSelectionSingleThread(preFilter, left, ret, rowLeft, pointsNeeded, sparseOut); + else + leftSelectionParallel(preFilter, left, ret, k, rowLeft, pointsNeeded, sparseOut); + } + + private static void filteredLeftSelection(MatrixBlock left, MatrixBlock ret, int k, boolean sparseOut, + final List preFilter) throws Exception { + final double[] constV = new double[ret.getNumColumns()]; + final List morphed = CLALibUtils.filterGroups(preFilter, constV); + normalLeftSelection(left, ret, k, sparseOut, morphed); + double[] rowSums = left.rowSum(k).getDenseBlockValues(); + + outerProduct(rowSums, constV, ret, sparseOut); + } + + private static void leftSelectionSingleThread(List right, MatrixBlock left, MatrixBlock ret, + final int rowLeft, final boolean pointsNeeded, final boolean sparseOut) { + P[] points = pointsNeeded ? ColGroupUtils.getSortedSelection(left.getSparseBlock(), 0, rowLeft) : null; + for(AColGroup g : right) + g.selectionMultiply(left, points, ret, 0, rowLeft); + if(sparseOut) + ret.getSparseBlock().sort(); + } + + private static void leftSelectionParallel(List right, MatrixBlock left, MatrixBlock ret, int k, + final int rowLeft, final boolean pointsNeeded, final boolean sparseOut) + throws InterruptedException, ExecutionException { + final ExecutorService pool = CommonThreadPool.get(k); + try { + + List> tasks = new ArrayList<>(); + final int blkz = Math.max(rowLeft / k, 1000); + for(int i = 0; i < rowLeft; i += blkz) { + final int start = i; + final int end = Math.min(rowLeft, i + blkz); + P[] points = pointsNeeded ? ColGroupUtils.getSortedSelection(left.getSparseBlock(), start, end) : null; + tasks.add(pool.submit(() -> { + for(AColGroup g : right) + g.selectionMultiply(left, points, ret, start, end); + if(sparseOut) { + SparseBlock sb = ret.getSparseBlock(); + for(int j = start; j < end; j++) { + if(!sb.isEmpty(j)) + sb.sort(j); + } + } + })); + } + + for(Future t : tasks) + t.get(); + } + finally { + pool.shutdown(); + } + } + + private static boolean areSortedCoordinatesNeeded(List right) { + for(AColGroup g : right) { + if(g.getCompType() == CompressionType.SDC) + return true; + } + return false; + } + + private static void outerProduct(double[] rows, double[] cols, MatrixBlock ret, boolean sparse) { + if(sparse) + outerProductSparse(rows, cols, ret); + else + outerProductDense(rows, cols, ret); + } + + private static void outerProductDense(double[] rows, double[] cols, MatrixBlock ret) { + DenseBlock db = ret.getDenseBlock(); + for(int r = 0; r < rows.length; r++) { + final double rv = rows[r]; + final double[] dbV = db.values(r); + final int pos = db.pos(r); + if(rv != 0) + for(int c = 0; c < cols.length; c++) + dbV[pos + c] += rv * cols[c]; + } + } + + private static void outerProductSparse(double[] rows, double[] cols, MatrixBlock ret) { + final SparseBlock sb = ret.getSparseBlock(); + + final IntArrayList skipCols = new IntArrayList(); + for(int c = 0; c < cols.length; c++) + if(cols[c] != 0) + skipCols.appendValue(c); + + final int skipSz = skipCols.size(); + if(skipSz == 0) + return; + + final int[] skipC = skipCols.extractValues(); + for(int r = 0; r < rows.length; r++) { + final double rv = rows[r]; + if(rv != 0) { + for(int ci = 0; ci < skipSz; ci++) { + final int c = skipC[ci]; + sb.add(r, c, rv * cols[c]); + } + } + } + } +} diff --git a/src/main/java/org/apache/sysds/runtime/compress/readers/ReaderColumnSelection.java b/src/main/java/org/apache/sysds/runtime/compress/readers/ReaderColumnSelection.java index bcf57e5a764..e087525bbbd 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/readers/ReaderColumnSelection.java +++ b/src/main/java/org/apache/sysds/runtime/compress/readers/ReaderColumnSelection.java @@ -30,27 +30,24 @@ public abstract class ReaderColumnSelection { protected static final Log LOG = LogFactory.getLog(ReaderColumnSelection.class.getName()); + /** The column indexes to read from the matrix */ protected final IColIndex _colIndexes; + /** Pointer to the wrapping reusable return DblArray */ protected final DblArray reusableReturn; + /** A reusable array that is stored inside the DblArray */ protected final double[] reusableArr; /** The row index to stop the reading at */ protected final int _ru; - /** rl is used as a pointer to current row */ + /** rl is used as a pointer to current row, that increment on calls to nextRow */ protected int _rl; protected ReaderColumnSelection(IColIndex colIndexes, int rl, int ru) { _colIndexes = colIndexes; _rl = rl; _ru = ru; - if(colIndexes != null) { - reusableArr = new double[colIndexes.size()]; - reusableReturn = new DblArray(reusableArr); - } - else { - reusableArr = null; - reusableReturn = null; - } + reusableArr = new double[colIndexes.size()]; + reusableReturn = new DblArray(reusableArr); } /** @@ -68,24 +65,59 @@ public final DblArray nextRow() { return ret; } + /** + * Get the next row as a DblArray, returns null if no more rows. This method is used internally and not supposed to + * be called from the outside, instead use nextRow. + * + * @return The next row. + */ protected abstract DblArray getNextRow(); + /** + * Get the current row index that the reader is at. + * + * @return The row index + */ public int getCurrentRowIndex() { return _rl; } + /** + * Create an reader of the matrix block that is able to iterate though all the rows and return as dense double + * arrays. + * + * Note the reader reuse the return, therefore if needed for something please copy the returned rows. + * + * @param rawBlock The block to iterate though + * @param colIndices The column indexes to extract and insert into the double array + * @param transposed If the raw block should be treated as transposed + * @return A reader of the columns specified + */ public static ReaderColumnSelection createReader(MatrixBlock rawBlock, IColIndex colIndices, boolean transposed) { final int rl = 0; final int ru = transposed ? rawBlock.getNumColumns() : rawBlock.getNumRows(); return createReader(rawBlock, colIndices, transposed, rl, ru); } + /** + * Create an reader of the matrix block that is able to iterate though all the rows and return as dense double + * arrays. + * + * Note the reader reuse the return, therefore if needed for something please copy the returned rows. + * + * @param rawBlock The block to iterate though + * @param colIndices The column indexes to extract and insert into the double array + * @param transposed If the raw block should be treated as transposed + * @param rl The row to start at + * @param ru The row to end at (not inclusive) + * @return A reader of the columns specified + */ public static ReaderColumnSelection createReader(MatrixBlock rawBlock, IColIndex colIndices, boolean transposed, int rl, int ru) { - checkInput(rawBlock, colIndices, rl, ru); + checkInput(rawBlock, colIndices, rl, ru, transposed); rl = rl - 1; if(rawBlock.isEmpty()) { - LOG.warn("It is likely an error occurred when reading an empty block. But we do support it!"); + LOG.warn("It is likely an error occurred when reading an empty block, but we do support it!"); return new ReaderColumnSelectionEmpty(rawBlock, colIndices, rl, ru, transposed); } @@ -104,11 +136,18 @@ else if(rawBlock.getDenseBlock().numBlocks() > 1) return new ReaderColumnSelectionDenseSingleBlock(rawBlock, colIndices, rl, ru); } - private static void checkInput(final MatrixBlock rawBlock, final IColIndex colIndices, final int rl, final int ru) { + private static void checkInput(final MatrixBlock rawBlock, final IColIndex colIndices, final int rl, final int ru, + final boolean transposed) { if(colIndices.size() <= 1) throw new DMLCompressionException( "Column selection reader should not be done on single column groups: " + colIndices); else if(rl >= ru) throw new DMLCompressionException("Invalid inverse range for reader " + rl + " to " + ru); + + final int finalColIndex = colIndices.get(colIndices.size() - 1); + final int finalBlockCol = transposed ? rawBlock.getNumRows() : rawBlock.getNumColumns(); + if(finalColIndex > finalBlockCol) + throw new DMLCompressionException("Invalid columns to extract outside the given block: index: " + finalColIndex + + " is larger than : " + finalBlockCol); } } diff --git a/src/main/java/org/apache/sysds/runtime/compress/readers/ReaderColumnSelectionDenseMultiBlock.java b/src/main/java/org/apache/sysds/runtime/compress/readers/ReaderColumnSelectionDenseMultiBlock.java index d82cb5e23b1..dad3212f902 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/readers/ReaderColumnSelectionDenseMultiBlock.java +++ b/src/main/java/org/apache/sysds/runtime/compress/readers/ReaderColumnSelectionDenseMultiBlock.java @@ -33,7 +33,6 @@ protected ReaderColumnSelectionDenseMultiBlock(MatrixBlock data, IColIndex colIn } protected DblArray getNextRow() { - _rl++; for(int i = 0; i < _colIndexes.size(); i++) reusableArr[i] = _data.get(_rl, _colIndexes.get(i)); diff --git a/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java b/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java index 8e2e7f06657..b0e9eb2ea06 100644 --- a/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java +++ b/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java @@ -230,6 +230,7 @@ public MatrixBlock(int rl, int cl, DenseBlock dBlock){ clen = cl; sparse = false; denseBlock = dBlock; + nonZeros = -1; } public MatrixBlock(int rl, int cl, double[] vals){ @@ -5529,6 +5530,25 @@ public static MatrixBlock randOperations(RandomMatrixGenerator rgen, long seed, return out; } + + /** + * Transpose this MatrixBlock + * + * @return The transpose MatrixBlock + */ + public final MatrixBlock transpose() { + return transpose(1); + } + + /** + * Transpose this MatrixBlock leveraging parallelzation degree k + * + * @param k Parallelization degree allowed + * @return The transpose MatrixBlock + */ + public MatrixBlock transpose(int k) { + return LibMatrixReorg.transpose(this, k); + } /** * Function to generate a matrix of random numbers. This is invoked both diff --git a/src/test/java/org/apache/sysds/test/TestUtils.java b/src/test/java/org/apache/sysds/test/TestUtils.java index 1a0cbdfa879..bbb29deb4a3 100644 --- a/src/test/java/org/apache/sysds/test/TestUtils.java +++ b/src/test/java/org/apache/sysds/test/TestUtils.java @@ -1087,6 +1087,15 @@ public static void compareMatricesBitAvgDistance(MatrixBlock expectedMatrix, Mat public static void compareMatricesBitAvgDistance(MatrixBlock expectedMatrix, MatrixBlock actualMatrix, long maxUnitsOfLeastPrecision, long maxAvgDistance, String message) { + + final int rows = expectedMatrix.getNumRows(); + final int cols = expectedMatrix.getNumColumns(); + + if(rows != actualMatrix.getNumRows()) + fail(message + "\nnot same number of rows: " + rows + " vs " + actualMatrix.getNumRows()); + if(cols != actualMatrix.getNumColumns()) + fail(message + "\nnot same number of cols: " + cols + " vs " + actualMatrix.getNumColumns()); + if(expectedMatrix instanceof CompressedMatrixBlock) expectedMatrix = ((CompressedMatrixBlock) expectedMatrix).decompress(); if(actualMatrix instanceof CompressedMatrixBlock) @@ -1123,13 +1132,11 @@ else if(expectedMatrix.isInSparseFormat() && actualMatrix.isInSparseFormat()) { maxUnitsOfLeastPrecision, maxAvgDistance, message, actualMatrix.getNumColumns()); return; } - final int rows = expectedMatrix.getNumRows(); - final int cols = actualMatrix.getNumColumns(); int countErrors = 0; long sumDistance = 0; - for(int i = 0; i < rows && countErrors < 20; i++) { - for(int j = 0; j < cols && countErrors < 20; j++) { + for(int i = 0; i < rows && countErrors < 5; i++) { + for(int j = 0; j < cols && countErrors < 5; j++) { final double v1 = expectedMatrix.get(i, j); final double v2 = actualMatrix.get(i, j); if(v1 == 0 && v2 == 0) diff --git a/src/test/java/org/apache/sysds/test/component/compress/colgroup/ColGroupFactoryTest.java b/src/test/java/org/apache/sysds/test/component/compress/colgroup/ColGroupFactoryTest.java index 8d24d4bb64e..0468de4dc04 100644 --- a/src/test/java/org/apache/sysds/test/component/compress/colgroup/ColGroupFactoryTest.java +++ b/src/test/java/org/apache/sysds/test/component/compress/colgroup/ColGroupFactoryTest.java @@ -103,6 +103,8 @@ public static Collection data() { private static void addDenseMultiBlock(ArrayList tests, int nRows, int nCols, int min, int max, double sparsity, int seed) { + if(nCols <= 1) + nCols += 1; MatrixBlock mb = TestUtils.generateTestMatrixBlock(nRows, nCols, min, max, sparsity, seed); mb = TestUtils.ceil(mb); @@ -113,7 +115,7 @@ private static void addDenseMultiBlock(ArrayList tests, int nRows, int mbt = new MatrixBlock(mbt.getNumRows(), mbt.getNumColumns(), new DenseBlockFP64Mock(new int[] {mbt.getNumRows(), mbt.getNumColumns()}, mbt.getDenseBlockValues())); - add(tests, nCols + 3, mb, mbt); + add(tests, nCols , mb, mbt); } private static void addWithEmpty(ArrayList tests, int nRows, int nCols, int min, int max, double sparsity, diff --git a/src/test/java/org/apache/sysds/test/component/compress/colgroup/ColGroupNegativeTests.java b/src/test/java/org/apache/sysds/test/component/compress/colgroup/ColGroupNegativeTests.java index dd5a65a0f77..6f33851ff7b 100644 --- a/src/test/java/org/apache/sysds/test/component/compress/colgroup/ColGroupNegativeTests.java +++ b/src/test/java/org/apache/sysds/test/component/compress/colgroup/ColGroupNegativeTests.java @@ -37,14 +37,16 @@ import org.apache.sysds.runtime.compress.colgroup.ColGroupRLE; import org.apache.sysds.runtime.compress.colgroup.ColGroupSDCSingleZeros; import org.apache.sysds.runtime.compress.colgroup.ColGroupSDCZeros; -import org.apache.sysds.runtime.compress.colgroup.dictionary.IDictionary; +import org.apache.sysds.runtime.compress.colgroup.ColGroupUtils.P; import org.apache.sysds.runtime.compress.colgroup.dictionary.Dictionary; +import org.apache.sysds.runtime.compress.colgroup.dictionary.IDictionary; import org.apache.sysds.runtime.compress.colgroup.indexes.ColIndexFactory; import org.apache.sysds.runtime.compress.colgroup.indexes.IColIndex; import org.apache.sysds.runtime.compress.colgroup.scheme.ICLAScheme; import org.apache.sysds.runtime.compress.cost.ComputationCostEstimator; import org.apache.sysds.runtime.compress.estim.CompressedSizeInfo; import org.apache.sysds.runtime.compress.estim.CompressedSizeInfoColGroup; +import org.apache.sysds.runtime.compress.lib.CLALibLeftMultBy; import org.apache.sysds.runtime.data.DenseBlock; import org.apache.sysds.runtime.data.SparseBlock; import org.apache.sysds.runtime.functionobjects.Builtin; @@ -140,6 +142,11 @@ public void invalidPreAggregateClass() { } } + @Test(expected = Exception.class) + public void invalidEmptyRowSum() { + CLALibLeftMultBy.rowSum(new MatrixBlock(10, 10, true), 0, 10, 0, 10); + } + private class FakeIndexing extends IndexFunction { private static final long serialVersionUID = -4099420257856761251L; @@ -198,7 +205,7 @@ public void preAggregateDense(MatrixBlock m, double[] preAgg, int rl, int ru, in } @Override - public void preAggregateSparse(SparseBlock sb, double[] preAgg, int rl, int ru) { + public void preAggregateSparse(SparseBlock sb, double[] preAgg, int rl, int ru, int cl, int cu) { } @@ -393,6 +400,24 @@ protected AColGroup fixColIndexes(IColIndex newColIndex, int[] reordering) { // TODO Auto-generated method stub throw new UnsupportedOperationException("Unimplemented method 'fixColIndexes'"); } + + @Override + public void leftMMIdentityPreAggregateDense(MatrixBlock that, MatrixBlock ret, int rl, int ru, int cl, int cu) { + // TODO Auto-generated method stub + throw new UnsupportedOperationException("Unimplemented method 'leftMMIdentityPreAggregateDense'"); + } + + @Override + public void sparseSelection(MatrixBlock selection, P[] points, MatrixBlock ret, int rl, int ru) { + // TODO Auto-generated method stub + throw new UnsupportedOperationException("Unimplemented method 'sparseSelection'"); + } + + @Override + protected void denseSelection(MatrixBlock selection, P[] points, MatrixBlock ret, int rl, int ru) { + // TODO Auto-generated method stub + throw new UnsupportedOperationException("Unimplemented method 'denseSelection'"); + } } private class FakeDictBasedColGroup extends ADictBasedColGroup { @@ -643,5 +668,17 @@ protected AColGroup fixColIndexes(IColIndex newColIndex, int[] reordering) { // TODO Auto-generated method stub throw new UnsupportedOperationException("Unimplemented method 'fixColIndexes'"); } + + @Override + public void sparseSelection(MatrixBlock selection, P[] points, MatrixBlock ret, int rl, int ru) { + // TODO Auto-generated method stub + throw new UnsupportedOperationException("Unimplemented method 'sparseSelection'"); + } + + @Override + protected void denseSelection(MatrixBlock selection, P[] points, MatrixBlock ret, int rl, int ru) { + // TODO Auto-generated method stub + throw new UnsupportedOperationException("Unimplemented method 'denseSelection'"); + } } } diff --git a/src/test/java/org/apache/sysds/test/component/compress/colgroup/ColGroupTest.java b/src/test/java/org/apache/sysds/test/component/compress/colgroup/ColGroupTest.java index 3f7d5e37e63..d4c54780da9 100644 --- a/src/test/java/org/apache/sysds/test/component/compress/colgroup/ColGroupTest.java +++ b/src/test/java/org/apache/sysds/test/component/compress/colgroup/ColGroupTest.java @@ -20,6 +20,7 @@ package org.apache.sysds.test.component.compress.colgroup; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; @@ -33,6 +34,7 @@ import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.apache.sysds.runtime.DMLRuntimeException; +import org.apache.sysds.runtime.compress.CompressedMatrixBlock; import org.apache.sysds.runtime.compress.CompressionSettings; import org.apache.sysds.runtime.compress.CompressionSettingsBuilder; import org.apache.sysds.runtime.compress.DMLCompressionException; @@ -46,6 +48,8 @@ import org.apache.sysds.runtime.compress.colgroup.ColGroupRLE; import org.apache.sysds.runtime.compress.colgroup.ColGroupSDCSingleZeros; import org.apache.sysds.runtime.compress.colgroup.ColGroupUncompressed; +import org.apache.sysds.runtime.compress.colgroup.ColGroupUtils; +import org.apache.sysds.runtime.compress.colgroup.ColGroupUtils.P; import org.apache.sysds.runtime.compress.colgroup.indexes.ColIndexFactory; import org.apache.sysds.runtime.compress.colgroup.indexes.IColIndex; import org.apache.sysds.runtime.compress.colgroup.scheme.ICLAScheme; @@ -74,6 +78,7 @@ import org.apache.sysds.runtime.matrix.operators.ScalarOperator; import org.apache.sysds.runtime.matrix.operators.UnaryOperator; import org.apache.sysds.test.TestUtils; +import org.apache.sysds.test.component.compress.lib.CLALibSelectionMultTest; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; @@ -1118,7 +1123,7 @@ public void getMax() { @Test public void tsmm() { - try{ + try { final MatrixBlock bt = new MatrixBlock(maxCol, maxCol, false); final MatrixBlock ot = new MatrixBlock(maxCol, maxCol, false); @@ -1128,7 +1133,7 @@ public void tsmm() { other.tsmm(ot, nRow); compare(ot, bt); } - catch(Exception e){ + catch(Exception e) { e.printStackTrace(); fail(e.getMessage()); } @@ -1294,6 +1299,77 @@ public void leftMultNoPreAgg(int nRowLeft, int rl, int ru, int cl, int cu, Matri } } + @Test + public void sparseSelection() { + MatrixBlock mb = CLALibSelectionMultTest.createSelectionMatrix(nRow, 5, false); + mb = CompressedMatrixBlock.getUncompressed(mb); + MatrixBlock ret = new MatrixBlock(5, maxCol, true); + ret.allocateSparseRowsBlock(); + ret.setNonZeros(-1); + selection(mb, ret); + } + + @Test + public void denseSelection() { + MatrixBlock mb = CLALibSelectionMultTest.createSelectionMatrix(nRow, 5, false); + mb = CompressedMatrixBlock.getUncompressed(mb); + MatrixBlock ret = new MatrixBlock(5, maxCol, false); + ret.allocateDenseBlock(); + ret.setNonZeros(-1); + assertFalse(ret.isInSparseFormat()); + selection(mb, ret); + } + + + @Test + public void sparseSelectionEmptyRows() { + MatrixBlock mb = CLALibSelectionMultTest.createSelectionMatrix(nRow, 50, true); + mb = CompressedMatrixBlock.getUncompressed(mb); + MatrixBlock ret = new MatrixBlock(50, maxCol, true); + ret.allocateSparseRowsBlock(); + ret.setNonZeros(-1); + selection(mb, ret); + } + + @Test + public void denseSelectionEmptyRows() { + MatrixBlock mb = CLALibSelectionMultTest.createSelectionMatrix(nRow, 50, true); + mb = CompressedMatrixBlock.getUncompressed(mb); + MatrixBlock ret = new MatrixBlock(50, maxCol, false); + ret.allocateDenseBlock(); + ret.setNonZeros(-1); + assertFalse(ret.isInSparseFormat()); + selection(mb, ret); + } + + public void selection(MatrixBlock selection, MatrixBlock ret) { + P[] points = ColGroupUtils.getSortedSelection(selection.getSparseBlock(), 0, selection.getNumRows()); + MatrixBlock ret1 = new MatrixBlock(ret.getNumRows(), ret.getNumColumns(), ret.isInSparseFormat()); + ret1.allocateBlock(); + + + MatrixBlock ret2 = new MatrixBlock(ret.getNumRows(), ret.getNumColumns(), ret.isInSparseFormat()); + ret2.allocateBlock(); + + try { + + base.selectionMultiply(selection, points, ret1, 0, selection.getNumRows()); + other.selectionMultiply(selection, points, ret2, 0, selection.getNumRows()); + + TestUtils.compareMatricesBitAvgDistance(ret1, ret2, 0, 0, base.getClass().getSimpleName() + " vs " + other.getClass().getSimpleName()); + + + } + catch(NotImplementedException e) { + // okay + } + catch(Exception e){ + e.printStackTrace(); + fail(e.getMessage()); + } + + } + @Test public void preAggLeftMult() { preAggLeftMult(new MatrixBlock(1, nRow, 1.0), 0, 1); @@ -1465,8 +1541,7 @@ public void preAggLeftMultSecondRowDenseEnd() { preAggLeftMultDense(new MatrixBlock(2, nRow, 2.0), 1, 2, nRow - 10, nRow - 3); } - @Test(expected = NotImplementedException.class) - // @Test + @Test public void preAggLeftMultDenseNonContiguous() { try { @@ -1476,14 +1551,12 @@ public void preAggLeftMultDenseNonContiguous() { preAggLeftMultDense(new MatrixBlock(1, nRow, mock), 0, 1, 3, nRow - 3); } catch(NotImplementedException e) { - throw e; + // valid } catch(Exception e) { e.printStackTrace(); fail(e.getMessage()); } - throw new NotImplementedException("Throw since we passed the test"); - // if we get here throw the exception } private static class DenseBlockFP64Mock extends DenseBlockFP64 { @@ -1505,51 +1578,43 @@ public int numBlocks() { } public void preAggLeftMultDense(MatrixBlock mb, int rl, int ru, int cl, int cu) { - MatrixBlock retB = null; - MatrixBlock retO = null; - final double[] rowSum = CLALibLeftMultBy.rowSum(mb, rl, ru, cl, cu); - if(base instanceof AMorphingMMColGroup) { - double[] cb = new double[maxCol]; - AColGroup b = ((AMorphingMMColGroup) base).extractCommon(cb); - retB = mmPreAggDense((APreAgg) b, mb, cb, rowSum, rl, ru, cl, cu); - } - else if(base instanceof APreAgg) - retB = mmPreAggDense((APreAgg) base, mb, null, rowSum, rl, ru, cl, cu); - else if(base instanceof ColGroupConst) { - double[] cb = new double[maxCol]; - ((ColGroupConst) base).addToCommon(cb); - retO = mmRowSum(cb, rowSum, rl, ru, cl, cu); - } + final MatrixBlock retB = morphingLLM(base, mb, rl, ru, cl, cu, rowSum); + final MatrixBlock retO = morphingLLM(other, mb, rl, ru, cl, cu, rowSum); + + retB.recomputeNonZeros(); + retO.recomputeNonZeros(); - if(other instanceof AMorphingMMColGroup) { + compare(retB, retO); + + } + + private MatrixBlock lmmNoAgg(AColGroup g, MatrixBlock mb, int rl, int ru, int cl, int cu) { + MatrixBlock tmpB = new MatrixBlock(ru, maxCol, false); + tmpB.allocateDenseBlock(); + g.leftMultByMatrixNoPreAgg(mb, tmpB, rl, ru, cl, cu); + return tmpB; + } + + private MatrixBlock morphingLLM(AColGroup g, MatrixBlock mb, int rl, int ru, int cl, int cu, final double[] rowSum) { + final MatrixBlock retB; + if(g instanceof AMorphingMMColGroup) { double[] cb = new double[maxCol]; - AColGroup b = ((AMorphingMMColGroup) other).extractCommon(cb); - retO = mmPreAggDense((APreAgg) b, mb, cb, rowSum, rl, ru, cl, cu); + AColGroup b = ((AMorphingMMColGroup) g).extractCommon(cb); + retB = mmPreAggDense((APreAgg) b, mb, cb, rowSum, rl, ru, cl, cu); } - else if(other instanceof APreAgg) - retO = mmPreAggDense((APreAgg) other, mb, null, rowSum, rl, ru, cl, cu); - else if(other instanceof ColGroupConst) { + else if(g instanceof APreAgg) + retB = mmPreAggDense((APreAgg) g, mb, null, rowSum, rl, ru, cl, cu); + else if(g instanceof ColGroupConst) { double[] cb = new double[maxCol]; - ((ColGroupConst) other).addToCommon(cb); - retO = mmRowSum(cb, rowSum, rl, ru, cl, cu); + ((ColGroupConst) g).addToCommon(cb); + retB = mmRowSum(cb, rowSum, rl, ru, cl, cu); } + else + retB = lmmNoAgg(g, mb, rl, ru, cl, cu); - if(retB == null) { - retB = new MatrixBlock(ru, maxCol, false); - retB.allocateDenseBlock(); - base.leftMultByMatrixNoPreAgg(mb, retB, rl, ru, cl, cu); - } - - if(retO == null) { - retO = new MatrixBlock(ru, maxCol, false); - retO.allocateDenseBlock(); - other.leftMultByMatrixNoPreAgg(mb, retO, rl, ru, cl, cu); - } - - compare(retB, retO); - + return retB; } private MatrixBlock mmPreAggDense(APreAgg g, MatrixBlock mb, double[] cv, double[] rowSum, int rl, int ru, int cl, @@ -2289,7 +2354,7 @@ private void appendSelfVerification(AColGroup g) { try { AColGroup g2 = g.append(g); - AColGroup g2n = AColGroup.appendN(new AColGroup[] {g, g}, nRow, nRow*2); + AColGroup g2n = AColGroup.appendN(new AColGroup[] {g, g}, nRow, nRow * 2); if(g2 != null && g2n != null) { double s2 = g2.getSum(nRow * 2); double s = g.getSum(nRow) * 2; @@ -2300,7 +2365,7 @@ private void appendSelfVerification(AColGroup g) { UA_ROW(InstructionUtils.parseBasicAggregateUnaryOperator("uar+", 1), 0, nRow * 2, g2, g2n, nRow * 2); } } - catch(NotImplementedException e){ + catch(NotImplementedException e) { // okay } catch(Exception e) { diff --git a/src/test/java/org/apache/sysds/test/component/compress/colgroup/scheme/SchemeTestBase.java b/src/test/java/org/apache/sysds/test/component/compress/colgroup/scheme/SchemeTestBase.java index a0b45351ba7..16d248c5459 100644 --- a/src/test/java/org/apache/sysds/test/component/compress/colgroup/scheme/SchemeTestBase.java +++ b/src/test/java/org/apache/sysds/test/component/compress/colgroup/scheme/SchemeTestBase.java @@ -340,6 +340,7 @@ public void testUpdateEmpty() { @Test public void testUpdateEmptyT() { MatrixBlock in = new MatrixBlock(src.getNumColumns(), 5, 0.0); + // 5 rows to encode transposed try { sh.encodeT(in); } @@ -353,10 +354,13 @@ public void testUpdateEmptyT() { shc = shc.updateT(in); AColGroup out = shc.encodeT(in); // should be possible now. - MatrixBlock d = new MatrixBlock(in.getNumRows(), src.getNumColumns(), false); + + // now we learned how to encode. lets decompress the encoded. + + MatrixBlock d = new MatrixBlock( in.getNumColumns(), in.getNumRows(), false); d.allocateBlock(); - out.decompressToDenseBlock(d.getDenseBlock(), 0, in.getNumRows()); - MatrixBlock inSlice = in.slice(0, src.getNumColumns() - 1, 0, in.getNumColumns() - 1); + out.decompressToDenseBlock(d.getDenseBlock(), 0, in.getNumColumns()); + MatrixBlock inSlice = in.slice(0, in.getNumRows() - 1, 0, in.getNumColumns() - 1); d.recomputeNonZeros(); TestUtils.compareMatricesBitAvgDistance(inSlice, LibMatrixReorg.transpose(d), 0, 0); } @@ -395,7 +399,7 @@ public void testUpdateEmptyMyCols() { @Test public void testUpdateEmptyMyColsT() { MatrixBlock in = new MatrixBlock(src.getNumColumns(), 5, 0.0); - in = in.append(new MatrixBlock(1, 5, 1.0), false); + in = in.append(new MatrixBlock(src.getNumColumns(), 1, 1.0), true); try { sh.encodeT(in); } @@ -409,10 +413,17 @@ public void testUpdateEmptyMyColsT() { shc = shc.updateT(in); AColGroup out = shc.encodeT(in); // should be possible now. - MatrixBlock d = new MatrixBlock(in.getNumRows(), src.getNumColumns(), false); + // MatrixBlock d = new MatrixBlock(in.getNumRows(), src.getNumColumns(), false); + // d.allocateBlock(); + // out.decompressToDenseBlock(d.getDenseBlock(), 0, in.getNumRows()); + // MatrixBlock inSlice = in.slice(0, src.getNumColumns() - 1, 0, in.getNumColumns() - 1); + // d.recomputeNonZeros(); + // TestUtils.compareMatricesBitAvgDistance(inSlice, LibMatrixReorg.transpose(d), 0, 0); + + MatrixBlock d = new MatrixBlock( in.getNumColumns(), in.getNumRows(), false); d.allocateBlock(); - out.decompressToDenseBlock(d.getDenseBlock(), 0, in.getNumRows()); - MatrixBlock inSlice = in.slice(0, src.getNumColumns() - 1, 0, in.getNumColumns() - 1); + out.decompressToDenseBlock(d.getDenseBlock(), 0, in.getNumColumns()); + MatrixBlock inSlice = in.slice(0, in.getNumRows() - 1, 0, in.getNumColumns() - 1); d.recomputeNonZeros(); TestUtils.compareMatricesBitAvgDistance(inSlice, LibMatrixReorg.transpose(d), 0, 0); } diff --git a/src/test/java/org/apache/sysds/test/component/compress/dictionary/CombineTest.java b/src/test/java/org/apache/sysds/test/component/compress/dictionary/CombineTest.java index f9e0ac1ee42..119412fbb69 100644 --- a/src/test/java/org/apache/sysds/test/component/compress/dictionary/CombineTest.java +++ b/src/test/java/org/apache/sysds/test/component/compress/dictionary/CombineTest.java @@ -40,9 +40,9 @@ import org.apache.sysds.runtime.compress.colgroup.ColGroupDDC; import org.apache.sysds.runtime.compress.colgroup.ColGroupEmpty; import org.apache.sysds.runtime.compress.colgroup.ColGroupSDC; -import org.apache.sysds.runtime.compress.colgroup.dictionary.IDictionary; import org.apache.sysds.runtime.compress.colgroup.dictionary.Dictionary; import org.apache.sysds.runtime.compress.colgroup.dictionary.DictionaryFactory; +import org.apache.sysds.runtime.compress.colgroup.dictionary.IDictionary; import org.apache.sysds.runtime.compress.colgroup.indexes.ColIndexFactory; import org.apache.sysds.runtime.compress.colgroup.mapping.MapToFactory; import org.apache.sysds.runtime.compress.colgroup.offset.OffsetFactory; @@ -72,6 +72,43 @@ public void singleBothSides() { } } + @Test + public void singleBothSidesFilter() { + try { + + IDictionary a = Dictionary.create(new double[] {1.2}); + IDictionary b = Dictionary.create(new double[] {1.4}); + Map filter = new HashMap<>(); + filter.put(0, 0); + IDictionary c = DictionaryFactory.combineFullDictionaries(a, 1, b, 1, filter); + + assertEquals(c.getValue(0, 0, 2), 1.2, 0.0); + assertEquals(c.getValue(0, 1, 2), 1.4, 0.0); + } + catch(Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + } + + @Test + public void singleBothSidesFilter2() { + try { + + IDictionary a = Dictionary.create(new double[] {1.2}); + IDictionary b = Dictionary.create(new double[] {1.4}); + Map filter = new HashMap<>(); + // filter.put(0, 0); + IDictionary c = DictionaryFactory.combineFullDictionaries(a, 1, b, 1, filter); + + assertEquals(c, null); + } + catch(Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + } + @Test public void singleOneSideBothSides() { try { @@ -91,6 +128,25 @@ public void singleOneSideBothSides() { } } + @Test + public void singleOneSideOtherSides() { + try { + IDictionary a = Dictionary.create(new double[] {1.2}); + IDictionary b = Dictionary.create(new double[] {1.3, 1.4}); + + IDictionary c = DictionaryFactory.combineFullDictionaries(a, 1, b, 1); + + assertEquals(c.getValue(0, 0, 2), 1.2, 0.0); + assertEquals(c.getValue(0, 1, 2), 1.3, 0.0); + assertEquals(c.getValue(1, 0, 2), 1.2, 0.0); + assertEquals(c.getValue(1, 1, 2), 1.4, 0.0); + } + catch(Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + } + @Test public void twoBothSides() { try { @@ -114,6 +170,87 @@ public void twoBothSides() { } } + @Test + public void twoBothSidesFilter() { + try { + IDictionary a = Dictionary.create(new double[] {1.2, 1.3}); + IDictionary b = Dictionary.create(new double[] {1.4, 1.5}); + Map filter = new HashMap<>(); + filter.put(0,0); + + IDictionary c = DictionaryFactory.combineFullDictionaries(a, 1, b, 1, filter); + + assertEquals(1, c.getNumberOfValues(2)); + assertEquals(c.getValue(0, 0, 2), 1.2, 0.0); + assertEquals(c.getValue(0, 1, 2), 1.4, 0.0); + // assertEquals(c.getValue(1, 0, 2), 1.3, 0.0); + // assertEquals(c.getValue(1, 1, 2), 1.4, 0.0); + // assertEquals(c.getValue(2, 0, 2), 1.2, 0.0); + // assertEquals(c.getValue(2, 1, 2), 1.5, 0.0); + // assertEquals(c.getValue(3, 0, 2), 1.3, 0.0); + // assertEquals(c.getValue(3, 1, 2), 1.5, 0.0); + } + catch(Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + } + + + @Test + public void twoBothSidesFilter2() { + try { + IDictionary a = Dictionary.create(new double[] {1.2, 1.3}); + IDictionary b = Dictionary.create(new double[] {1.4, 1.5}); + Map filter = new HashMap<>(); + filter.put(3,0); + + IDictionary c = DictionaryFactory.combineFullDictionaries(a, 1, b, 1, filter); + + assertEquals(1, c.getNumberOfValues(2)); + assertEquals(c.getValue(0, 0, 2), 1.3, 0.0); + assertEquals(c.getValue(0, 1, 2), 1.5, 0.0); + // assertEquals(c.getValue(1, 0, 2), 1.3, 0.0); + // assertEquals(c.getValue(1, 1, 2), 1.4, 0.0); + // assertEquals(c.getValue(2, 0, 2), 1.2, 0.0); + // assertEquals(c.getValue(2, 1, 2), 1.5, 0.0); + // assertEquals(c.getValue(3, 0, 2), 1.3, 0.0); + // assertEquals(c.getValue(3, 1, 2), 1.5, 0.0); + } + catch(Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + } + + + @Test + public void twoBothSidesFilter3() { + try { + IDictionary a = Dictionary.create(new double[] {1.2, 1.3}); + IDictionary b = Dictionary.create(new double[] {1.4, 1.5}); + Map filter = new HashMap<>(); + filter.put(3,0); + filter.put(1,1); + + IDictionary c = DictionaryFactory.combineFullDictionaries(a, 1, b, 1, filter); + + assertEquals(2, c.getNumberOfValues(2)); + assertEquals(c.getValue(0, 0, 2), 1.3, 0.0); + assertEquals(c.getValue(0, 1, 2), 1.5, 0.0); + assertEquals(c.getValue(1, 0, 2), 1.3, 0.0); + assertEquals(c.getValue(1, 1, 2), 1.4, 0.0); + // assertEquals(c.getValue(2, 0, 2), 1.2, 0.0); + // assertEquals(c.getValue(2, 1, 2), 1.5, 0.0); + // assertEquals(c.getValue(3, 0, 2), 1.3, 0.0); + // assertEquals(c.getValue(3, 1, 2), 1.5, 0.0); + } + catch(Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + } + @Test public void sparseSparse() { try { @@ -395,51 +532,51 @@ public void combineNotImplementedSparse6() { DictionaryFactory.combineDictionariesSparse(m, s); } - @Test - public void sparseSparseConst1() { - try { - IDictionary a = Dictionary.create(new double[] {3, 2, 7, 8}); - // IDictionary b = Dictionary.create(new double[] {4, 4, 9, 5}); - - double[] bd = new double[] {0, 2}; - - IDictionary c = DictionaryFactory.combineSparseConstSparseRet(a, 2, bd); - MatrixBlock ret = c.getMBDict(2).getMatrixBlock(); - - MatrixBlock exp = new MatrixBlock(2, 4, new double[] {// - 3, 2, 0, 2, // - 7, 8, 0, 2,}); - TestUtils.compareMatricesBitAvgDistance(ret, exp, 0, 0); - } - catch(Exception e) { - e.printStackTrace(); - fail(e.getMessage()); - } - } - - @Test - public void sparseSparseConst2() { - try { - IDictionary a = Dictionary.create(new double[] {3, 2, 7, 8}); - // IDictionary b = Dictionary.create(new double[] {4, 4, 9, 5}); - - double[] bd = new double[] {0, 2}; - - IDictionary c = DictionaryFactory.combineSparseConstSparseRet(a, 1, bd); - MatrixBlock ret = c.getMBDict(2).getMatrixBlock(); - - MatrixBlock exp = new MatrixBlock(2, 3, new double[] {// - 3, 0, 2, // - 2, 0, 2, // - 7, 0, 2, // - 8, 0, 2,}); - TestUtils.compareMatricesBitAvgDistance(ret, exp, 0, 0); - } - catch(Exception e) { - e.printStackTrace(); - fail(e.getMessage()); - } - } + // @Test + // public void sparseSparseConst1() { + // try { + // IDictionary a = Dictionary.create(new double[] {3, 2, 7, 8}); + // // IDictionary b = Dictionary.create(new double[] {4, 4, 9, 5}); + + // double[] bd = new double[] {0, 2}; + + // IDictionary c = DictionaryFactory.combineSparseConstSparseRet(a, 2, bd); + // MatrixBlock ret = c.getMBDict(2).getMatrixBlock(); + + // MatrixBlock exp = new MatrixBlock(2, 4, new double[] {// + // 3, 2, 0, 2, // + // 7, 8, 0, 2,}); + // TestUtils.compareMatricesBitAvgDistance(ret, exp, 0, 0); + // } + // catch(Exception e) { + // e.printStackTrace(); + // fail(e.getMessage()); + // } + // } + + // @Test + // public void sparseSparseConst2() { + // try { + // IDictionary a = Dictionary.create(new double[] {3, 2, 7, 8}); + // // IDictionary b = Dictionary.create(new double[] {4, 4, 9, 5}); + + // double[] bd = new double[] {0, 2}; + + // IDictionary c = DictionaryFactory.combineSparseConstSparseRet(a, 1, bd); + // MatrixBlock ret = c.getMBDict(2).getMatrixBlock(); + + // MatrixBlock exp = new MatrixBlock(4, 3, new double[] {// + // 3, 0, 2, // + // 2, 0, 2, // + // 7, 0, 2, // + // 8, 0, 2,}); + // TestUtils.compareMatricesBitAvgDistance(exp, ret, 0, 0); + // } + // catch(Exception e) { + // e.printStackTrace(); + // fail(e.getMessage()); + // } + // } @Test public void testEmpty() { diff --git a/src/test/java/org/apache/sysds/test/component/compress/dictionary/CustomDictionaryTest.java b/src/test/java/org/apache/sysds/test/component/compress/dictionary/CustomDictionaryTest.java index 8dd8a1165bf..bfacec50e16 100644 --- a/src/test/java/org/apache/sysds/test/component/compress/dictionary/CustomDictionaryTest.java +++ b/src/test/java/org/apache/sysds/test/component/compress/dictionary/CustomDictionaryTest.java @@ -19,16 +19,27 @@ package org.apache.sysds.test.component.compress.dictionary; +import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNull; import static org.junit.Assert.assertTrue; +import java.util.Arrays; + import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.apache.sysds.runtime.compress.DMLCompressionException; +import org.apache.sysds.runtime.compress.bitmap.ABitmap; +import org.apache.sysds.runtime.compress.bitmap.BitmapEncoder; import org.apache.sysds.runtime.compress.colgroup.dictionary.ADictionary; import org.apache.sysds.runtime.compress.colgroup.dictionary.Dictionary; +import org.apache.sysds.runtime.compress.colgroup.dictionary.DictionaryFactory; +import org.apache.sysds.runtime.compress.colgroup.dictionary.IDictionary; import org.apache.sysds.runtime.compress.colgroup.dictionary.MatrixBlockDictionary; +import org.apache.sysds.runtime.compress.colgroup.indexes.ColIndexFactory; +import org.apache.sysds.runtime.compress.utils.DblArray; +import org.apache.sysds.runtime.compress.utils.DblArrayCountHashMap; +import org.apache.sysds.runtime.compress.utils.DoubleCountHashMap; import org.apache.sysds.runtime.matrix.data.MatrixBlock; import org.junit.Test; @@ -153,4 +164,312 @@ public void createZeroRowMatrixBlock() { MatrixBlockDictionary.create(new MatrixBlock(0, 10, 10.0)); } + @Test + public void bitMapConstructor() { + MatrixBlock mb = new MatrixBlock(10, 10, 1.0); + mb.set(5, 5, 2.0); + mb.set(7, 5, 2.0); + final ABitmap ubm = BitmapEncoder.extractBitmap(ColIndexFactory.create(10), mb, true, 1, true); + final double[] defaultTuple = new double[10]; + + IDictionary dict = DictionaryFactory.create(ubm, 0, defaultTuple, 1.0, ubm.getNumZeros() > 0); + assertEquals(dict, Dictionary.create(new double[] {// + 1, 1, 1, 1, 1, 2, 1, 2, 1, 1})); + assertTrue(Arrays.equals(defaultTuple, new double[] {1, 1, 1, 1, 1, 1, 1, 1, 1, 1})); + } + + @Test + public void bitMapConstructor2() { + MatrixBlock mb = new MatrixBlock(10, 10, 1.0); + mb.set(5, 5, 2.0); + mb.set(7, 7, 2.0); + final ABitmap ubm = BitmapEncoder.extractBitmap(ColIndexFactory.create(10), mb, true, 1, true); + final double[] defaultTuple = new double[10]; + + IDictionary dict = DictionaryFactory.create(ubm, 0, defaultTuple, 1.0, ubm.getNumZeros() > 0); + assertEquals(dict, Dictionary.create(new double[] {// + 1, 1, 1, 1, 1, 1, 1, 2, 1, 1, // + 1, 1, 1, 1, 1, 2, 1, 1, 1, 1})); + assertTrue(Arrays.equals(defaultTuple, new double[] {1, 1, 1, 1, 1, 1, 1, 1, 1, 1})); + } + + @Test + public void bitMapConstructor3() { + MatrixBlock mb = new MatrixBlock(10, 10, 1.0); + mb.set(5, 5, 2.0); + mb.set(7, 7, 2.0); + final ABitmap ubm = BitmapEncoder.extractBitmap(ColIndexFactory.create(10), mb, true, 1, true); + final double[] defaultTuple = new double[10]; + + IDictionary dict = DictionaryFactory.create(ubm, 1, defaultTuple, 1.0, ubm.getNumZeros() > 0); + assertEquals(dict, Dictionary.create(new double[] {// + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, // + 1, 1, 1, 1, 1, 2, 1, 1, 1, 1})); + assertTrue(Arrays.equals(defaultTuple, new double[] {1, 1, 1, 1, 1, 1, 1, 2, 1, 1})); + } + + @Test + public void bitMapConstructor4_Sparse() { + MatrixBlock mb = new MatrixBlock(10, 10, 0.0); + mb.set(5, 5, 2.0); + mb.set(7, 7, 2.0); + mb.set(8, 8, 2.0); + final ABitmap ubm = BitmapEncoder.extractBitmap(ColIndexFactory.create(10), mb, true, 1, true); + final double[] defaultTuple = new double[10]; + + IDictionary dict = DictionaryFactory.create(ubm, 1, defaultTuple, 0.1, ubm.getNumZeros() > 0); + assertEquals(dict, Dictionary.create(new double[] {// + 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, // + 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, // + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // + })); + assertTrue(Arrays.equals(defaultTuple, new double[] {0, 0, 0, 0, 0, 0, 0, 2, 0, 0})); + } + + @Test + public void bitMapConstructorVector() { + MatrixBlock mb = new MatrixBlock(10, 1, 1.0); + mb.set(5, 1, 2.0); + mb.set(7, 1, 3.0); + final ABitmap ubm = BitmapEncoder.extractBitmap(ColIndexFactory.create(1), mb, false, 1, true); + final double[] defaultTuple = new double[1]; + + IDictionary dict = DictionaryFactory.create(ubm, 0, defaultTuple, 1.0, ubm.getNumZeros() > 0); + assertEquals(dict, Dictionary.create(new double[] {// + 2, 3})); + assertTrue(Arrays.equals(defaultTuple, new double[] {1})); + } + + @Test + public void bitMapConstructorVector2() { + MatrixBlock mb = new MatrixBlock(10, 1, 1.0); + mb.set(5, 1, 2.0); + mb.set(7, 1, 3.0); + mb.set(8, 1, 3.0); + final ABitmap ubm = BitmapEncoder.extractBitmap(ColIndexFactory.create(1), mb, false, 1, true); + final double[] defaultTuple = new double[1]; + + IDictionary dict = DictionaryFactory.create(ubm, 0, defaultTuple, 1.0, ubm.getNumZeros() > 0); + assertEquals(dict, Dictionary.create(new double[] {// + 3, 2})); + assertTrue(Arrays.equals(defaultTuple, new double[] {1})); + } + + @Test + public void bitMapConstructorVector3() { + MatrixBlock mb = new MatrixBlock(10, 1, 1.0); + mb.set(5, 1, 2.0); + mb.set(7, 1, 3.0); + mb.set(8, 1, 3.0); + final ABitmap ubm = BitmapEncoder.extractBitmap(ColIndexFactory.create(1), mb, false, 1, true); + final double[] defaultTuple = new double[1]; + + IDictionary dict = DictionaryFactory.create(ubm, 0, defaultTuple, 1.0, true); + assertEquals(dict, Dictionary.create(new double[] {// + 3, 2, 0.0})); + assertTrue(Arrays.equals(defaultTuple, new double[] {1})); + } + + @Test + public void bitMapConstruct() { + MatrixBlock mb = new MatrixBlock(10, 1, 1.0); + mb.set(5, 0, 2.0); + mb.set(7, 0, 3.0); + mb.set(8, 0, 3.0); + final ABitmap ubm = BitmapEncoder.extractBitmap(ColIndexFactory.create(1), mb, false, 1, true); + + IDictionary dict = DictionaryFactory.create(ubm); + assertEquals(dict, Dictionary.create(new double[] {// + 1, 3, 2})); + } + + @Test + public void bitMapConstruct2() { + MatrixBlock mb = new MatrixBlock(10, 1, 1.0); + mb.set(5, 0, 2.0); + mb.set(7, 0, 3.0); + mb.set(8, 0, 2.0); + final ABitmap ubm = BitmapEncoder.extractBitmap(ColIndexFactory.create(1), mb, false, 1, true); + + IDictionary dict = DictionaryFactory.create(ubm); + assertEquals(dict, Dictionary.create(new double[] {// + 1, 2, 3})); + } + + @Test + public void bitMapConstruct3() { + MatrixBlock mb = new MatrixBlock(10, 2, 1.0); + mb.set(5, 0, 2.0); + mb.set(7, 0, 3.0); + mb.set(8, 0, 2.0); + final ABitmap ubm = BitmapEncoder.extractBitmap(ColIndexFactory.create(2), mb, false, 1, true); + + IDictionary dict = DictionaryFactory.create(ubm); + assertEquals(dict, Dictionary.create(new double[] {// + 1, 1, // + 2, 1, // + 3, 1,// + })); + } + + @Test + public void bitMapConstruct4Sparse() { + MatrixBlock mb = new MatrixBlock(10, 5, 0.0); + mb.set(5, 0, 2.0); + mb.set(7, 0, 3.0); + mb.set(8, 0, 2.0); + final ABitmap ubm = BitmapEncoder.extractBitmap(ColIndexFactory.create(5), mb, false, 1, true); + + IDictionary dict = DictionaryFactory.create(ubm, 0.1); + assertEquals(dict, Dictionary.create(new double[] {// + 2, 0, 0, 0, 0, // + 3, 0, 0, 0, 0,// + })); + } + + @Test + public void bitMapConstruct4Sparse2() { + MatrixBlock mb = new MatrixBlock(10, 3, 0.0); + mb.set(5, 0, 2.0); + mb.set(7, 0, 3.0); + mb.set(8, 0, 2.0); + final ABitmap ubm = BitmapEncoder.extractBitmap(ColIndexFactory.create(3), mb, false, 1, true); + + IDictionary dict = DictionaryFactory.create(ubm, 0.1); + assertEquals(dict, Dictionary.create(new double[] {// + 2, 0, 0, // + 3, 0, 0,// + })); + } + + @Test + public void getInMemorySize() { + long s = DictionaryFactory.getInMemorySize(100, 100, 1.0, false); + long s2 = Dictionary.getInMemorySize(100 * 100); + assertTrue(s <= s2); + } + + @Test + public void getInMemorySize2() { + long s = DictionaryFactory.getInMemorySize(100, 100, 0.1, false); + long s2 = MatrixBlockDictionary.getInMemorySize(100, 100, 0.1); + assertTrue(s <= s2); + } + + @Test + public void getInMemorySize3() { + long s = DictionaryFactory.getInMemorySize(100, 100, 1.0, true); + long s2 = Dictionary.getInMemorySize(100 * 100); + assertTrue(s <= s2); + } + + @Test + public void getInMemorySize4() { + long s = DictionaryFactory.getInMemorySize(100, 1, 1.0, true); + long s2 = Dictionary.getInMemorySize(100); + assertTrue(s <= s2); + } + + @Test + public void getInMemorySize5() { + long s = DictionaryFactory.getInMemorySize(100, 1, 1.0, false); + long s2 = Dictionary.getInMemorySize(100); + assertTrue(s <= s2); + } + + @Test + public void createDblArrayCount() { + + DblArrayCountHashMap m = new DblArrayCountHashMap(3); + m.increment(new DblArray(new double[] {1, 2, 3})); + m.increment(new DblArray(new double[] {1, 2, 4})); + m.increment(new DblArray(new double[] {1, 2, 3})); + IDictionary d = DictionaryFactory.create(m, 3, false, 1.0); + + assertEquals(Dictionary.create(new double[] {// + 1, 2, 3, // + 1, 2, 4,// + }), d); + } + + @Test + public void createDblArrayCount2() { + + DblArrayCountHashMap m = new DblArrayCountHashMap(3); + m.increment(new DblArray(new double[] {1, 2, 3})); + m.increment(new DblArray(new double[] {1, 2, 4})); + m.increment(new DblArray(new double[] {1, 2, 5})); + IDictionary d = DictionaryFactory.create(m, 3, false, 1.0); + + assertEquals(Dictionary.create(new double[] {// + 1, 2, 3, // + 1, 2, 4, // + 1, 2, 5,// + }), d); + } + + @Test + public void createDblArrayCountSparse() { + + DblArrayCountHashMap m = new DblArrayCountHashMap(3); + m.increment(new DblArray(new double[] {1, 2, 3})); + m.increment(new DblArray(new double[] {1, 2, 4})); + m.increment(new DblArray(new double[] {1, 2, 5})); + IDictionary d = DictionaryFactory.create(m, 3, false, 0.2); + + assertEquals(Dictionary.create(new double[] {// + 1, 2, 3, // + 1, 2, 4, // + 1, 2, 5,// + }), d); + } + + @Test + public void createDblArrayCountSparse2() { + + DblArrayCountHashMap m = new DblArrayCountHashMap(3); + m.increment(new DblArray(new double[] {1, 2, 3, 1, 1})); + m.increment(new DblArray(new double[] {1, 2, 4, 1, 1})); + m.increment(new DblArray(new double[] {1, 2, 5, 1, 1})); + IDictionary d = DictionaryFactory.create(m, 5, false, 0.2); + + assertEquals(Dictionary.create(new double[] {// + 1, 2, 3, 1, 1, // + 1, 2, 4, 1, 1, // + 1, 2, 5, 1, 1,// + }), d); + } + + @Test + public void createDblArrayCountSparse3() { + + DblArrayCountHashMap m = new DblArrayCountHashMap(3); + m.increment(new DblArray(new double[] {0, 2, 3, 0, 0})); + m.increment(new DblArray(new double[] {0, 2, 4, 0, 0})); + m.increment(new DblArray(new double[] {1, 2, 5, 1, 1})); + m.increment(new DblArray(new double[] {1, 2, 5, 1, 1})); + m.increment(new DblArray(new double[] {1, 2, 5, 1, 1})); + IDictionary d = DictionaryFactory.create(m, 5, true, 0.2); + + assertEquals(Dictionary.create(new double[] {// + 0, 2, 3, 0, 0, // + 0, 2, 4, 0, 0, // + 1, 2, 5, 1, 1, // + 0, 0, 0, 0, 0}), d); + } + + @Test + public void createDoubleCountHashMap() { + + DoubleCountHashMap m = new DoubleCountHashMap(3); + m.increment(1); + m.increment(2); + m.increment(4); + m.increment(6); + m.increment(1); + IDictionary d = DictionaryFactory.create(m); + + assertEquals(Dictionary.create(new double[] {// + 1, 2, 4, 6,}), d); + } } diff --git a/src/test/java/org/apache/sysds/test/component/compress/dictionary/DictionaryTests.java b/src/test/java/org/apache/sysds/test/component/compress/dictionary/DictionaryTests.java index 9307930f1d2..587b724c77f 100644 --- a/src/test/java/org/apache/sysds/test/component/compress/dictionary/DictionaryTests.java +++ b/src/test/java/org/apache/sysds/test/component/compress/dictionary/DictionaryTests.java @@ -24,6 +24,11 @@ import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.DataInputStream; +import java.io.DataOutputStream; +import java.io.IOException; import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; @@ -33,6 +38,7 @@ import org.apache.commons.logging.LogFactory; import org.apache.sysds.runtime.compress.DMLCompressionException; import org.apache.sysds.runtime.compress.colgroup.dictionary.Dictionary; +import org.apache.sysds.runtime.compress.colgroup.dictionary.DictionaryFactory; import org.apache.sysds.runtime.compress.colgroup.dictionary.IDictionary; import org.apache.sysds.runtime.compress.colgroup.dictionary.IdentityDictionary; import org.apache.sysds.runtime.compress.colgroup.dictionary.MatrixBlockDictionary; @@ -95,6 +101,14 @@ public static Collection data() { 0, 0, 1, 0, // 0, 0, 0, 1,// }), 4, 4}); + + tests.add(new Object[] {new IdentityDictionary(4).sliceOutColumnRange(1, 4, 4), // + Dictionary.create(new double[] {// + 0, 0, 0, // + 1, 0, 0, // + 0, 1, 0, // + 0, 0, 1,// + }), 4, 3}); tests.add(new Object[] {new IdentityDictionary(4, true), // Dictionary.create(new double[] {// 1, 0, 0, 0, // @@ -104,6 +118,15 @@ public static Collection data() { 0, 0, 0, 0}), 5, 4}); + tests.add(new Object[] {new IdentityDictionary(4, true).sliceOutColumnRange(1, 4, 4), // + Dictionary.create(new double[] {// + 0, 0, 0, // + 1, 0, 0, // + 0, 1, 0, // + 0, 0, 1, // + 0, 0, 0}), + 5, 3}); + create(tests, 30, 300, 0.2); } catch(Exception e) { @@ -439,7 +462,13 @@ public void contains1WithReferenceMinus1() { @Test public void equalsEl() { - assertEquals(a, b); + try { + assertEquals(a, b); + } + catch(Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } } @Test @@ -609,11 +638,16 @@ public void containsValueWithReference(double value, double[] reference) { private static void compare(IDictionary a, IDictionary b, int nRow, int nCol) { try { - - String errorM = a.getClass().getSimpleName() + " " + b.getClass().getSimpleName(); - for(int i = 0; i < nRow; i++) - for(int j = 0; j < nCol; j++) - assertEquals(errorM, a.getValue(i, j, nCol), b.getValue(i, j, nCol), 0.0001); + if(a == null && b == null) + return; + else if(a == null || b == null) + fail("both outputs should be null if one is: \n" + a + " \n " + b); + else { + String errorM = a.getClass().getSimpleName() + " " + b.getClass().getSimpleName(); + for(int i = 0; i < nRow; i++) + for(int j = 0; j < nCol; j++) + assertEquals(errorM, a.getValue(i, j, nCol), b.getValue(i, j, nCol), 0.0001); + } } catch(Exception e) { e.printStackTrace(); @@ -677,4 +711,54 @@ private static double[] getReference(int nCol, int seed, double min, double max) reference[i] = r.nextDouble() * diff - min; return reference; } + + @Test + public void testSerialization() { + try { + // Serialize out + ByteArrayOutputStream bos = new ByteArrayOutputStream(); + DataOutputStream fos = new DataOutputStream(bos); + a.write(fos); + + // Serialize in + ByteArrayInputStream bis = new ByteArrayInputStream(bos.toByteArray()); + DataInputStream fis = new DataInputStream(bis); + + IDictionary n = DictionaryFactory.read(fis); + + compare(a, n, nRow, nCol); + } + catch(IOException e) { + throw new RuntimeException("Error in io", e); + } + catch(Exception e) { + e.printStackTrace(); + throw e; + } + } + + @Test + public void testSerializationB() { + try { + // Serialize out + ByteArrayOutputStream bos = new ByteArrayOutputStream(); + DataOutputStream fos = new DataOutputStream(bos); + b.write(fos); + + // Serialize in + ByteArrayInputStream bis = new ByteArrayInputStream(bos.toByteArray()); + DataInputStream fis = new DataInputStream(bis); + + IDictionary n = DictionaryFactory.read(fis); + + compare(b, n, nRow, nCol); + } + catch(IOException e) { + throw new RuntimeException("Error in io", e); + } + catch(Exception e) { + e.printStackTrace(); + throw e; + } + } } diff --git a/src/test/java/org/apache/sysds/test/component/compress/dictionary/PlaceHolderDictTest.java b/src/test/java/org/apache/sysds/test/component/compress/dictionary/PlaceHolderDictTest.java new file mode 100644 index 00000000000..88e5d8adcc3 --- /dev/null +++ b/src/test/java/org/apache/sysds/test/component/compress/dictionary/PlaceHolderDictTest.java @@ -0,0 +1,520 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.component.compress.dictionary; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.DataInputStream; +import java.io.DataOutputStream; +import java.io.IOException; + +import org.apache.sysds.runtime.compress.colgroup.dictionary.Dictionary; +import org.apache.sysds.runtime.compress.colgroup.dictionary.DictionaryFactory; +import org.apache.sysds.runtime.compress.colgroup.dictionary.IDictionary; +import org.apache.sysds.runtime.compress.colgroup.dictionary.PlaceHolderDict; +import org.junit.Test; + +public class PlaceHolderDictTest { + + PlaceHolderDict d = new PlaceHolderDict(1); + + @Test(expected = Exception.class) + public void getValues() { + d.getValues(); + } + + @Test(expected = Exception.class) + public void getValue() { + d.getValue(1); + } + + @Test(expected = Exception.class) + public void getValue2() { + d.getValue(1, 2, 3); + } + + @Test + public void getInMemorySize() { + assertEquals(16 + 4, d.getInMemorySize()); + } + + @Test(expected = Exception.class) + public void aggregate() { + d.aggregate(1, null); + } + + @Test(expected = Exception.class) + public void aggregateWithReference() { + d.aggregateWithReference(1, null, null, true); + } + + @Test(expected = Exception.class) + public void aggregateRows() { + d.aggregateRows(null, 1); + } + + @Test(expected = Exception.class) + public void aggregateRowsWithDefault() { + d.aggregateRowsWithDefault(null, null); + } + + @Test(expected = Exception.class) + public void aggregateRowsWithReference() { + d.aggregateRowsWithReference(null, null); + } + + @Test(expected = Exception.class) + public void aggregateCols() { + d.aggregateCols(null, null, null); + } + + @Test(expected = Exception.class) + public void aggregateColsWithReference() { + d.aggregateColsWithReference(null, null, null, null, false); + } + + @Test(expected = Exception.class) + public void applyScalarOp() { + d.applyScalarOp(null); + } + + @Test(expected = Exception.class) + public void applyScalarOpAndAppend() { + d.applyScalarOpAndAppend(null, 1, 1); + } + + @Test(expected = Exception.class) + public void applyUnaryOp() { + d.applyUnaryOp(null); + } + + @Test(expected = Exception.class) + public void applyUnaryOpAndAppend() { + d.applyUnaryOpAndAppend(null, 1, 1); + } + + @Test(expected = Exception.class) + public void applyScalarOpWithReference() { + d.applyScalarOpWithReference(null, null, null); + } + + @Test(expected = Exception.class) + public void applyUnaryOpWithReference() { + d.applyUnaryOpWithReference(null, null, null); + } + + @Test(expected = Exception.class) + public void binOpLeft() { + d.binOpLeft(null, null, null); + } + + @Test(expected = Exception.class) + public void binOpLeftAndAppend() { + d.binOpLeftAndAppend(null, null, null); + } + + @Test(expected = Exception.class) + public void binOpLeftWithReference() { + d.binOpLeftWithReference(null, null, null, null, null); + } + + @Test(expected = Exception.class) + public void binOpRight() { + d.binOpRight(null, null, null); + } + + @Test(expected = Exception.class) + public void binOpRightAndAppend() { + d.binOpRightAndAppend(null, null, null); + } + + @Test(expected = Exception.class) + public void binOpRight2() { + d.binOpRight(null, null); + } + + @Test(expected = Exception.class) + public void binOpRightWithReference() { + d.binOpRightWithReference(null, null, null, null, null); + } + + @Test + public void getExactSizeOnDisk() { + assertEquals(5, d.getExactSizeOnDisk()); + } + + @Test(expected = Exception.class) + public void getDictType() { + d.getDictType(); + } + + @Test + public void getNumberOfValues() { + assertEquals(1, d.getNumberOfValues(1)); + } + + @Test(expected = Exception.class) + public void sumAllRowsToDouble() { + d.sumAllRowsToDouble(3); + } + + @Test(expected = Exception.class) + public void sumAllRowsToDoubleWithDefault() { + d.sumAllRowsToDoubleWithDefault(null); + } + + @Test(expected = Exception.class) + public void sumAllRowsToDoubleWithReference() { + d.sumAllRowsToDoubleWithReference(null); + } + + @Test(expected = Exception.class) + public void sumAllRowsToDoubleSq() { + d.sumAllRowsToDoubleSq(1); + } + + @Test(expected = Exception.class) + public void sumAllRowsToDoubleSqWithDefault() { + d.sumAllRowsToDoubleSqWithDefault(null); + } + + @Test(expected = Exception.class) + public void sumAllRowsToDoubleSqWithReference() { + d.sumAllRowsToDoubleSqWithReference(null); + } + + @Test(expected = Exception.class) + public void productAllRowsToDouble() { + d.productAllRowsToDouble(22); + } + + @Test(expected = Exception.class) + public void productAllRowsToDoubleWithDefault() { + d.productAllRowsToDoubleWithDefault(null); + } + + @Test(expected = Exception.class) + public void productAllRowsToDoubleWithReference() { + d.productAllRowsToDoubleWithReference(null); + } + + @Test(expected = Exception.class) + public void colSum() { + d.colSum(null, null, null); + } + + @Test(expected = Exception.class) + public void colSumSq() { + d.colSumSq(null, null, null); + } + + @Test(expected = Exception.class) + public void colSumSqWithReference() { + d.colSumSqWithReference(null, null, null, null); + } + + @Test(expected = Exception.class) + public void sum() { + d.sum(null, 1); + } + + @Test(expected = Exception.class) + public void sumSq() { + d.sumSq(null, 1); + } + + @Test(expected = Exception.class) + public void sumSqWithReference() { + d.sumSqWithReference(null, null); + } + + @Test + public void getString() { + assertEquals("", d.getString(1)); + } + + @Test(expected = Exception.class) + public void sliceOutColumnRange() { + d.sliceOutColumnRange(1, 1, 1); + } + + @Test(expected = Exception.class) + public void containsValue() { + d.containsValue(1); + } + + @Test(expected = Exception.class) + public void containsValueWithReference() { + d.containsValueWithReference(1, null); + } + + @Test + public void getNumberNonZeros() { + assertEquals(-1, d.getNumberNonZeros(null, 1)); + } + + @Test(expected = Exception.class) + public void getNumberNonZerosWithReference() { + d.getNumberNonZerosWithReference(null, null, 1); + } + + @Test(expected = Exception.class) + public void addToEntry() { + d.addToEntry(null, 1, 1, 1); + } + + @Test(expected = Exception.class) + public void addToEntry2() { + d.addToEntry(null, 1, 1, 1, 2); + } + + @Test(expected = Exception.class) + public void addToEntryVectorized() { + d.addToEntryVectorized(null, 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8, 1); + } + + @Test(expected = Exception.class) + public void subtractTuple() { + d.subtractTuple(null); + } + + @Test(expected = Exception.class) + public void getMBDict() { + d.getMBDict(1); + } + + @Test(expected = Exception.class) + public void scaleTuples() { + d.scaleTuples(null, 1); + } + + @Test(expected = Exception.class) + public void preaggValuesFromDense() { + d.preaggValuesFromDense(1, null, null, null, 1); + } + + @Test(expected = Exception.class) + public void replace() { + d.replace(1, 1, 1); + } + + @Test(expected = Exception.class) + public void replaceWithReference() { + d.replaceWithReference(1, 1, null); + } + + @Test(expected = Exception.class) + public void product() { + d.product(null, null, 1); + } + + @Test(expected = Exception.class) + public void productWithDefault() { + d.productWithDefault(null, null, null, 1); + } + + @Test(expected = Exception.class) + public void productWithReference() { + d.productWithReference(null, null, null, 1); + } + + @Test(expected = Exception.class) + public void colProduct() { + d.colProduct(null, null, null); + } + + @Test(expected = Exception.class) + public void colProductWithReference() { + d.colProductWithReference(null, null, null, null); + } + + @Test(expected = Exception.class) + public void centralMoment() { + d.centralMoment(null, null, 1); + } + + @Test(expected = Exception.class) + public void centralMoment2() { + d.centralMoment(null, null, null, 1); + } + + @Test(expected = Exception.class) + public void centralMomentWithDefault() { + d.centralMomentWithDefault(null, null, 1, 1); + } + + @Test(expected = Exception.class) + public void centralMomentWithDefault2() { + d.centralMomentWithDefault(null, null, null, 1, 1); + } + + @Test(expected = Exception.class) + public void centralMomentWithReference() { + d.centralMomentWithReference(null, null, 1, 1); + } + + @Test(expected = Exception.class) + public void centralMomentWithReference2() { + d.centralMomentWithReference(null, null, null, 1, 1); + } + + @Test(expected = Exception.class) + public void rexpandCols() { + d.rexpandCols(1, false, false, 1); + } + + @Test(expected = Exception.class) + public void rexpandColsWithReference() { + d.rexpandColsWithReference(1, false, false, 1); + } + + @Test(expected = Exception.class) + public void getSparsity() { + d.getSparsity(); + } + + @Test(expected = Exception.class) + public void multiplyScalar() { + d.multiplyScalar(1, null, 1, 1, null); + } + + @Test(expected = Exception.class) + public void TSMMWithScaling() { + d.TSMMWithScaling(null, null, null, null); + } + + @Test(expected = Exception.class) + public void MMDict() { + d.MMDict(null, null, null, null); + } + + @Test(expected = Exception.class) + public void MMDictDense() { + d.MMDictDense(null, null, null, null); + } + + @Test(expected = Exception.class) + public void MMDictSparse() { + d.MMDictSparse(null, null, null, null); + } + + @Test(expected = Exception.class) + public void TSMMToUpperTriangle() { + d.TSMMToUpperTriangle(null, null, null, null); + } + + @Test(expected = Exception.class) + public void TSMMToUpperTriangleDense() { + d.TSMMToUpperTriangleDense(null, null, null, null); + } + + @Test(expected = Exception.class) + public void TSMMToUpperTriangleSparse() { + d.TSMMToUpperTriangleSparse(null, null, null, null); + } + + @Test(expected = Exception.class) + public void TSMMToUpperTriangleScaling() { + d.TSMMToUpperTriangleScaling(null, null, null, null, null); + } + + @Test(expected = Exception.class) + public void TSMMToUpperTriangleDenseScaling() { + d.TSMMToUpperTriangleDenseScaling(null, null, null, null, null); + } + + @Test(expected = Exception.class) + public void TSMMToUpperTriangleSparseScaling() { + d.TSMMToUpperTriangleSparseScaling(null, null, null, null, null); + } + + @Test(expected = Exception.class) + public void cbind() { + d.cbind(null, 1); + } + + @Test + public void equals() { + assertTrue(!d.equals(Dictionary.create(new double[2]))); + assertTrue(d.equals(new PlaceHolderDict(2))); + } + + @Test + public void equals2() { + assertTrue(!d.equals(new double[3])); + } + + @Test(expected = Exception.class) + public void reorder() { + d.reorder(null); + } + + @Test + public void cloneTest() { + assertTrue(d.equals(d.clone())); + } + + @Test(expected = Exception.class) + public void MMDictScaling() { + d.MMDictScaling(null, null, null, null, null); + } + + @Test(expected = Exception.class) + public void MMDictScalingDense() { + d.MMDictScalingDense(null, null, null, null, null); + } + + @Test(expected = Exception.class) + public void MMDictScalingSparse() { + d.MMDictScalingSparse(null, null, null, null, null); + } + + @Test(expected = Exception.class) + public void put() { + d.put(null, 1, 1, 1, null); + } + + @Test + public void testSerialization() { + try { + // Serialize out + ByteArrayOutputStream bos = new ByteArrayOutputStream(); + DataOutputStream fos = new DataOutputStream(bos); + d.write(fos); + + // Serialize in + ByteArrayInputStream bis = new ByteArrayInputStream(bos.toByteArray()); + DataInputStream fis = new DataInputStream(bis); + + IDictionary d2 = DictionaryFactory.read(fis); + assertTrue(d.equals(d2)); + } + catch(IOException e) { + throw new RuntimeException("Error in io", e); + } + catch(Exception e) { + e.printStackTrace(); + throw e; + } + } +} diff --git a/src/test/java/org/apache/sysds/test/component/compress/functional/LinearRegNegativeTest.java b/src/test/java/org/apache/sysds/test/component/compress/functional/LinearRegNegativeTest.java new file mode 100644 index 00000000000..bb030b352eb --- /dev/null +++ b/src/test/java/org/apache/sysds/test/component/compress/functional/LinearRegNegativeTest.java @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.component.compress.functional; + +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.when; + +import org.apache.sysds.runtime.compress.colgroup.functional.LinearRegression; +import org.apache.sysds.runtime.compress.colgroup.indexes.ColIndexFactory; +import org.apache.sysds.runtime.compress.colgroup.indexes.IColIndex; +import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import org.junit.Test; + +public class LinearRegNegativeTest { + + @Test(expected = Exception.class) + public void invalidRows() { + LinearRegression.regressMatrixBlock(new MatrixBlock(-1, -1, 132), null, false); + } + + @Test(expected = Exception.class) + public void invalidCols() { + + IColIndex spy = spy(ColIndexFactory.create(10)); + when(spy.size()).thenReturn(-1); + + LinearRegression.regressMatrixBlock(new MatrixBlock(10, 10, 132), spy, false); + } + + +} diff --git a/src/test/java/org/apache/sysds/test/component/compress/functional/LinearRegressionTests.java b/src/test/java/org/apache/sysds/test/component/compress/functional/LinearRegressionTests.java index f34a6c2bc8f..c321e96d514 100644 --- a/src/test/java/org/apache/sysds/test/component/compress/functional/LinearRegressionTests.java +++ b/src/test/java/org/apache/sysds/test/component/compress/functional/LinearRegressionTests.java @@ -126,5 +126,4 @@ public void testLineratRegressionEquivalentTransposed() { assertEquals(expectedException.getMessage(), e.getMessage()); } } - } diff --git a/src/test/java/org/apache/sysds/test/component/compress/lib/CLALibLMMTest.java b/src/test/java/org/apache/sysds/test/component/compress/lib/CLALibLMMTest.java new file mode 100644 index 00000000000..905ef59de25 --- /dev/null +++ b/src/test/java/org/apache/sysds/test/component/compress/lib/CLALibLMMTest.java @@ -0,0 +1,435 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.component.compress.lib; + +import static org.junit.Assert.fail; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.when; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.List; +import java.util.Random; + +import org.apache.commons.lang3.NotImplementedException; +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.apache.sysds.runtime.compress.CompressedMatrixBlock; +import org.apache.sysds.runtime.compress.CompressedMatrixBlockFactory; +import org.apache.sysds.runtime.compress.colgroup.AColGroup; +import org.apache.sysds.runtime.compress.colgroup.ColGroupConst; +import org.apache.sysds.runtime.compress.colgroup.ColGroupDDC; +import org.apache.sysds.runtime.compress.colgroup.ColGroupEmpty; +import org.apache.sysds.runtime.compress.colgroup.ColGroupUncompressed; +import org.apache.sysds.runtime.compress.colgroup.dictionary.IdentityDictionary; +import org.apache.sysds.runtime.compress.colgroup.indexes.ColIndexFactory; +import org.apache.sysds.runtime.compress.colgroup.mapping.AMapToData; +import org.apache.sysds.runtime.compress.lib.CLALibLeftMultBy; +import org.apache.sysds.runtime.compress.lib.CLALibMatrixMult; +import org.apache.sysds.runtime.data.DenseBlock; +import org.apache.sysds.runtime.functionobjects.Plus; +import org.apache.sysds.runtime.matrix.data.LibMatrixMult; +import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import org.apache.sysds.runtime.matrix.operators.RightScalarOperator; +import org.apache.sysds.test.TestUtils; +import org.apache.sysds.test.component.compress.mapping.MappingTestUtil; +import org.junit.BeforeClass; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; +import org.junit.runners.Parameterized.Parameters; + +@RunWith(value = Parameterized.class) +public class CLALibLMMTest { + protected static final Log LOG = LogFactory.getLog(CombineGroupsTest.class.getName()); + + @Parameterized.Parameter(0) + public String s; + @Parameterized.Parameter(1) + public MatrixBlock mb; + @Parameterized.Parameter(2) + public CompressedMatrixBlock cmb; + @Parameterized.Parameter(3) + public MatrixBlock mb2; + + @Parameterized.Parameter(4) + public MatrixBlock tcmb2; + + @BeforeClass + public static void setup() { + Thread.currentThread().setName("main_test_" + Thread.currentThread().getId()); + } + + @Parameters(name = "{0}") + public static Collection data() { + List tests = new ArrayList<>(); + MatrixBlock mb; + CompressedMatrixBlock cmb; + List gs; + + + // HACK + // because we have a sideways way of testing some morphing column groups + // a rowSum operation is available in the LMM library + // here we side call it to enable testing all of CLALibLeftMult from this one file. + try{ + CLALibLeftMultBy.rowSum(new MatrixBlock(), 0, 0, 0, 0); + } + catch(Exception e){ + // do nothing. + CLALibLeftMultBy.rowSum(new MatrixBlock(10,10,1.0), 0, 1, 0, 1); + } + + try { + mb = TestUtils.generateTestMatrixBlock(200, 50, -10, 10, 1.0, 32); + mb = TestUtils.round(mb); + cmb = (CompressedMatrixBlock) CompressedMatrixBlockFactory.compress(mb, 1).getLeft(); + genTests(tests, mb, cmb, "Normal"); + + cmb = (CompressedMatrixBlock) cmb.scalarOperations(new RightScalarOperator(Plus.getPlusFnObject(), 2), null); + mb = cmb.decompress(); + genTests(tests, mb, cmb, "NormalP2"); + + cmb = new CompressedMatrixBlock(mb.getNumRows(), mb.getNumColumns()); + cmb.allocateColGroup(ColGroupUncompressed.create(mb)); + genTests(tests, mb, cmb, "UncompressAbleGroup"); + + mb = TestUtils.generateTestMatrixBlock(200, 10, -5, 5, 1.0, 32); + mb = TestUtils.round(mb); + gs = new ArrayList<>(); + gs.add(ColGroupUncompressed.create(mb)); + gs.add(ColGroupUncompressed.create(mb)); + cmb = new CompressedMatrixBlock(mb.getNumRows(), mb.getNumColumns(), 1000, true, gs); + mb = cmb.decompress(); + genTests(tests, mb, cmb, "UncompressAbleGroup2"); + + MatrixBlock mb2 = new MatrixBlock(10, 10, 132.0); + cmb = CompressedMatrixBlockFactory.createConstant(10, 10, 132.0); + genTests(tests, mb2, cmb, "Const"); + + gs = new ArrayList<>(); + gs.add(ColGroupConst.create(ColIndexFactory.create(10), 100.0)); + gs.add(ColGroupConst.create(ColIndexFactory.create(10), 32.0)); + cmb = new CompressedMatrixBlock(10, 10, 100, true, gs); + genTests(tests, cmb.getUncompressed(), cmb, "OverlappingConst"); + + gs = new ArrayList<>(); + gs.add(ColGroupConst.create(ColIndexFactory.create(10), new double[] {13.0, 0, 0, 0, 0, 0, 0, 0, 0, 0})); + gs.add(ColGroupConst.create(ColIndexFactory.create(10), new double[] {13.0, 32., 0, 0, 0, 0, 0, 0, 0, 0})); + cmb = new CompressedMatrixBlock(10, 10, 100, true, gs); + genTests(tests, cmb.getUncompressed(), cmb, "OverlappingSparseConst"); + + mb = TestUtils.generateTestMatrixBlock(200, 16, -10, 10, 0.04, 32); + mb = TestUtils.round(mb); + cmb = (CompressedMatrixBlock) CompressedMatrixBlockFactory.compress(mb, 1).getLeft(); + genTests(tests, mb, cmb, "Sparse"); + + mb = TestUtils.generateTestMatrixBlock(200, 16, -10, 10, 0.2, 32); + mb = TestUtils.round(mb); + cmb = (CompressedMatrixBlock) CompressedMatrixBlockFactory.compress(mb, 1).getLeft(); + genTests(tests, mb, cmb, "Sparse2"); + + IdentityDictionary id = new IdentityDictionary(10); + AMapToData d = MappingTestUtil.createRandomMap(100, 10, new Random(23)); + AColGroup idg = ColGroupDDC.create(ColIndexFactory.create(10), id, d, null); + cmb = new CompressedMatrixBlock(100, 10); + cmb.allocateColGroup(idg); + mb = CompressedMatrixBlock.getUncompressed(cmb); + genTests(tests, mb, cmb, "Identity"); + + AColGroup empty = new ColGroupEmpty(ColIndexFactory.create(10)); + cmb = new CompressedMatrixBlock(100, 10); + cmb.allocateColGroup(empty); + genTests(tests, new MatrixBlock(10, 10, true), cmb, "Empty"); + + } + catch(Exception e) { + e.printStackTrace(); + fail("failed constructing tests"); + } + + return tests; + } + + private static void genTests(List tests, MatrixBlock mb, MatrixBlock cmb, String version) { + + MatrixBlock tmp; + MatrixBlock tcmb; + + final int nRow = cmb.getNumRows(); + final int nCol = cmb.getNumColumns(); + + tmp = TestUtils.generateTestMatrixBlock(nCol, nRow, -10, 10, 0.9, 132); + tmp = TestUtils.round(tmp); + tcmb = CompressedMatrixBlockFactory.compress(tmp.transpose(), 1).getLeft(); + tests.add(new Object[] {version + "_dense_full", mb, cmb, tmp, tcmb}); + + tmp = TestUtils.generateTestMatrixBlock(1, nRow, -10, 10, 0.9, 32); + tmp = TestUtils.round(tmp); + tcmb = CompressedMatrixBlockFactory.compress(tmp.transpose(), 1).getLeft(); + tests.add(new Object[] {version + "_dense_vector", mb, cmb, tmp, tcmb}); + + tmp = TestUtils.generateTestMatrixBlock(2, nRow, -10, 10, 0.9, 32); + tmp = TestUtils.round(tmp); + tcmb = CompressedMatrixBlockFactory.compress(tmp.transpose(), 1).getLeft(); + tests.add(new Object[] {version + "_dense_two_rows", mb, cmb, tmp, tcmb}); + + tmp = TestUtils.generateTestMatrixBlock(3, nRow, -10, 10, 0.1, 314); + tmp = TestUtils.round(tmp); + tcmb = CompressedMatrixBlockFactory.compress(tmp.transpose(), 1).getLeft(); + tests.add(new Object[] {version + "_sparse_full", mb, cmb, tmp, tcmb}); + + tmp = TestUtils.generateTestMatrixBlock(1, nRow, -10, 10, 0.1, 2323); + tmp = TestUtils.round(tmp); + tcmb = CompressedMatrixBlockFactory.compress(tmp.transpose(), 1).getLeft(); + tests.add(new Object[] {version + "_sparse_vector", mb, cmb, tmp, tcmb}); + + tmp = TestUtils.generateTestMatrixBlock(2, nRow, -10, 10, 0.1, 2323); + tmp = TestUtils.round(tmp); + tcmb = CompressedMatrixBlockFactory.compress(tmp.transpose(), 1).getLeft(); + tests.add(new Object[] {version + "_sparse_two_rows", mb, cmb, tmp, tcmb}); + + tmp = new MatrixBlock(1, nRow, 0.0).append(tmp, false); + tcmb = CompressedMatrixBlockFactory.compress(tmp.transpose(), 1).getLeft(); + tests.add(new Object[] {version + "_sparse_empty_row", mb, cmb, tmp, tcmb}); + + tests.add(new Object[] {version + "_self_transposed", mb, cmb, mb.transpose(), null}); + + tests.add(new Object[] {version + "_empty", mb, cmb, new MatrixBlock(2, nRow, true), + new CompressedMatrixBlock(nRow, 2)}); + + tcmb = createSelectionMatrix(nRow, 30, false); + // decompressed for uncompressed operation + tmp = CompressedMatrixBlock.getUncompressed(tcmb); + // compress the transposed version of it. + tcmb = CompressedMatrixBlockFactory.compress(tmp.transpose(), 1).getLeft(); + tests.add(new Object[] {version + "_selection", mb, cmb, tmp, tcmb}); + + tcmb = createSelectionMatrix(nRow, 30, true); + // decompressed for uncompressed operation + tmp = CompressedMatrixBlock.getUncompressed(tcmb); + // compress the transposed version of it. + tcmb = CompressedMatrixBlockFactory.compress(tmp.transpose(), 1).getLeft(); + tests.add(new Object[] {version + "_selection_with_empty", mb, cmb, tmp, tcmb}); + + } + + private static MatrixBlock createSelectionMatrix(final int nRow, final int nRowLeft, boolean emptyRows) { + MatrixBlock tcmb; + IdentityDictionary id = new IdentityDictionary(nRow, emptyRows); + AMapToData d = MappingTestUtil.createRandomMap(nRowLeft, nRow + (emptyRows ? 1 : 0), new Random(33)); + AColGroup idg = ColGroupDDC.create(ColIndexFactory.create(nRow), id, d, null); + tcmb = new CompressedMatrixBlock(nRowLeft, nRow); + ((CompressedMatrixBlock) tcmb).allocateColGroup(idg); + return tcmb; + } + + @Test + public void testMultiplicationSingleThread() { + try { + exec(mb, cmb, mb2, 1); + } + catch(Exception e) { + if(!causeWasNotImplemented(e)) { + e.printStackTrace(); + fail(e.getMessage()); + } + } + } + + @Test + public void testMultiplicationParallel() { + try { + exec(mb, cmb, mb2, 4); + } + catch(Exception e) { + if(!causeWasNotImplemented(e)) { + e.printStackTrace(); + fail(e.getMessage()); + } + } + } + + @Test + public void testMultiplicationThatNonContinuousSingleThread() { + try { + if(!mb2.isEmpty() && !mb2.isInSparseFormat()) { + DenseBlock spy = spy(mb2.getDenseBlock()); + when(spy.isContiguous()).thenReturn(false); + MatrixBlock mb2t = new MatrixBlock(mb2.getNumRows(), mb2.getNumColumns(), spy); + exec(mb, cmb, mb2t, 1); + } + } + catch(Exception e) { + if(!causeWasNotImplemented(e)) { + e.printStackTrace(); + fail(e.getMessage()); + } + } + } + + @Test + public void testMultiplicationThatNonContinuousParallel() { + try { + if(!mb2.isEmpty() && !mb2.isInSparseFormat()) { + DenseBlock spy = spy(mb2.getDenseBlock()); + when(spy.isContiguous()).thenReturn(false); + MatrixBlock mb2t = new MatrixBlock(mb2.getNumRows(), mb2.getNumColumns(), spy); + exec(mb, cmb, mb2t, 4); + } + } + catch(Exception e) { + if(!causeWasNotImplemented(e)) { + e.printStackTrace(); + fail(e.getMessage()); + } + } + } + + @Test + public void testMultiplicationOverlappingSingleThread() { + try { + CompressedMatrixBlock spy = spy(cmb); + when(spy.isOverlapping()).thenReturn(true); + exec(mb, spy, mb2, 1); + } + catch(Exception e) { + if(!causeWasNotImplemented(e)) { + e.printStackTrace(); + fail(e.getMessage()); + } + } + } + + @Test + public void testMultiplicationOverlappingParallel() { + try { + CompressedMatrixBlock spy = spy(cmb); + when(spy.isOverlapping()).thenReturn(true); + exec(mb, spy, mb2, 4); + } + catch(Exception e) { + if(!causeWasNotImplemented(e)) { + e.printStackTrace(); + fail(e.getMessage()); + } + } + } + + @Test + public void testMultiplicationRetAllocatedParallel() { + try { + MatrixBlock ret = new MatrixBlock(mb2.getNumRows(), cmb.getNumColumns(), false); + ret.allocateDenseBlock(); + DenseBlock spy = spy(ret.getDenseBlock()); + when(spy.isContiguous()).thenReturn(false); + ret.setDenseBlock(spy); + + execR(mb, cmb, mb2, ret, 4); + } + catch(Exception e) { + if(!causeWasNotImplemented(e)) { + e.printStackTrace(); + fail(e.getMessage()); + } + } + } + + @Test + public void testMultiplicationTransposeLeftSingleThread() { + try { + + execTl(mb, cmb, mb2, tcmb2, 1); + } + catch(Exception e) { + if(!causeWasNotImplemented(e)) { + + e.printStackTrace(); + fail(e.getMessage()); + } + + } + } + + @Test + public void testMultiplicationTransposeLeftParallel() { + try { + execTl(mb, cmb, mb2, tcmb2, 4); + } + catch(Exception e) { + if(!causeWasNotImplemented(e)) { + e.printStackTrace(); + fail(e.getMessage()); + } + } + } + + @Test + public void testMultiplicationTransposeLeftRetAllocatedParallel() { + try { + MatrixBlock ret = new MatrixBlock(mb2.getNumRows(), cmb.getNumColumns(), false); + ret.allocateDenseBlock(); + DenseBlock spy = spy(ret.getDenseBlock()); + when(spy.isContiguous()).thenReturn(false); + ret.setDenseBlock(spy); + + execTlR(mb, cmb, mb2, tcmb2, ret, 4); + } + catch(Exception e) { + if(!causeWasNotImplemented(e)) { + e.printStackTrace(); + fail(e.getMessage()); + } + } + } + + private boolean causeWasNotImplemented(Throwable e) { + return e instanceof NotImplementedException || // + (e.getCause() != null && causeWasNotImplemented(e.getCause())); + } + + private static void execTl(MatrixBlock mb1, CompressedMatrixBlock cmb1, MatrixBlock mb2, MatrixBlock tmb2, int k) { + execTlR(mb1, cmb1, mb2, tmb2, null, k); + } + + private static void execTlR(MatrixBlock mb1, CompressedMatrixBlock cmb1, MatrixBlock mb2, MatrixBlock tmb2, + MatrixBlock ret, int k) { + if(tmb2 == null) // then it is the transpose self case + tmb2 = cmb1; + MatrixBlock cRet = CLALibMatrixMult.matrixMultiply(tmb2, cmb1, ret, k, true, false); + MatrixBlock uRet = LibMatrixMult.matrixMult(mb2, mb1, k); + compare(cRet, uRet); + } + + private static void exec(MatrixBlock mb1, CompressedMatrixBlock cmb1, MatrixBlock mb2, int k) { + execR(mb1, cmb1, mb2, null, k); + } + + private static void execR(MatrixBlock mb1, CompressedMatrixBlock cmb1, MatrixBlock mb2, MatrixBlock ret, int k) { + MatrixBlock cRet = CLALibMatrixMult.matrixMultiply(mb2, cmb1, ret, k); + MatrixBlock uRet = LibMatrixMult.matrixMult(CompressedMatrixBlock.getUncompressed(mb2), mb1, k); + compare(cRet, uRet); + } + + private static void compare(MatrixBlock cRet, MatrixBlock uRet) { + TestUtils.compareMatricesBitAvgDistance(uRet, cRet, 0, 0); + } + +} diff --git a/src/test/java/org/apache/sysds/test/component/compress/lib/CLALibSelectionMultCustomTest.java b/src/test/java/org/apache/sysds/test/component/compress/lib/CLALibSelectionMultCustomTest.java new file mode 100644 index 00000000000..2b5b52df949 --- /dev/null +++ b/src/test/java/org/apache/sysds/test/component/compress/lib/CLALibSelectionMultCustomTest.java @@ -0,0 +1,128 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.component.compress.lib; + +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertThrows; +import static org.junit.Assert.assertTrue; + +import org.apache.sysds.runtime.compress.lib.CLALibSelectionMult; +import org.apache.sysds.runtime.data.SparseBlockCSR; +import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import org.junit.Test; + +public class CLALibSelectionMultCustomTest { + + @Test + public void isSelectionEmpty() { + MatrixBlock mb = new MatrixBlock(10, 10, false); + assertFalse(CLALibSelectionMult.isSelectionMatrix(mb)); + } + + @Test + public void isSelectionOneCell() { + MatrixBlock mb = new MatrixBlock(10, 10, true); + mb.allocateSparseRowsBlock(); + mb.appendValue(1, 1, 1); + assertTrue(CLALibSelectionMult.isSelectionMatrix(mb)); + } + + @Test + public void isSelectionOneCellIncorrectValue() { + MatrixBlock mb = new MatrixBlock(10, 10, true); + mb.allocateSparseRowsBlock(); + mb.appendValue(1, 1, 2); + assertFalse(CLALibSelectionMult.isSelectionMatrix(mb)); + } + + @Test + public void isSelectionOneCellCSR() { + MatrixBlock mb = new MatrixBlock(10, 10, true); + mb.allocateSparseRowsBlock(); + mb.appendValue(1, 1, 1); + SparseBlockCSR sb = new SparseBlockCSR(mb.getSparseBlock()); + mb.setSparseBlock(sb); + assertTrue(CLALibSelectionMult.isSelectionMatrix(mb)); + } + + @Test + public void isSelectionOneCellCSRIncorrectValue() { + MatrixBlock mb = new MatrixBlock(10, 10, true); + mb.allocateSparseRowsBlock(); + mb.appendValue(1, 1, 2); + SparseBlockCSR sb = new SparseBlockCSR(mb.getSparseBlock()); + mb.setSparseBlock(sb); + assertFalse(CLALibSelectionMult.isSelectionMatrix(mb)); + } + + @Test + public void isSelectionTwoCellsOneRow() { + MatrixBlock mb = new MatrixBlock(10, 10, true); + mb.allocateSparseRowsBlock(); + mb.appendValue(1, 1, 1); + mb.appendValue(1, 2, 1); + assertFalse(CLALibSelectionMult.isSelectionMatrix(mb)); + } + + @Test + public void isSelectionTwoCellsTwoRows() { + MatrixBlock mb = new MatrixBlock(10, 10, true); + mb.allocateSparseRowsBlock(); + mb.appendValue(1, 1, 1); + mb.appendValue(0, 1, 1); + assertTrue(CLALibSelectionMult.isSelectionMatrix(mb)); + } + + @Test + public void isSelectionTwoCellsTwoRowsInvalidValue() { + MatrixBlock mb = new MatrixBlock(10, 10, true); + mb.allocateSparseRowsBlock(); + mb.appendValue(1, 1, 1); + mb.appendValue(0, 1, 2); + assertFalse(CLALibSelectionMult.isSelectionMatrix(mb)); + } + + @Test + public void isSelectionMorePointsThanRows() { + MatrixBlock mb = new MatrixBlock(2, 2, true); + mb.allocateSparseRowsBlock(); + mb.appendValue(1, 1, 1); + mb.appendValue(0, 1, 1); + mb.appendValue(0, 0, 1); + assertFalse(CLALibSelectionMult.isSelectionMatrix(mb)); + } + + @Test + public void isSelectionDenseBlock() { + MatrixBlock mb = new MatrixBlock(2, 2, false); + mb.allocateDenseBlock(); + // mb.allocateSparseRowsBlock(); + mb.appendValue(1, 1, 1); + mb.appendValue(0, 1, 1); + // mb.appendValue(2, 2, 1); + assertFalse(CLALibSelectionMult.isSelectionMatrix(mb)); + } + + @Test + public void selectionError() { + Exception e = assertThrows(Exception.class, () -> CLALibSelectionMult.leftSelection(null, null, null, 1)); + assertTrue(e.getMessage().contains("Failed left selection Multiplication")); + } +} diff --git a/src/test/java/org/apache/sysds/test/component/compress/lib/CLALibSelectionMultTest.java b/src/test/java/org/apache/sysds/test/component/compress/lib/CLALibSelectionMultTest.java new file mode 100644 index 00000000000..980a1558506 --- /dev/null +++ b/src/test/java/org/apache/sysds/test/component/compress/lib/CLALibSelectionMultTest.java @@ -0,0 +1,285 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.component.compress.lib; + +import static org.junit.Assert.fail; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.when; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.List; +import java.util.Random; + +import org.apache.commons.lang3.NotImplementedException; +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.apache.sysds.runtime.compress.CompressedMatrixBlock; +import org.apache.sysds.runtime.compress.CompressedMatrixBlockFactory; +import org.apache.sysds.runtime.compress.colgroup.AColGroup; +import org.apache.sysds.runtime.compress.colgroup.ColGroupConst; +import org.apache.sysds.runtime.compress.colgroup.ColGroupDDC; +import org.apache.sysds.runtime.compress.colgroup.ColGroupEmpty; +import org.apache.sysds.runtime.compress.colgroup.ColGroupUncompressed; +import org.apache.sysds.runtime.compress.colgroup.dictionary.IdentityDictionary; +import org.apache.sysds.runtime.compress.colgroup.indexes.ColIndexFactory; +import org.apache.sysds.runtime.compress.colgroup.mapping.AMapToData; +import org.apache.sysds.runtime.compress.lib.CLALibSelectionMult; +import org.apache.sysds.runtime.functionobjects.Plus; +import org.apache.sysds.runtime.matrix.data.LibMatrixMult; +import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import org.apache.sysds.runtime.matrix.operators.RightScalarOperator; +import org.apache.sysds.test.TestUtils; +import org.apache.sysds.test.component.compress.mapping.MappingTestUtil; +import org.junit.BeforeClass; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; +import org.junit.runners.Parameterized.Parameters; + +@RunWith(value = Parameterized.class) +public class CLALibSelectionMultTest { + protected static final Log LOG = LogFactory.getLog(CombineGroupsTest.class.getName()); + + @Parameterized.Parameter(0) + public String s; + @Parameterized.Parameter(1) + public MatrixBlock mb; + @Parameterized.Parameter(2) + public CompressedMatrixBlock cmb; + @Parameterized.Parameter(3) + public MatrixBlock mb2; + + @BeforeClass + public static void setup() { + Thread.currentThread().setName("main_test_" + Thread.currentThread().getId()); + } + + @Parameters(name = "{0}") + public static Collection data() { + List tests = new ArrayList<>(); + MatrixBlock mb; + CompressedMatrixBlock cmb; + List gs; + + try { + mb = TestUtils.generateTestMatrixBlock(200, 50, -10, 10, 1.0, 32); + mb = TestUtils.round(mb); + cmb = (CompressedMatrixBlock) CompressedMatrixBlockFactory.compress(mb, 1).getLeft(); + genTests(tests, mb, cmb, "Normal"); + + mb = TestUtils.generateTestMatrixBlock(1020, 50, -10, 10, 1.0, 32); + mb = TestUtils.round(mb); + cmb = (CompressedMatrixBlock) CompressedMatrixBlockFactory.compress(mb, 1).getLeft(); + genTests(tests, mb, cmb, "NormalLarge"); + + cmb = (CompressedMatrixBlock) cmb.scalarOperations(new RightScalarOperator(Plus.getPlusFnObject(), 2), null); + mb = cmb.decompress(); + genTests(tests, mb, cmb, "NormalP2"); + + cmb = new CompressedMatrixBlock(mb.getNumRows(), mb.getNumColumns()); + cmb.allocateColGroup(ColGroupUncompressed.create(mb)); + genTests(tests, mb, cmb, "UncompressAbleGroup"); + + mb = TestUtils.generateTestMatrixBlock(200, 10, -5, 5, 1.0, 32); + mb = TestUtils.round(mb); + gs = new ArrayList<>(); + gs.add(ColGroupUncompressed.create(mb)); + gs.add(ColGroupUncompressed.create(mb)); + cmb = new CompressedMatrixBlock(mb.getNumRows(), mb.getNumColumns(), 1000, true, gs); + mb = cmb.decompress(); + genTests(tests, mb, cmb, "UncompressAbleGroup2"); + + MatrixBlock mb2 = new MatrixBlock(10, 10, 132.0); + cmb = CompressedMatrixBlockFactory.createConstant(10, 10, 132.0); + genTests(tests, mb2, cmb, "Const"); + + gs = new ArrayList<>(); + gs.add(ColGroupConst.create(ColIndexFactory.create(10), 100.0)); + gs.add(ColGroupConst.create(ColIndexFactory.create(10), 32.0)); + cmb = new CompressedMatrixBlock(10, 10, 100, true, gs); + genTests(tests, cmb.getUncompressed(), cmb, "OverlappingConst"); + + gs = new ArrayList<>(); + gs.add(ColGroupConst.create(ColIndexFactory.create(10), new double[] {13.0, 0, 0, 0, 0, 0, 0, 0, 0, 0})); + gs.add(ColGroupConst.create(ColIndexFactory.create(10), new double[] {13.0, 32., 0, 0, 0, 0, 0, 0, 0, 0})); + cmb = new CompressedMatrixBlock(10, 10, 100, true, gs); + genTests(tests, cmb.getUncompressed(), cmb, "OverlappingSparseConst"); + + gs = new ArrayList<>(); + gs.add(ColGroupConst.create(ColIndexFactory.create(10), new double[] {13.0, 0, 0, 0, 0, 0, 0, 0, 0, 0})); + IdentityDictionary id = new IdentityDictionary(10); + AMapToData d = MappingTestUtil.createRandomMap(100, 10, new Random(23)); + AColGroup idg = ColGroupDDC.create(ColIndexFactory.create(10, 100), id, d, null); + gs.add(idg); + cmb = new CompressedMatrixBlock(100, 100, 100, true, gs); + genTests(tests, cmb.getUncompressed(), cmb, "OverlappingSparseConst"); + + mb = TestUtils.generateTestMatrixBlock(200, 16, -10, 10, 0.04, 32); + mb = TestUtils.round(mb); + cmb = (CompressedMatrixBlock) CompressedMatrixBlockFactory.compress(mb, 1).getLeft(); + genTests(tests, mb, cmb, "Sparse"); + + mb = TestUtils.generateTestMatrixBlock(200, 16, -10, 10, 0.2, 32); + mb = TestUtils.round(mb); + cmb = (CompressedMatrixBlock) CompressedMatrixBlockFactory.compress(mb, 1).getLeft(); + genTests(tests, mb, cmb, "Sparse2"); + + mb = TestUtils.generateTestMatrixBlock(1023, 16, -10, 10, 0.01, 32); + mb = TestUtils.round(mb); + cmb = (CompressedMatrixBlock) CompressedMatrixBlockFactory.compress(mb, 1).getLeft(); + genTests(tests, mb, cmb, "SparseLarge"); + + d = MappingTestUtil.createRandomMap(100, 10, new Random(23)); + idg = ColGroupDDC.create(ColIndexFactory.create(10), id, d, null); + cmb = new CompressedMatrixBlock(100, 10); + cmb.allocateColGroup(idg); + mb = CompressedMatrixBlock.getUncompressed(cmb); + genTests(tests, mb, cmb, "Identity"); + + AColGroup empty = new ColGroupEmpty(ColIndexFactory.create(10)); + cmb = new CompressedMatrixBlock(100, 10); + cmb.allocateColGroup(empty); + genTests(tests, new MatrixBlock(10, 10, true), cmb, "Empty"); + + } + catch(Exception e) { + e.printStackTrace(); + fail("failed constructing tests"); + } + + return tests; + } + + private static void genTests(List tests, MatrixBlock mb, MatrixBlock cmb, String version) { + + MatrixBlock tmp; + MatrixBlock tcmb; + + final int nRow = cmb.getNumRows(); + // final int nCol = cmb.getNumColumns(); + + tcmb = createSelectionMatrix(nRow, 30, false); + // decompressed for uncompressed operation + tmp = CompressedMatrixBlock.getUncompressed(tcmb); + // compress the transposed version of it. + tcmb = CompressedMatrixBlockFactory.compress(tmp.transpose(), 1).getLeft(); + tests.add(new Object[] {version + "_selection", mb, cmb, tmp}); + + tcmb = createSelectionMatrix(nRow, 30, true); + // decompressed for uncompressed operation + tmp = CompressedMatrixBlock.getUncompressed(tcmb); + // compress the transposed version of it. + tcmb = CompressedMatrixBlockFactory.compress(tmp.transpose(), 1).getLeft(); + tests.add(new Object[] {version + "_selection_with_empty", mb, cmb, tmp}); + + tcmb = createSelectionMatrix(nRow, 1002, true); + // decompressed for uncompressed operation + tmp = CompressedMatrixBlock.getUncompressed(tcmb); + // compress the transposed version of it. + tcmb = CompressedMatrixBlockFactory.compress(tmp.transpose(), 1).getLeft(); + tests.add(new Object[] {version + "_selection_with_empty_1002", mb, cmb, tmp}); + + } + + public static MatrixBlock createSelectionMatrix(final int nRow, final int nRowLeft, boolean emptyRows) { + MatrixBlock tcmb; + IdentityDictionary id = new IdentityDictionary(nRow, emptyRows); + AMapToData d = MappingTestUtil.createRandomMap(nRowLeft, nRow + (emptyRows ? 1 : 0), new Random(33)); + AColGroup idg = ColGroupDDC.create(ColIndexFactory.create(nRow), id, d, null); + tcmb = new CompressedMatrixBlock(nRowLeft, nRow); + ((CompressedMatrixBlock) tcmb).allocateColGroup(idg); + return tcmb; + } + + @Test + public void testMultiplicationSingleThread() { + try { + exec(mb, cmb, mb2, 1); + } + catch(Exception e) { + if(!causeWasNotImplemented(e)) { + e.printStackTrace(); + fail(e.getMessage()); + } + } + } + + @Test + public void testMultiplicationUnknownNonZeros() { + try { + CompressedMatrixBlock spy = spy(cmb); + when(spy.getNonZeros()).thenReturn(-1L); + exec(mb, spy, mb2, 1); + } + catch(Exception e) { + if(!causeWasNotImplemented(e)) { + e.printStackTrace(); + fail(e.getMessage()); + } + } + } + + @Test + public void testMultiplicationParallel() { + try { + exec(mb, cmb, mb2, 4); + } + catch(Exception e) { + if(!causeWasNotImplemented(e)) { + e.printStackTrace(); + fail(e.getMessage()); + } + } + } + + @Test + public void testMultiplicationParallelAllocatedRet() { + try { + execR(mb, cmb, mb2, new MatrixBlock(), 4); + } + catch(Exception e) { + if(!causeWasNotImplemented(e)) { + e.printStackTrace(); + fail(e.getMessage()); + } + } + } + + private boolean causeWasNotImplemented(Throwable e) { + return e instanceof NotImplementedException || // + (e.getCause() != null && causeWasNotImplemented(e.getCause())); + } + + private static void exec(MatrixBlock mb1, CompressedMatrixBlock cmb1, MatrixBlock mb2, int k) { + execR(mb1, cmb1, mb2, null, k); + } + + private static void execR(MatrixBlock mb1, CompressedMatrixBlock cmb1, MatrixBlock mb2, MatrixBlock ret, int k) { + MatrixBlock cRet = CLALibSelectionMult.leftSelection(cmb1, mb2, ret, k); + MatrixBlock uRet = LibMatrixMult.matrixMult(CompressedMatrixBlock.getUncompressed(mb2), mb1, k); + compare(cRet, uRet); + } + + private static void compare(MatrixBlock cRet, MatrixBlock uRet) { + TestUtils.compareMatricesBitAvgDistance(uRet, cRet, 0, 0); + } + +} diff --git a/src/test/java/org/apache/sysds/test/component/compress/mapping/MappingTestUtil.java b/src/test/java/org/apache/sysds/test/component/compress/mapping/MappingTestUtil.java index bcb9703a96f..2e04a642dc8 100644 --- a/src/test/java/org/apache/sysds/test/component/compress/mapping/MappingTestUtil.java +++ b/src/test/java/org/apache/sysds/test/component/compress/mapping/MappingTestUtil.java @@ -78,7 +78,7 @@ protected static int getTypeSize(MAP_TYPE t) { } } - protected static AMapToData createRandomMap(int len, int nUnique, Random r) { + public static AMapToData createRandomMap(int len, int nUnique, Random r) { AMapToData m = MapToFactory.create(len, nUnique); for(int i = 0; i < len; i++) m.set(i, r.nextInt(nUnique));