From 2b8de1629b935d0b75caf38e4295c706980f0ce7 Mon Sep 17 00:00:00 2001 From: Sebastian Baunsgaard Date: Thu, 26 Oct 2023 18:24:54 +0200 Subject: [PATCH] [MINOR] JIT optimize LibMatrixBinCell This commit move some of the code inside LibMatrixBincell around to encourage jit compilation of some methods. In specific folloing methods have been introduced. - safeBinaryMvSparseRowVector - fillZeroValuesEmpty - fillZeroValuesDense - fillZeroValuesSparse - safeBinaryMMDenseDenseDensePM_Vec (Plus Multiply kernel vectorized) - safeBinaryMMDenseDenseDensePM (Plus Multiply kernel small input) - safeBinaryMMDenseDenseDenseContiguous (This one makes a big difference) - safeBinaryMMDenseDenseDenseGeneric In specific the safeBinaryMMDenseDenseDenseContiguous, safeBinaryMMDenseDenseDensePMm and safeBinaryMMDenseDenseDensePM_Vec improve the performance by much. In LM_cg the performance: Stats output: +* 3.123 3000 (Before) +* 1.991 3000 (After) + 1.125 2021 (Before) + 0.703 2015 (After) This is training on Criteo 100k rows. --- .../runtime/matrix/data/LibMatrixBincell.java | 430 +++++++++++------- .../runtime/matrix/data/LibMatrixMult.java | 2 +- 2 files changed, 269 insertions(+), 163 deletions(-) diff --git a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixBincell.java b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixBincell.java index e53f09a7f43..e5ec7a00209 100644 --- a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixBincell.java +++ b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixBincell.java @@ -851,85 +851,93 @@ private static void safeBinaryMVSparseDenseRow(MatrixBlock m1, MatrixBlock m2, M private static void safeBinaryMVSparse(MatrixBlock m1, MatrixBlock m2, MatrixBlock ret, BinaryOperator op) { boolean isMultiply = (op.fn instanceof Multiply); boolean skipEmpty = (isMultiply || isSparseSafeDivide(op, m2)); - - int rlen = m1.rlen; - int clen = m1.clen; - SparseBlock a = m1.sparseBlock; BinaryAccessType atype = getBinaryAccessType(m1, m2); - - //early abort on skip and empty - if( skipEmpty && (m1.isEmptyBlock(false) || m2.isEmptyBlock(false) ) ) + + // early abort on skip and empty + if(skipEmpty && (m1.isEmptyBlock(false) || m2.isEmptyBlock(false))) return; // skip entire empty block - - //allocate once in order to prevent repeated reallocation - if( ret.sparse ) + + // allocate once in order to prevent repeated reallocation + if(ret.sparse) ret.allocateSparseRowsBlock(); - - if( atype == BinaryAccessType.MATRIX_COL_VECTOR ) - { - for( int i=0; i aix[apos]){ - apos++; - } - // for each point in the sparse range - for(; apos < alen && aix[apos] < len; apos++){ - if(!zeroIsZero){ - while(cpos < len && cpos < aix[apos]){ - ret.appendValue(rpos, cpos++, zero); - } - } - cpos = aix[apos]; - final double v = op.fn.execute(0, vals[apos]); - ret.appendValue(rpos, aix[apos], v); - // cpos++; - } - // process tail. + } + else { + // def + for(int k = cpos; k < len; k++) { + ret.appendValue(rpos, k, op.fn.execute(0, vals[k])); + } + } + } + + private static void fillZeroValuesSparse(BinaryOperator op, MatrixBlock m2, MatrixBlock ret, boolean skipEmpty, + int rpos, int cpos, int len) { + + final double zero = op.fn.execute(0.0, 0.0); + final boolean zeroIsZero = zero == 0.0; + final SparseBlock sb = m2.getSparseBlock(); + if(sb.isEmpty(0)) { + if(!zeroIsZero) { + while(cpos < len) + ret.appendValue(rpos, cpos++, zero); + } + } + else { + int apos = sb.pos(0); + final int alen = sb.size(0) + apos; + final int[] aix = sb.indexes(0); + final double[] vals = sb.values(0); + // skip aix pos until inside range of cpos and len + while(apos < alen && aix[apos] < len && cpos > aix[apos]) { + apos++; + } + // for each point in the sparse range + for(; apos < alen && aix[apos] < len; apos++) { if(!zeroIsZero) { - while(cpos < len) { + while(cpos < len && cpos < aix[apos]) { ret.appendValue(rpos, cpos++, zero); } } - } - } - else { - final DenseBlock db = m2.getDenseBlock(); - final double[] vals = db.values(0); - for(int k = cpos; k < len; k++){ - ret.appendValue(rpos, k, op.fn.execute(0, vals[k])); + cpos = aix[apos]; + final double v = op.fn.execute(0, vals[apos]); + ret.appendValue(rpos, aix[apos], v); + // cpos++; + } + // process tail. + if(!zeroIsZero) { + while(cpos < len) { + ret.appendValue(rpos, cpos++, zero); + } } } } @@ -1313,40 +1347,86 @@ else if(op.fn instanceof Multiply) } private static long safeBinaryMMDenseDenseDense(MatrixBlock m1, MatrixBlock m2, MatrixBlock ret, - BinaryOperator op, int rl, int ru) - { - boolean isPM = m1.clen >= 512 & (op.fn instanceof PlusMultiply | op.fn instanceof MinusMultiply); - double cntPM = !isPM ? Double.NaN : (op.fn instanceof PlusMultiply ? - ((PlusMultiply)op.fn).getConstant() : -1d * ((MinusMultiply)op.fn).getConstant()); + BinaryOperator op, int rl, int ru){ + final int clen = m1.clen; + final boolean isPM = (op.fn instanceof PlusMultiply || op.fn instanceof MinusMultiply); //guard for postponed allocation in single-threaded exec - if( !ret.isAllocated() ) + if(!ret.isAllocated()) ret.allocateDenseBlock(); - DenseBlock da = m1.getDenseBlock(); - DenseBlock db = m2.getDenseBlock(); - DenseBlock dc = ret.getDenseBlock(); - ValueFunction fn = op.fn; - int clen = m1.clen; + final DenseBlock da = m1.getDenseBlock(); + final DenseBlock db = m2.getDenseBlock(); + final DenseBlock dc = ret.getDenseBlock(); - //compute dense-dense binary, maintain nnz on-the-fly + if(isPM && clen >= 64) + return safeBinaryMMDenseDenseDensePM_Vec(da, db, dc, op, rl, ru, clen); + else if(da.isContiguous() && db.isContiguous() && dc.isContiguous()) { + if(op.fn instanceof PlusMultiply) + return safeBinaryMMDenseDenseDensePM(da, db, dc, op, rl, ru, clen); + else + return safeBinaryMMDenseDenseDenseContiguous(da, db, dc, op, rl, ru, clen); + } + else + return safeBinaryMMDenseDenseDenseGeneric(da, db, dc, op, rl, ru, clen); + } + + private static final long safeBinaryMMDenseDenseDensePM_Vec(DenseBlock da, DenseBlock db, DenseBlock dc, BinaryOperator op, + int rl, int ru, int clen) { + final double cntPM = (op.fn instanceof PlusMultiply ? ((PlusMultiply) op.fn).getConstant() : -1d * + ((MinusMultiply) op.fn).getConstant()); long lnnz = 0; - for(int i=rl; i