Skip to content

Commit

Permalink
[MINOR] Uncompressed ColGroup Outer TSMM
Browse files Browse the repository at this point in the history
Add support for sparse outer TSMM for uncompressed column groups.
This was missing in 1c26e2d

Closes apache#1968
  • Loading branch information
Baunsgaard committed Dec 30, 2023
1 parent 3b48c4a commit 246eea9
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -532,14 +532,33 @@ public final void tsmm(MatrixBlock ret, int nRows) {
// tsmm but only upper triangle.
LibMatrixMult.matrixMultTransposeSelf(_data, tmp, true, false);

// copy that upper triangle part to ret
final int numColumns = ret.getNumColumns();
final double[] result = ret.getDenseBlockValues();
final double[] tmpV = tmp.getDenseBlockValues();
for(int row = 0, offTmp = 0; row < tCol; row++, offTmp += tCol) {
final int offRet = _colIndexes.get(row) * numColumns;
for(int col = row; col < tCol; col++)
result[offRet + _colIndexes.get(col)] += tmpV[offTmp + col];
if(tmp.isInSparseFormat()){
final int numColumns = ret.getNumColumns();
final double[] result = ret.getDenseBlockValues();
final SparseBlock sb = tmp.getSparseBlock();
for(int row = 0; row < tCol; row++) {
final int offRet = _colIndexes.get(row) * numColumns;
if(sb.isEmpty(row))
continue;
int apos = sb.pos(row);
int alen = sb.size(row) + apos;
int[] aix = sb.indexes(row);
double[] aval = sb.values(row);
for(int j = apos; j < alen; j++)
result[offRet + _colIndexes.get(aix[j])] += aval[j];

}
}
else{
// copy that upper triangle part to ret
final int numColumns = ret.getNumColumns();
final double[] result = ret.getDenseBlockValues();
final double[] tmpV = tmp.getDenseBlockValues();
for(int row = 0, offTmp = 0; row < tCol; row++, offTmp += tCol) {
final int offRet = _colIndexes.get(row) * numColumns;
for(int col = row; col < tCol; col++)
result[offRet + _colIndexes.get(col)] += tmpV[offTmp + col];
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1118,13 +1118,20 @@ public void getMax() {

@Test
public void tsmm() {
final MatrixBlock bt = new MatrixBlock(maxCol, maxCol, false);
final MatrixBlock ot = new MatrixBlock(maxCol, maxCol, false);
ot.allocateDenseBlock();
bt.allocateDenseBlock();
base.tsmm(bt, nRow);
other.tsmm(ot, nRow);
compare(ot, bt);
try{

final MatrixBlock bt = new MatrixBlock(maxCol, maxCol, false);
final MatrixBlock ot = new MatrixBlock(maxCol, maxCol, false);
ot.allocateDenseBlock();
bt.allocateDenseBlock();
base.tsmm(bt, nRow);
other.tsmm(ot, nRow);
compare(ot, bt);
}
catch(Exception e){
e.printStackTrace();
fail(e.getMessage());
}
}

@Test
Expand Down

0 comments on commit 246eea9

Please sign in to comment.