Skip to content

Commit

Permalink
more identity tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Baunsgaard committed Sep 23, 2024
1 parent 6bd85f4 commit 9c606ca
Show file tree
Hide file tree
Showing 7 changed files with 64 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ else if(shouldPreAggregateLeft(lhs)) {// left preAgg
DictLibMatrixMult.MMDicts(lDict, lhsPA, leftIdx, rightIdx, result);
}
else {// right preAgg
final IDictionary rhsPA = preAggregateThatIndexStructure(lhs);
final IDictionary rhsPA = this.preAggregateThatIndexStructure(lhs);
if(rhsPA != null)
DictLibMatrixMult.MMDicts(rhsPA, rDict, leftIdx, rightIdx, result);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -687,7 +687,7 @@ private void leftMMIdentityPreAggregateDenseSingleRow(double[] values, int pos,
for(int rc = cl; rc < cu; rc++, pos++) {
final int idx = _data.getIndex(rc);
if(idx != nVal)
values2[_colIndexes.get(_data.getIndex(rc))] += values[pos];
values2[_colIndexes.get(idx)] += values[pos];
}
}
else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -651,11 +651,13 @@ public void MMDictScaling(IDictionary right, IColIndex rowsLeft, IColIndex colsR
@Override
public void MMDictDense(double[] left, IColIndex rowsLeft, IColIndex colsRight, MatrixBlock result) {
// similar to fused transpose left into right locations.

final int leftSide = rowsLeft.size();
final int colsOut = result.getNumColumns();
final int commonDim = Math.min(left.length / leftSide, nRowCol);
final double[] resV = result.getDenseBlockValues();
for(int i = 0; i < leftSide; i++) { // rows in left side
final int offOut = rowsLeft.get(i) * commonDim;
final int offOut = rowsLeft.get(i) * colsOut;
final int leftOff = i;
for(int j = 0; j < commonDim; j++) { // cols in left side skipping empty from identity
resV[offOut + colsRight.get(j)] += left[leftOff + j * leftSide];
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,7 @@ private static MatrixBlock leftMultByCompressedTransposedMatrixSingleThread(Comp
// Force dense output
ret.setNonZeros((long) ret.getNumRows() * ret.getNumColumns());
ret.allocateDenseBlock();

for(int j = 0; j < fLeft.size(); j++)
for(int i = 0; i < fRight.size(); i++)
fRight.get(i).leftMultByAColGroup(fLeft.get(j), ret, sd);
Expand Down Expand Up @@ -314,8 +315,6 @@ private static MatrixBlock LMM(List<AColGroup> colGroups, MatrixBlock that, Matr
else
rowSums = that.rowSum(k).getDenseBlockValues();

// final double multTime = t.stop();

// add the correction layer for the subtracted common values.
if(rowSums != null) {
if(ret.isEmpty())
Expand All @@ -324,11 +323,6 @@ private static MatrixBlock LMM(List<AColGroup> colGroups, MatrixBlock that, Matr
ret.sparseToDense();
outerProduct(rowSums, constV, ret, k);
}

// final double outerProd = t.stop();
// if(LOG.isDebugEnabled()) {
// LOG.debug(String.format("LLM: filter: %10f Mult: %10f outer: %10f", filterGroupsTime, multTime, outerProd));
// }
}
else {
CLALibUtils.splitPreAgg(colGroups, noPreAggGroups, preAggGroups);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,18 @@
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;

import org.apache.commons.lang3.NotImplementedException;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysds.runtime.compress.colgroup.AColGroup;
import org.apache.sysds.runtime.compress.colgroup.ColGroupDDC;
import org.apache.sysds.runtime.compress.colgroup.ColGroupEmpty;
import org.apache.sysds.runtime.compress.colgroup.ColGroupSDCSingleZeros;
import org.apache.sysds.runtime.compress.colgroup.dictionary.Dictionary;
import org.apache.sysds.runtime.compress.colgroup.dictionary.PlaceHolderDict;
import org.apache.sysds.runtime.compress.colgroup.indexes.ColIndexFactory;
import org.apache.sysds.runtime.compress.colgroup.indexes.IColIndex;
import org.apache.sysds.runtime.compress.colgroup.mapping.MapToFactory;
import org.apache.sysds.runtime.compress.colgroup.offset.OffsetFactory;
import org.junit.Test;

Expand Down Expand Up @@ -65,4 +69,15 @@ public void appendEmptyToSDCZero2() {
assertEquals(((ColGroupSDCSingleZeros) r).getNumRows(), 7 * 20);

}

@Test(expected = NotImplementedException.class)
public void preAggSparseError() {

AColGroup g = ColGroupDDC.create(ColIndexFactory.create(3),
Dictionary.create(new double[] {1, 2, 3, 4, 5, 6, 7, 8, 9}),
MapToFactory.create(new int[] {0, 0, 0, 1, 1, 1, 2, 2, 2}, 3), null);

((ColGroupDDC)g).preAggregateSparse(null, null, 0, 3, 1, 2);

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,32 @@ public static Collection<Object[]> data() {
mb = CompressedMatrixBlock.getUncompressed(cmb);
genTests(tests, mb, cmb, "Identity");


d = MappingTestUtil.createRandomMap(100, 10, new Random(23));
idg = ColGroupDDC.create(ColIndexFactory.createI(0,1,2,3,4,6,7,8,9,10), id, d, null);
cmb = new CompressedMatrixBlock(100, 11);
cmb.allocateColGroup(idg);
mb = CompressedMatrixBlock.getUncompressed(cmb);
genTests(tests, mb, cmb, "Identity2");

id = new IdentityDictionary(10, true);

// continuous index range
d = MappingTestUtil.createRandomMap(100, 11, new Random(33));
idg = ColGroupDDC.create(ColIndexFactory.create(10), id, d, null);
cmb = new CompressedMatrixBlock(100, 10);
cmb.allocateColGroup(idg);
mb = CompressedMatrixBlock.getUncompressed(cmb);
genTests(tests, mb, cmb, "Identity_empty");

// not continuous.
d = MappingTestUtil.createRandomMap(100, 11, new Random(33));
idg = ColGroupDDC.create(ColIndexFactory.createI(0,1,2,3,4,6,7,8,9,10), id, d, null);
cmb = new CompressedMatrixBlock(100, 11);
cmb.allocateColGroup(idg);
mb = CompressedMatrixBlock.getUncompressed(cmb);
genTests(tests, mb, cmb, "Identity_empty2");

AColGroup empty = new ColGroupEmpty(ColIndexFactory.create(10));
cmb = new CompressedMatrixBlock(100, 10);
cmb.allocateColGroup(empty);
Expand Down Expand Up @@ -415,6 +441,7 @@ private static void execTlR(MatrixBlock mb1, CompressedMatrixBlock cmb1, MatrixB
tmb2 = cmb1;
MatrixBlock cRet = CLALibMatrixMult.matrixMultiply(tmb2, cmb1, ret, k, true, false);
MatrixBlock uRet = LibMatrixMult.matrixMult(mb2, mb1, k);

compare(cRet, uRet);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,22 @@ public static Collection<Object[]> data() {
mb = CompressedMatrixBlock.getUncompressed(cmb);
genTests(tests, mb, cmb, "Identity");

d = MappingTestUtil.createRandomMap(100, 10, new Random(23));
idg = ColGroupDDC.create(ColIndexFactory.createI(0,1,2,3,4,6,7,8,9,10), id, d, null);
cmb = new CompressedMatrixBlock(100, 11);
cmb.allocateColGroup(idg);
mb = CompressedMatrixBlock.getUncompressed(cmb);
genTests(tests, mb, cmb, "Identity2");

id = new IdentityDictionary(10, true);

d = MappingTestUtil.createRandomMap(100, 11, new Random(33));
idg = ColGroupDDC.create(ColIndexFactory.createI(0,1,2,3,4,6,7,8,9,10), id, d, null);
cmb = new CompressedMatrixBlock(100, 11);
cmb.allocateColGroup(idg);
mb = CompressedMatrixBlock.getUncompressed(cmb);
genTests(tests, mb, cmb, "Identity_empty");

AColGroup empty = new ColGroupEmpty(ColIndexFactory.create(10));
cmb = new CompressedMatrixBlock(100, 10);
cmb.allocateColGroup(empty);
Expand Down

0 comments on commit 9c606ca

Please sign in to comment.