From 3537383147d2bb9e716ea8ec5b80738ecf4e8b14 Mon Sep 17 00:00:00 2001 From: Sebastian Baunsgaard Date: Thu, 4 Jul 2024 23:18:37 +0200 Subject: [PATCH] [MINOR] Add LibReplace for MatrixBlock Adds a Lib class for replace. The main benefit in this commit is when we perform replacement on a sparse matrix, and the output is dense. after: ``` java -jar target/systemds-3.3.0-SNAPSHOT-perf.jar 17 10000 10000 0.1 16 Profiling started replaceZero, 132.883+- 9.335 ms, replaceOne, 59.612+- 3.295 ms, replaceNaN, 10.651+- 0.470 ms, ``` before: ``` java -jar target/systemds-3.3.0-SNAPSHOT-perf.jar 17 10000 10000 0.1 16 Profiling started replaceZero, 228.727+- 11.965 ms, replaceOne, 163.212+- 4.993 ms, replaceNaN, 10.602+- 0.437 ms, ``` Closes #2043 --- .../apache/sysds/runtime/data/DenseBlock.java | 7 + .../sysds/runtime/data/DenseBlockBool.java | 7 + .../sysds/runtime/data/DenseBlockFP32.java | 7 + .../sysds/runtime/data/DenseBlockFP64.java | 8 + .../runtime/data/DenseBlockFP64DEDUP.java | 6 + .../sysds/runtime/data/DenseBlockInt32.java | 7 + .../sysds/runtime/data/DenseBlockInt64.java | 8 + .../sysds/runtime/data/DenseBlockLBool.java | 7 + .../sysds/runtime/data/DenseBlockLFP32.java | 7 + .../sysds/runtime/data/DenseBlockLFP64.java | 7 + .../runtime/data/DenseBlockLFP64DEDUP.java | 5 + .../sysds/runtime/data/DenseBlockLInt32.java | 7 + .../sysds/runtime/data/DenseBlockLInt64.java | 7 + .../sysds/runtime/data/DenseBlockLString.java | 7 + .../sysds/runtime/data/DenseBlockString.java | 8 + .../runtime/matrix/data/LibMatrixReplace.java | 211 ++++++++++++++++++ .../runtime/matrix/data/MatrixBlock.java | 88 +------- .../org/apache/sysds/performance/Main.java | 14 ++ .../performance/matrix/MatrixReplacePerf.java | 69 ++++++ 19 files changed, 400 insertions(+), 87 deletions(-) create mode 100644 src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixReplace.java create mode 100644 src/test/java/org/apache/sysds/performance/matrix/MatrixReplacePerf.java diff --git a/src/main/java/org/apache/sysds/runtime/data/DenseBlock.java b/src/main/java/org/apache/sysds/runtime/data/DenseBlock.java index 0baf8819368..87967174222 100644 --- a/src/main/java/org/apache/sysds/runtime/data/DenseBlock.java +++ b/src/main/java/org/apache/sysds/runtime/data/DenseBlock.java @@ -418,6 +418,13 @@ public final long size() { */ public abstract void fillBlock(int bix, int fromIndex, int toIndex, double v); + /** + * Fill the DenseBlock row index with the value specified. + * @param r The row to fill + * @param v The value to fill it with. + */ + public abstract void fillRow(int r, double v); + /** * Set a value at a position given by block index and index in that block. * @param bix block index diff --git a/src/main/java/org/apache/sysds/runtime/data/DenseBlockBool.java b/src/main/java/org/apache/sysds/runtime/data/DenseBlockBool.java index 3d94fcf8af4..3244fb75b15 100644 --- a/src/main/java/org/apache/sysds/runtime/data/DenseBlockBool.java +++ b/src/main/java/org/apache/sysds/runtime/data/DenseBlockBool.java @@ -146,6 +146,13 @@ public void fillBlock(int bix, int fromIndex, int toIndex, double v) { _data.set(fromIndex, toIndex, v != 0); } + @Override + public void fillRow(int r, double v){ + int start = pos(r); + int end = start + getDim(1); + _data.set(start, end, v != 0); + } + @Override protected void setInternal(int bix, int ix, double v) { _data.set(ix, v != 0); diff --git a/src/main/java/org/apache/sysds/runtime/data/DenseBlockFP32.java b/src/main/java/org/apache/sysds/runtime/data/DenseBlockFP32.java index 519c17d83d8..8d2b9e61d9e 100644 --- a/src/main/java/org/apache/sysds/runtime/data/DenseBlockFP32.java +++ b/src/main/java/org/apache/sysds/runtime/data/DenseBlockFP32.java @@ -132,6 +132,13 @@ public void fillBlock(int bix, int fromIndex, int toIndex, double v) { Arrays.fill(_data, fromIndex, toIndex, (float)v); } + @Override + public void fillRow(int r, double v){ + int start = pos(r); + int end = start + getDim(1); + Arrays.fill(_data, start, end, (float)v); + } + @Override protected void setInternal(int bix, int ix, double v) { _data[ix] = (float)v; diff --git a/src/main/java/org/apache/sysds/runtime/data/DenseBlockFP64.java b/src/main/java/org/apache/sysds/runtime/data/DenseBlockFP64.java index eb93777fa4d..94909444198 100644 --- a/src/main/java/org/apache/sysds/runtime/data/DenseBlockFP64.java +++ b/src/main/java/org/apache/sysds/runtime/data/DenseBlockFP64.java @@ -139,6 +139,14 @@ public void fillBlock(int bix, int fromIndex, int toIndex, double v) { Arrays.fill(_data, fromIndex, toIndex, v); } + @Override + public void fillRow(int r, double v){ + int start = pos(r); + int end = start + getDim(1); + Arrays.fill(_data, start, end, v); + } + + @Override protected void setInternal(int bix, int ix, double v) { _data[ix] = v; diff --git a/src/main/java/org/apache/sysds/runtime/data/DenseBlockFP64DEDUP.java b/src/main/java/org/apache/sysds/runtime/data/DenseBlockFP64DEDUP.java index 2f3008c727e..49d591a01e8 100644 --- a/src/main/java/org/apache/sysds/runtime/data/DenseBlockFP64DEDUP.java +++ b/src/main/java/org/apache/sysds/runtime/data/DenseBlockFP64DEDUP.java @@ -317,6 +317,12 @@ public void fillBlock(int bix, int fromIndex, int toIndex, double v) { } } + @Override + public void fillRow(int r, double v){ + throw new NotImplementedException(); + } + + @Override protected void setInternal(int bix, int ix, double v) { set(bix, ix, v); diff --git a/src/main/java/org/apache/sysds/runtime/data/DenseBlockInt32.java b/src/main/java/org/apache/sysds/runtime/data/DenseBlockInt32.java index 6f3c2a66228..c10072ec172 100644 --- a/src/main/java/org/apache/sysds/runtime/data/DenseBlockInt32.java +++ b/src/main/java/org/apache/sysds/runtime/data/DenseBlockInt32.java @@ -132,6 +132,13 @@ public void fillBlock(int bix, int fromIndex, int toIndex, double v) { Arrays.fill(_data, fromIndex, toIndex, UtilFunctions.toInt(v)); } + @Override + public void fillRow(int r, double v){ + int start = pos(r); + int end = start + getDim(1); + Arrays.fill(_data, start, end, UtilFunctions.toInt(v)); + } + @Override protected void setInternal(int bix, int ix, double v) { _data[ix] = UtilFunctions.toInt(v); diff --git a/src/main/java/org/apache/sysds/runtime/data/DenseBlockInt64.java b/src/main/java/org/apache/sysds/runtime/data/DenseBlockInt64.java index ffe790e81a9..23930926a91 100644 --- a/src/main/java/org/apache/sysds/runtime/data/DenseBlockInt64.java +++ b/src/main/java/org/apache/sysds/runtime/data/DenseBlockInt64.java @@ -133,6 +133,14 @@ public void fillBlock(int bix, int fromIndex, int toIndex, double v) { Arrays.fill(_data, fromIndex, toIndex, UtilFunctions.toLong(v)); } + @Override + public void fillRow(int r, double v){ + int start = pos(r); + int end = start + getDim(1); + Arrays.fill(_data, start, end, UtilFunctions.toLong(v)); + } + + @Override protected void setInternal(int bix, int ix, double v) { _data[ix] = UtilFunctions.toLong(v); diff --git a/src/main/java/org/apache/sysds/runtime/data/DenseBlockLBool.java b/src/main/java/org/apache/sysds/runtime/data/DenseBlockLBool.java index ab3bd98b9c2..73a93d9a87f 100644 --- a/src/main/java/org/apache/sysds/runtime/data/DenseBlockLBool.java +++ b/src/main/java/org/apache/sysds/runtime/data/DenseBlockLBool.java @@ -150,6 +150,13 @@ public void fillBlock(int bix, int fromIndex, int toIndex, double v) { _blocks[bix].set(fromIndex, toIndex, v != 0); } + @Override + public void fillRow(int r, double v){ + int start = pos(r); + int end = start + getDim(1); + _blocks[index(r)].set(start, end, v != 0); + } + @Override public DenseBlock set(String s) { boolean b = Boolean.parseBoolean(s); diff --git a/src/main/java/org/apache/sysds/runtime/data/DenseBlockLFP32.java b/src/main/java/org/apache/sysds/runtime/data/DenseBlockLFP32.java index 4122db2faec..b9e9d602b51 100644 --- a/src/main/java/org/apache/sysds/runtime/data/DenseBlockLFP32.java +++ b/src/main/java/org/apache/sysds/runtime/data/DenseBlockLFP32.java @@ -113,6 +113,13 @@ public void fillBlock(int bix, int fromIndex, int toIndex, double v) { Arrays.fill(_blocks[bix], fromIndex, toIndex, (float)v); } + @Override + public void fillRow(int r, double v){ + int start = pos(r); + int end = start + getDim(1); + Arrays.fill(_blocks[index(r)],start, end, (float)v); + } + @Override public DenseBlock set(int r, int c, double v) { _blocks[index(r)][pos(r, c)] = (float)v; diff --git a/src/main/java/org/apache/sysds/runtime/data/DenseBlockLFP64.java b/src/main/java/org/apache/sysds/runtime/data/DenseBlockLFP64.java index 1d0bc3ccfbf..3a8091938b9 100644 --- a/src/main/java/org/apache/sysds/runtime/data/DenseBlockLFP64.java +++ b/src/main/java/org/apache/sysds/runtime/data/DenseBlockLFP64.java @@ -105,6 +105,13 @@ public void fillBlock(int bix, int fromIndex, int toIndex, double v) { Arrays.fill(_blocks[bix], fromIndex,toIndex, v); } + @Override + public void fillRow(int r, double v){ + int start = pos(r); + int end = start + getDim(1); + Arrays.fill(_blocks[index(r)],start, end, v); + } + @Override public DenseBlock set(int r, int c, double v) { _blocks[index(r)][pos(r, c)] = v; diff --git a/src/main/java/org/apache/sysds/runtime/data/DenseBlockLFP64DEDUP.java b/src/main/java/org/apache/sysds/runtime/data/DenseBlockLFP64DEDUP.java index 79c02b7ac64..782c87f7869 100644 --- a/src/main/java/org/apache/sysds/runtime/data/DenseBlockLFP64DEDUP.java +++ b/src/main/java/org/apache/sysds/runtime/data/DenseBlockLFP64DEDUP.java @@ -178,6 +178,11 @@ public void fillBlock(int bix, int fromIndex, int toIndex, double v) { throw new NotImplementedException(); } + @Override + public void fillRow(int r, double v){ + throw new NotImplementedException(); + } + @Override protected void setInternal(int bix, int ix, double v) { throw new NotImplementedException(); diff --git a/src/main/java/org/apache/sysds/runtime/data/DenseBlockLInt32.java b/src/main/java/org/apache/sysds/runtime/data/DenseBlockLInt32.java index 0880440e46c..6c63f31d200 100644 --- a/src/main/java/org/apache/sysds/runtime/data/DenseBlockLInt32.java +++ b/src/main/java/org/apache/sysds/runtime/data/DenseBlockLInt32.java @@ -109,6 +109,13 @@ public void fillBlock(int bix, int fromIndex, int toIndex, double v) { Arrays.fill(_blocks[bix], fromIndex, toIndex, UtilFunctions.toInt(v)); } + @Override + public void fillRow(int r, double v){ + int start = pos(r); + int end = start + getDim(1); + Arrays.fill(_blocks[index(r)], start, end, UtilFunctions.toInt(v)); + } + @Override public DenseBlock set(int r, int c, double v) { _blocks[index(r)][pos(r, c)] = UtilFunctions.toInt(v); diff --git a/src/main/java/org/apache/sysds/runtime/data/DenseBlockLInt64.java b/src/main/java/org/apache/sysds/runtime/data/DenseBlockLInt64.java index d9ffc261766..79b79e17607 100644 --- a/src/main/java/org/apache/sysds/runtime/data/DenseBlockLInt64.java +++ b/src/main/java/org/apache/sysds/runtime/data/DenseBlockLInt64.java @@ -109,6 +109,13 @@ public void fillBlock(int bix, int fromIndex, int toIndex, double v) { Arrays.fill(_blocks[bix], fromIndex, toIndex, UtilFunctions.toLong(v)); } + @Override + public void fillRow(int r, double v){ + int start = pos(r); + int end = start + getDim(1); + Arrays.fill(_blocks[index(r)], start, end, UtilFunctions.toLong(v)); + } + @Override public DenseBlock set(int r, int c, double v) { _blocks[index(r)][pos(r, c)] = UtilFunctions.toLong(v); diff --git a/src/main/java/org/apache/sysds/runtime/data/DenseBlockLString.java b/src/main/java/org/apache/sysds/runtime/data/DenseBlockLString.java index 0ab267abecb..013ae366bb4 100644 --- a/src/main/java/org/apache/sysds/runtime/data/DenseBlockLString.java +++ b/src/main/java/org/apache/sysds/runtime/data/DenseBlockLString.java @@ -109,6 +109,13 @@ public void fillBlock(int bix, int fromIndex, int toIndex, double v) { Arrays.fill(_blocks[bix], fromIndex, toIndex, String.valueOf(v)); } + @Override + public void fillRow(int r, double v){ + int start = pos(r); + int end = start + getDim(1); + Arrays.fill(_blocks[index(r)], start, end, String.valueOf(v)); + } + @Override public DenseBlock set(String s) { for (int i = 0; i < numBlocks() - 1; i++) { diff --git a/src/main/java/org/apache/sysds/runtime/data/DenseBlockString.java b/src/main/java/org/apache/sysds/runtime/data/DenseBlockString.java index 657a0f65595..396c3496bef 100644 --- a/src/main/java/org/apache/sysds/runtime/data/DenseBlockString.java +++ b/src/main/java/org/apache/sysds/runtime/data/DenseBlockString.java @@ -118,6 +118,14 @@ public void fillBlock(int bix, int fromIndex, int toIndex, double v) { Arrays.fill(_data, fromIndex, toIndex, String.valueOf(v)); } + @Override + public void fillRow(int r, double v){ + int start = pos(r); + int end = start + getDim(1); + Arrays.fill(_data, start, end, String.valueOf(v)); + } + + @Override protected void setInternal(int bix, int ix, double v) { _data[ix] = String.valueOf(v); diff --git a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixReplace.java b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixReplace.java new file mode 100644 index 00000000000..b40cea9c4bd --- /dev/null +++ b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixReplace.java @@ -0,0 +1,211 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.runtime.matrix.data; + +import org.apache.sysds.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer; +import org.apache.sysds.runtime.data.DenseBlock; +import org.apache.sysds.runtime.data.SparseBlock; + +public class LibMatrixReplace { + + private LibMatrixReplace() { + + } + + public static MatrixBlock replaceOperations(MatrixBlock in, MatrixBlock ret, double pattern, double replacement) { + return replaceOperations(in, ret, pattern, replacement, InfrastructureAnalyzer.getLocalParallelism()); + } + + public static MatrixBlock replaceOperations(MatrixBlock in, MatrixBlock ret, double pattern, double replacement, + int k) { + + // ensure input its in the right format + in.examSparsity(k); + + final int rlen = in.getNumRows(); + final int clen = in.getNumColumns(); + final long nonZeros = in.getNonZeros(); + final boolean sparse = in.isInSparseFormat(); + + if(ret != null) + ret.reset(rlen, clen, sparse); + else + ret = new MatrixBlock(rlen, clen, sparse); + + // probe early abort conditions + if(nonZeros == 0 && pattern != 0) + return ret; + if(!in.containsValue(pattern)) + return in; // avoid allocation + copy + if(in.isEmpty() && pattern == 0) { + ret.reset(rlen, clen, replacement); + return ret; + } + + final boolean replaceNaN = Double.isNaN(pattern); + + final long nnz; + if(sparse) // SPARSE + nnz = replaceSparse(in, ret, pattern, replacement, replaceNaN); + else if(replaceNaN) + nnz = replaceDenseNaN(in, ret, replacement); + else + nnz = replaceDense(in, ret, pattern, replacement); + + ret.setNonZeros(nnz); + ret.examSparsity(k); + return ret; + } + + private static long replaceSparse(MatrixBlock in, MatrixBlock ret, double pattern, double replacement, + boolean replaceNaN) { + if(replaceNaN) + return replaceSparseInSparseOutReplaceNaN(in, ret, replacement); + else if(pattern != 0d) // sparse safe. + return replaceSparseInSparseOut(in, ret, pattern, replacement); + else // sparse unsafe + return replace0InSparse(in, ret, replacement); + + } + + private static long replaceSparseInSparseOutReplaceNaN(MatrixBlock in, MatrixBlock ret, double replacement) { + ret.allocateSparseRowsBlock(); + SparseBlock a = in.sparseBlock; + SparseBlock c = ret.sparseBlock; + long nnz = 0; + for(int i = 0; i < in.rlen; i++) { + if(!a.isEmpty(i)) { + int apos = a.pos(i); + int alen = a.size(i); + c.allocate(i, alen); + int[] aix = a.indexes(i); + double[] avals = a.values(i); + for(int j = apos; j < apos + alen; j++) { + double val = avals[j]; + if(Double.isNaN(val)) + c.append(i, aix[j], replacement); + else + c.append(i, aix[j], val); + } + c.compact(i); + nnz += c.size(i); + } + } + return nnz; + } + + private static long replaceSparseInSparseOut(MatrixBlock in, MatrixBlock ret, double pattern, double replacement) { + ret.allocateSparseRowsBlock(); + final SparseBlock a = in.sparseBlock; + final SparseBlock c = ret.sparseBlock; + + return replaceSparseInSparseOut(a, c, pattern, replacement, 0, in.rlen); + + } + + private static long replaceSparseInSparseOut(SparseBlock a, SparseBlock c, double pattern, double replacement, int s, + int e) { + long nnz = 0; + for(int i = s; i < e; i++) { + if(!a.isEmpty(i)) { + final int apos = a.pos(i); + final int alen = a.size(i); + final int[] aix = a.indexes(i); + final double[] avals = a.values(i); + c.allocate(i, alen); + for(int j = apos; j < apos + alen; j++) { + double val = avals[j]; + if(val == pattern) + c.append(i, aix[j], replacement); + else + c.append(i, aix[j], val); + } + c.compact(i); + nnz += c.size(i); + } + } + return nnz; + } + + private static long replace0InSparse(MatrixBlock in, MatrixBlock ret, double replacement) { + ret.sparse = false; + ret.allocateDenseBlock(); + SparseBlock a = in.sparseBlock; + DenseBlock c = ret.getDenseBlock(); + + // initialize with replacement (since all 0 values, see SPARSITY_TURN_POINT) + // c.reset(in.rlen, in.clen, replacement); + + if(a == null)// check for empty matrix + return ((long) in.rlen) * in.clen; + + // overwrite with existing values (via scatter) + for(int i = 0; i < in.rlen; i++) { + c.fillRow(i, replacement); + if(!a.isEmpty(i)) { + int apos = a.pos(i); + int cpos = c.pos(i); + int alen = a.size(i); + int[] aix = a.indexes(i); + double[] avals = a.values(i); + double[] cvals = c.values(i); + for(int j = apos; j < apos + alen; j++) + if(avals[j] != 0) + cvals[cpos + aix[j]] = avals[j]; + } + } + return ((long) in.rlen) * in.clen; + + } + + private static long replaceDense(MatrixBlock in, MatrixBlock ret, double pattern, double replacement) { + DenseBlock a = in.getDenseBlock(); + DenseBlock c = ret.allocateDenseBlock().getDenseBlock(); + long nnz = 0; + for(int bi = 0; bi < a.numBlocks(); bi++) { + int len = a.size(bi); + double[] avals = a.valuesAt(bi); + double[] cvals = c.valuesAt(bi); + for(int i = 0; i < len; i++) { + cvals[i] = avals[i] == pattern ? replacement : avals[i]; + nnz += cvals[i] != 0 ? 1 : 0; + } + } + return nnz; + } + + private static long replaceDenseNaN(MatrixBlock in, MatrixBlock ret, double replacement) { + DenseBlock a = in.getDenseBlock(); + DenseBlock c = ret.allocateDenseBlock().getDenseBlock(); + long nnz = 0; + for(int bi = 0; bi < a.numBlocks(); bi++) { + int len = a.size(bi); + double[] avals = a.valuesAt(bi); + double[] cvals = c.valuesAt(bi); + for(int i = 0; i < len; i++) { + cvals[i] = Double.isNaN(avals[i]) ? replacement : avals[i]; + nnz += cvals[i] != 0 ? 1 : 0; + } + } + return nnz; + + } + +} diff --git a/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java b/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java index 668906ee2ea..054edf06a21 100644 --- a/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java +++ b/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java @@ -5207,93 +5207,7 @@ public MatrixBlock rexpandOperations( MatrixBlock ret, double max, boolean rows, @Override public MatrixBlock replaceOperations(MatrixValue result, double pattern, double replacement) { MatrixBlock ret = checkType(result); - examSparsity(); //ensure its in the right format - if(ret != null) - ret.reset(rlen, clen, sparse); - else - ret = new MatrixBlock(rlen, clen, sparse); - - //probe early abort conditions - if( nonZeros == 0 && pattern != 0 ) - return ret; - if( !containsValue(pattern) ) - return this; //avoid allocation + copy - if( isEmpty() && pattern==0 ) { - ret.reset(rlen, clen, replacement); - return ret; - } - - boolean NaNpattern = Double.isNaN(pattern); - if( sparse ) //SPARSE - { - if( pattern != 0d ) //SPARSE <- SPARSE (sparse-safe) - { - ret.allocateSparseRowsBlock(); - SparseBlock a = sparseBlock; - SparseBlock c = ret.sparseBlock; - - for( int i=0; i g = new ConstMatrix(mb); + new MatrixReplacePerf(100, g, k).run(); + } + private static void run1000(String[] args) { MatrixMulPerformance perf; if (args.length < 3) { diff --git a/src/test/java/org/apache/sysds/performance/matrix/MatrixReplacePerf.java b/src/test/java/org/apache/sysds/performance/matrix/MatrixReplacePerf.java new file mode 100644 index 00000000000..17de8d53fe3 --- /dev/null +++ b/src/test/java/org/apache/sysds/performance/matrix/MatrixReplacePerf.java @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.performance.matrix; + +import org.apache.sysds.performance.compression.APerfTest; +import org.apache.sysds.performance.generators.IGenerate; +import org.apache.sysds.runtime.matrix.data.LibMatrixReplace; +import org.apache.sysds.runtime.matrix.data.MatrixBlock; + +public class MatrixReplacePerf extends APerfTest { + + private final int k; + + public MatrixReplacePerf(int N, IGenerate gen, int k) { + super(N, gen); + this.k = k; + } + + public void run() throws Exception { + + warmup(() -> replaceZeroTask(k), 10); + execute(() -> replaceZeroTask(k), "replaceZero"); + execute(() -> replaceOneTask(k), "replaceOne"); + execute(() -> replaceNaNTask(k), "replaceNaN"); + } + + private void replaceZeroTask(int k) { + MatrixBlock mb = gen.take(); + LibMatrixReplace.replaceOperations(mb, null, 0, 1, k); + ret.add(null); + } + + + private void replaceOneTask(int k) { + MatrixBlock mb = gen.take(); + LibMatrixReplace.replaceOperations(mb, null, 1, 2, k); + ret.add(null); + } + + + private void replaceNaNTask(int k) { + MatrixBlock mb = gen.take(); + LibMatrixReplace.replaceOperations(mb, null, Double.NaN, 2, k); + ret.add(null); + } + + @Override + protected String makeResString() { + return ""; + } + +}