From 65a1f6d9d66d1f1dddc247770637ffdba0664062 Mon Sep 17 00:00:00 2001 From: Sebastian Baunsgaard Date: Thu, 19 Oct 2023 15:01:40 +0200 Subject: [PATCH 01/28] [SYSTEMDS-3635] Cast as frame column names This commit overloads the cast as frame operation. ```R col_names = ["id", "target"] y_file = as.frame(y_file, col_names) ``` and other improvements to tests in frames. Closes #1927 --- src/assembly/bin.xml | 1 + .../java/org/apache/sysds/common/Types.java | 1 + .../sysds/hops/recompile/Recompiler.java | 4 +- .../sysds/hops/rewrite/HopRewriteUtils.java | 5 + .../rewrite/RewriteBlockSizeAndReblock.java | 6 +- .../org/apache/sysds/lops/FunctionCallCP.java | 6 +- .../parser/BuiltinFunctionExpression.java | 40 +- .../apache/sysds/parser/DMLTranslator.java | 591 +++++++++--------- .../apache/sysds/parser/ExpressionList.java | 4 + .../apache/sysds/parser/ListIdentifier.java | 52 ++ .../sysds/runtime/frame/data/FrameBlock.java | 7 +- .../runtime/frame/data/columns/Array.java | 65 +- .../frame/data/columns/ArrayFactory.java | 78 +-- .../frame/data/columns/BitSetArray.java | 28 +- .../frame/data/columns/BooleanArray.java | 20 +- .../runtime/frame/data/columns/CharArray.java | 9 +- .../runtime/frame/data/columns/DDCArray.java | 13 +- .../frame/data/columns/DoubleArray.java | 43 +- .../frame/data/columns/FloatArray.java | 27 +- .../frame/data/columns/IntegerArray.java | 9 +- .../runtime/frame/data/columns/LongArray.java | 9 +- .../frame/data/columns/RaggedArray.java | 35 +- .../frame/data/columns/StringArray.java | 187 +++--- .../compress/CompressedFrameBlockFactory.java | 77 ++- .../runtime/frame/data/lib/FrameUtil.java | 85 ++- .../frame/data/lib/MatrixBlockFromFrame.java | 71 ++- .../instructions/CPInstructionParser.java | 3 +- .../instructions/cp/CPInstruction.java | 10 +- .../cp/VariableCPInstruction.java | 81 ++- .../sysds/runtime/util/DataConverter.java | 2 +- .../test/component/frame/FrameUtilTest.java | 2 +- .../frame/array/CustomArrayTests.java | 34 + .../frame/array/FrameArrayTests.java | 376 +++++++---- .../frame/array/NegativeArrayTests.java | 48 +- 34 files changed, 1263 insertions(+), 766 deletions(-) create mode 100644 src/main/java/org/apache/sysds/parser/ListIdentifier.java diff --git a/src/assembly/bin.xml b/src/assembly/bin.xml index 24cf9f93b38..8f3558bbfca 100644 --- a/src/assembly/bin.xml +++ b/src/assembly/bin.xml @@ -97,6 +97,7 @@ *:commons-logging* *:commons-math3* *:commons-text* + *:fastdoubleparser* *:guava* *:hadoop-auth* *:hadoop-client* diff --git a/src/main/java/org/apache/sysds/common/Types.java b/src/main/java/org/apache/sysds/common/Types.java index 0a3748a9f86..4b8f1c3a006 100644 --- a/src/main/java/org/apache/sysds/common/Types.java +++ b/src/main/java/org/apache/sysds/common/Types.java @@ -408,6 +408,7 @@ public enum OpOp2 { MINUS(true), MODULUS(true), MOMENT(false), MULT(true), NOTEQUAL(true), OR(true), PLUS(true), POW(true), PRINT(false), QUANTILE(false), SOLVE(false), RBIND(false), VALUE_SWAP(false), XOR(true), + CAST_AS_FRAME(false), // cast as frame with column names //fused ML-specific operators for performance MINUS_NZ(false), //sparse-safe minus: X-(mean*ppred(X,0,!=)) LOG_NZ(false), //sparse-safe log; ppred(X,0,"!=")*log(X,0.5) diff --git a/src/main/java/org/apache/sysds/hops/recompile/Recompiler.java b/src/main/java/org/apache/sysds/hops/recompile/Recompiler.java index 811b7593e96..dfa69881223 100644 --- a/src/main/java/org/apache/sysds/hops/recompile/Recompiler.java +++ b/src/main/java/org/apache/sysds/hops/recompile/Recompiler.java @@ -29,6 +29,8 @@ import java.util.List; import java.util.Map.Entry; +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; import org.apache.hadoop.fs.FileSystem; import org.apache.hadoop.fs.Path; import org.apache.sysds.api.DMLScript; @@ -120,7 +122,7 @@ * */ public class Recompiler { - // private static final Log LOG = LogFactory.getLog(Recompiler.class.getName()); + protected static final Log LOG = LogFactory.getLog(Recompiler.class.getName()); //Max threshold for in-memory reblock of text input [in bytes] //reason: single-threaded text read at 20MB/s, 1GB input -> 50s (should exploit parallelism) diff --git a/src/main/java/org/apache/sysds/hops/rewrite/HopRewriteUtils.java b/src/main/java/org/apache/sysds/hops/rewrite/HopRewriteUtils.java index a5ab766b17b..b2460e7697c 100644 --- a/src/main/java/org/apache/sysds/hops/rewrite/HopRewriteUtils.java +++ b/src/main/java/org/apache/sysds/hops/rewrite/HopRewriteUtils.java @@ -1136,6 +1136,11 @@ public static boolean isData(Hop hop, OpOpData type, DataType dt) { return isData(hop, type) && hop.getDataType()==dt; } + public static boolean isTransformEncode(Hop hop){ + return hop instanceof FunctionOp + && (((FunctionOp)hop).getFunctionName().equalsIgnoreCase("transformencode")); + } + public static boolean isBinaryMatrixColVectorOperation(Hop hop) { return hop instanceof BinaryOp && hop.getInput().get(0).getDataType().isMatrix() && hop.getInput().get(1).getDataType().isMatrix() diff --git a/src/main/java/org/apache/sysds/hops/rewrite/RewriteBlockSizeAndReblock.java b/src/main/java/org/apache/sysds/hops/rewrite/RewriteBlockSizeAndReblock.java index 1a2ec286bc6..4e03e02f626 100644 --- a/src/main/java/org/apache/sysds/hops/rewrite/RewriteBlockSizeAndReblock.java +++ b/src/main/java/org/apache/sysds/hops/rewrite/RewriteBlockSizeAndReblock.java @@ -85,7 +85,7 @@ private void rule_BlockSizeAndReblock(Hop hop, final int blocksize) ||(dop.getDataType() == DataType.FRAME && OptimizerUtils.isSparkExecutionMode() && (dop.getFileFormat()==FileFormat.TEXT || dop.getFileFormat()==FileFormat.CSV)) ) { - if( dop.getOp() == OpOpData.PERSISTENTREAD) + if( dop.getOp() == OpOpData.PERSISTENTREAD || dop.getOp() == OpOpData.FEDERATED) { // insert reblock after the hop dop.setRequiresReblock(true); @@ -111,10 +111,6 @@ else if (dop.getOp().isTransient()) { // by default, all transient reads and writes are in blocked format dop.setBlocksize(blocksize); } - else if (dop.getOp() == OpOpData.FEDERATED) { - dop.setRequiresReblock(true); - dop.setBlocksize(blocksize); - } else { throw new HopsException(hop.printErrorLocation() + "unexpected non-scalar Data HOP in reblock.\n"); } diff --git a/src/main/java/org/apache/sysds/lops/FunctionCallCP.java b/src/main/java/org/apache/sysds/lops/FunctionCallCP.java index 5f106c3ae64..b0694e2c364 100644 --- a/src/main/java/org/apache/sysds/lops/FunctionCallCP.java +++ b/src/main/java/org/apache/sysds/lops/FunctionCallCP.java @@ -86,7 +86,11 @@ public ArrayList getFunctionOutputs() { public String getFnamespace() { return _fnamespace; } - + + public String getFunctionName(){ + return _fname; + } + public boolean requiresOutputCreateVar() { return !_fname.equalsIgnoreCase(Builtins.REMOVE.getName()); } diff --git a/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java b/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java index 8f1a496c1ec..fa0354843b8 100644 --- a/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java +++ b/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java @@ -27,6 +27,8 @@ import org.antlr.v4.runtime.ParserRuleContext; import org.apache.commons.lang3.ArrayUtils; import org.apache.commons.lang3.NotImplementedException; +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; import org.apache.sysds.common.Builtins; import org.apache.sysds.common.Types.DataType; import org.apache.sysds.common.Types.ValueType; @@ -37,8 +39,8 @@ import org.apache.sysds.runtime.util.DnnUtils; import org.apache.sysds.runtime.util.UtilFunctions; -public class BuiltinFunctionExpression extends DataIdentifier -{ +public class BuiltinFunctionExpression extends DataIdentifier { + protected static final Log LOG = LogFactory.getLog(BuiltinFunctionExpression.class.getName()); protected Expression[] _args = null; private Builtins _opcode; @@ -409,7 +411,7 @@ public void validateExpression(MultiAssignmentStatement stmt, HashMap()); - + //STEP2: Actual Validate if( inclFuns ) { // handle functions in namespaces (current program has default namespace) @@ -125,7 +125,7 @@ public void validateParseTree(DMLProgram dmlp, boolean inclFuns) } } } - + // handle regular blocks -- "main" program VariableSet vs = new VariableSet(); HashMap constVars = new HashMap<>(); @@ -140,7 +140,7 @@ public void validateParseTree(DMLProgram dmlp, boolean inclFuns) { //propagate size and datatypes into read prepareReadAfterWrite(dmlp, new HashMap<>()); - + //re-validate main program for datatype propagation vs = new VariableSet(); constVars = new HashMap<>(); @@ -151,15 +151,15 @@ public void validateParseTree(DMLProgram dmlp, boolean inclFuns) } } } - + public void validateFunction(DMLProgram dmlp, FunctionStatementBlock fsb) { validateFunction(dmlp, fsb, false); } - + public void validateFunction(DMLProgram dmlp, FunctionStatementBlock fsb, boolean conditional) { HashMap constVars = new HashMap<>(); VariableSet vs = new VariableSet(); - + // add the input variables for the function to input variable list FunctionStatement fstmt = (FunctionStatement)fsb.getStatement(0); for (DataIdentifier currVar : fstmt.getInputParams()) { @@ -173,9 +173,9 @@ public void validateFunction(DMLProgram dmlp, FunctionStatementBlock fsb, boolea public void liveVariableAnalysis(DMLProgram dmlp) { liveVariableAnalysis(dmlp, true); } - + public void liveVariableAnalysis(DMLProgram dmlp, boolean inclFuns) { - + // for each namespace, handle function statement blocks if( inclFuns ) { for (String namespaceKey : dmlp.getNamespaces().keySet()) { @@ -185,14 +185,14 @@ public void liveVariableAnalysis(DMLProgram dmlp, boolean inclFuns) { } } } - + // handle regular program blocks VariableSet currentLiveOut = new VariableSet(); VariableSet activeIn = new VariableSet(); - + // handle function inlining dmlp.setStatementBlocks(StatementBlock.mergeFunctionCalls(dmlp.getStatementBlocks(), dmlp)); - + for (int i = 0; i < dmlp.getNumStatementBlocks(); i++) { StatementBlock sb = dmlp.getStatementBlock(i); activeIn = sb.initializeforwardLV(activeIn); @@ -206,41 +206,41 @@ public void liveVariableAnalysis(DMLProgram dmlp, boolean inclFuns) { currentLiveOut = sb.analyze(currentLiveOut); } } - + cleanupLiveOutVariables(dmlp.getStatementBlocks(), new VariableSet()); } - + public void liveVariableAnalysisFunction(DMLProgram dmlp, FunctionStatementBlock fsb) { //STEP 1: forward direction FunctionStatement fstmt = (FunctionStatement)fsb.getStatement(0); - + // perform function inlining fstmt.setBody(StatementBlock.mergeFunctionCalls(fstmt.getBody(), dmlp)); - + VariableSet activeIn = new VariableSet(); for (DataIdentifier id : fstmt.getInputParams()){ activeIn.addVariable(id.getName(), id); } fsb.initializeforwardLV(activeIn); - + //STEP 2: backward direction VariableSet currentLiveOut = new VariableSet(); VariableSet currentLiveIn = new VariableSet(); VariableSet unionLiveIn = new VariableSet(); - + for (DataIdentifier id : fstmt.getInputParams()) currentLiveIn.addVariable(id.getName(), id); - + for (DataIdentifier id : fstmt.getOutputParams()) { currentLiveOut.addVariable(id.getName(), id); unionLiveIn.addVariable(id.getName(), id); } - + fsb._liveOut = currentLiveOut; fsb.analyze(currentLiveIn, currentLiveOut); cleanupLiveOutVariables(fstmt.getBody(), unionLiveIn); } - + public void cleanupLiveOutVariables(List sbs, VariableSet unionLiveIn) { //backwards pass to collect union of livein variables of all successors //and cleanup unnecessary liveout variables @@ -257,7 +257,7 @@ public void cleanupLiveOutVariables(List sbs, VariableSet unionL public void constructHops(DMLProgram dmlp) { constructHops(dmlp, true); } - + public void constructHops(DMLProgram dmlp, boolean inclFuns) { // Step 1: construct hops for all functions if( inclFuns ) { @@ -266,7 +266,7 @@ public void constructHops(DMLProgram dmlp, boolean inclFuns) { for( FunctionStatementBlock fsb : fdict.getFunctions().values() ) constructHops(fsb); } - + // Step 2: construct hops for main program // handle regular program blocks for (int i = 0; i < dmlp.getNumStatementBlocks(); i++) { @@ -283,7 +283,7 @@ public void rewriteHopsDAG(DMLProgram dmlp) resetHopsDAGVisitStatus(dmlp); rewriter.rewriteProgramHopDAGs(dmlp, true); //rewrite and split resetHopsDAGVisitStatus(dmlp); - + //propagate size information from main into functions (but conservatively) if( OptimizerUtils.ALLOW_INTER_PROCEDURAL_ANALYSIS ) { InterProceduralAnalysis ipa = new InterProceduralAnalysis(dmlp); @@ -295,12 +295,12 @@ public void rewriteHopsDAG(DMLProgram dmlp) ProgramRewriter rewriter2 = new ProgramRewriter(false, true); rewriter2.rewriteProgramHopDAGs(dmlp); resetHopsDAGVisitStatus(dmlp); - + //compute memory estimates for all the hops. These estimates are used //subsequently in various optimizations, e.g. CP vs. MR scheduling and parfor. refreshMemEstimates(dmlp); resetHopsDAGVisitStatus(dmlp); - + //enhance HOP DAGs by automatic operator fusion DMLConfig dmlconf = ConfigurationManager.getDMLConfig(); if( ConfigurationManager.isCodegenEnabled() ){ @@ -318,19 +318,19 @@ public void rewriteLopDAG(DMLProgram dmlp) { LopRewriter rewriter = new LopRewriter(); rewriter.rewriteProgramLopDAGs(dmlp); } - + public void codgenHopsDAG(DMLProgram dmlp) { SpoofCompiler.generateCode(dmlp); } - + public void codgenHopsDAG(Program rtprog) { SpoofCompiler.generateCode(rtprog); } - + public void codgenHopsDAG(ProgramBlock pb) { SpoofCompiler.generateCodeFromProgramBlock(pb); } - + public void constructLops(DMLProgram dmlp) { // for each namespace, handle function program blocks for( FunctionDictionary fdict : dmlp.getNamespaces().values() ) { @@ -342,7 +342,7 @@ public void constructLops(DMLProgram dmlp) { for( FunctionStatementBlock fsb : fdict.getFunctions(false).values() ) constructLops(fsb); } - + // handle regular program blocks for( StatementBlock sb : dmlp.getStatementBlocks() ) constructLops(sb); @@ -351,54 +351,54 @@ public void constructLops(DMLProgram dmlp) { public boolean constructLops(StatementBlock sb) { boolean ret = false; - + if (sb instanceof WhileStatementBlock) { WhileStatementBlock wsb = (WhileStatementBlock)sb; WhileStatement whileStmt = (WhileStatement)wsb.getStatement(0); ArrayList body = whileStmt.getBody(); - + // step through stmt blocks in while stmt body for (StatementBlock stmtBlock : body) ret |= constructLops(stmtBlock); - + // handle while stmt predicate Lop l = wsb.getPredicateHops().constructLops(); - wsb.setPredicateLops(l); + wsb.setPredicateLops(l); ret |= wsb.updatePredicateRecompilationFlag(); } - + else if (sb instanceof IfStatementBlock) { IfStatementBlock isb = (IfStatementBlock) sb; IfStatement ifStmt = (IfStatement)isb.getStatement(0); ArrayList ifBody = ifStmt.getIfBody(); ArrayList elseBody = ifStmt.getElseBody(); - + // step through stmt blocks in if stmt ifBody for (StatementBlock stmtBlock : ifBody) ret |= constructLops(stmtBlock); - + // step through stmt blocks in if stmt elseBody for (StatementBlock stmtBlock : elseBody) ret |= constructLops(stmtBlock); - + // handle if stmt predicate Lop l = isb.getPredicateHops().constructLops(); isb.setPredicateLops(l); ret |= isb.updatePredicateRecompilationFlag(); } - + else if (sb instanceof ForStatementBlock) //NOTE: applies to ForStatementBlock and ParForStatementBlock { ForStatementBlock fsb = (ForStatementBlock) sb; ForStatement fs = (ForStatement)sb.getStatement(0); ArrayList body = fs.getBody(); - + // step through stmt blocks in FOR stmt body for (StatementBlock stmtBlock : body) ret |= constructLops(stmtBlock); - + // handle for stmt predicate if (fsb.getFromHops() != null){ Lop llobs = fsb.getFromHops().constructLops(); @@ -418,14 +418,14 @@ else if (sb instanceof FunctionStatementBlock) { FunctionStatementBlock fsb = (FunctionStatementBlock) sb; FunctionStatement functStmt = (FunctionStatement)sb.getStatement(0); ArrayList body = functStmt.getBody(); - + // step through stmt blocks in while stmt body for( StatementBlock stmtBlock : body ) ret |= constructLops(stmtBlock); if( fsb.isRecompileOnce() ) fsb.setRecompileOnce(ret); } - + // handle default case for regular StatementBlock else { if (sb.getHops() == null) @@ -436,20 +436,19 @@ else if (sb instanceof FunctionStatementBlock) { sb.setLops(lops); ret |= sb.updateRecompilationFlag(); } - + return ret; } - - + public Program getRuntimeProgram(DMLProgram prog, DMLConfig config) throws LanguageException, DMLRuntimeException, LopsException, HopsException { // constructor resets the set of registered functions Program rtprog = new Program(prog); - + // for all namespaces, translate function statement blocks into function program blocks for (String namespace : prog.getNamespaces().keySet()){ - + for (String fname : prog.getFunctionStatementBlocks(namespace).keySet()){ // add program block to program FunctionStatementBlock fsb = prog.getFunctionStatementBlocks(namespace).get(fname); @@ -461,23 +460,23 @@ public Program getRuntimeProgram(DMLProgram prog, DMLConfig config) } } } - + // translate all top-level statement blocks to program blocks for (StatementBlock sb : prog.getStatementBlocks() ) { // add program block to program ProgramBlock rtpb = createRuntimeProgramBlock(rtprog, sb, config); rtprog.addProgramBlock(rtpb); } - + //enhance runtime program by automatic operator fusion if( ConfigurationManager.isCodegenEnabled() && SpoofCompiler.INTEGRATION==IntegrationType.RUNTIME ){ codgenHopsDAG(rtprog); } - + return rtprog ; } - + private void prepareAndAddFunctionProgramBlock(Program rtprog, DMLConfig config, String fnamespace, String fname, FunctionStatementBlock fsb, boolean opt) { @@ -486,107 +485,107 @@ private void prepareAndAddFunctionProgramBlock(Program rtprog, DMLConfig config, rtpb.setRecompileOnce(fsb.isRecompileOnce()); rtpb.setNondeterministic(fsb.isNondeterministic()); } - + public ProgramBlock createRuntimeProgramBlock(Program prog, StatementBlock sb, DMLConfig config) { Dag dag = null; Dag pred_dag = null; ArrayList instruct; ArrayList pred_instruct = null; - + ProgramBlock retPB = null; - + // process While Statement - add runtime program blocks to program if (sb instanceof WhileStatementBlock){ - + // create DAG for loop predicates pred_dag = new Dag<>(); ((WhileStatementBlock) sb).getPredicateLops().addToDag(pred_dag); - + // create instructions for loop predicates pred_instruct = new ArrayList<>(); ArrayList pInst = pred_dag.getJobs(null, config); for (Instruction i : pInst ) { pred_instruct.add(i); } - + // create while program block WhileProgramBlock rtpb = new WhileProgramBlock(prog, pred_instruct); - + //// process the body of the while statement block //// - + WhileStatementBlock wsb = (WhileStatementBlock)sb; WhileStatement wstmt = (WhileStatement)wsb.getStatement(0); for (StatementBlock sblock : wstmt.getBody()){ - + // process the body ProgramBlock childBlock = createRuntimeProgramBlock(prog, sblock, config); rtpb.addProgramBlock(childBlock); } - + retPB = rtpb; - + //post processing for generating missing instructions retPB.setExitInstruction(deriveExitInstruction(sb)); - + // add statement block retPB.setStatementBlock(sb); - + // add location information retPB.setParseInfo(sb); } - + // process If Statement - add runtime program blocks to program else if (sb instanceof IfStatementBlock){ - + // create DAG for loop predicates pred_dag = new Dag<>(); ((IfStatementBlock) sb).getPredicateLops().addToDag(pred_dag); - + // create instructions for loop predicates pred_instruct = new ArrayList<>(); ArrayList pInst = pred_dag.getJobs(null, config); for (Instruction i : pInst ) { pred_instruct.add(i); } - + // create if program block IfProgramBlock rtpb = new IfProgramBlock(prog, pred_instruct); - + // process the body of the if statement block IfStatementBlock isb = (IfStatementBlock)sb; IfStatement istmt = (IfStatement)isb.getStatement(0); - + // process the if body for (StatementBlock sblock : istmt.getIfBody()){ ProgramBlock childBlock = createRuntimeProgramBlock(prog, sblock, config); rtpb.addProgramBlockIfBody(childBlock); } - + // process the else body for (StatementBlock sblock : istmt.getElseBody()){ ProgramBlock childBlock = createRuntimeProgramBlock(prog, sblock, config); rtpb.addProgramBlockElseBody(childBlock); } - + retPB = rtpb; - + //post processing for generating missing instructions retPB.setExitInstruction(deriveExitInstruction(sb)); - + // add statement block retPB.setStatementBlock(sb); - + // add location information retPB.setParseInfo(sb); } - + // process For Statement - add runtime program blocks to program // NOTE: applies to ForStatementBlock and ParForStatementBlock else if (sb instanceof ForStatementBlock) { ForStatementBlock fsb = (ForStatementBlock) sb; - + // create DAGs for loop predicates Dag fromDag = new Dag<>(); Dag toDag = new Dag<>(); @@ -597,7 +596,7 @@ else if (sb instanceof ForStatementBlock) fsb.getToLops().addToDag(toDag); if( fsb.getIncrementHops()!=null ) fsb.getIncrementLops().addToDag(incrementDag); - + // create instructions for loop predicates ArrayList fromInstructions = fromDag.getJobs(null, config); ArrayList toInstructions = toDag.getJobs(null, config); @@ -606,7 +605,7 @@ else if (sb instanceof ForStatementBlock) // create for program block ForProgramBlock rtpb = null; IterablePredicate iterPred = fsb.getIterPredicate(); - + if( sb instanceof ParForStatementBlock && ConfigurationManager.isParallelParFor() ) { rtpb = new ParForProgramBlock(prog, iterPred.getIterVar().getName(), iterPred.getParForParams(), ((ParForStatementBlock)sb).getResultVariables()); @@ -616,93 +615,92 @@ else if (sb instanceof ForStatementBlock) else {//ForStatementBlock rtpb = new ForProgramBlock(prog, iterPred.getIterVar().getName()); } - + rtpb.setFromInstructions(fromInstructions); rtpb.setToInstructions(toInstructions); rtpb.setIncrementInstructions(incrementInstructions); - + // process the body of the for statement block ForStatement fs = (ForStatement)fsb.getStatement(0); for (StatementBlock sblock : fs.getBody()){ ProgramBlock childBlock = createRuntimeProgramBlock(prog, sblock, config); rtpb.addProgramBlock(childBlock); } - + retPB = rtpb; - + //post processing for generating missing instructions retPB.setExitInstruction(deriveExitInstruction(sb)); - + // add statement block retPB.setStatementBlock(sb); - + // add location information retPB.setParseInfo(sb); } - + // process function statement block - add runtime program blocks to program else if (sb instanceof FunctionStatementBlock){ - + FunctionStatementBlock fsb = (FunctionStatementBlock)sb; FunctionStatement fstmt = (FunctionStatement)fsb.getStatement(0); FunctionProgramBlock rtpb = null; - - + // create function program block rtpb = new FunctionProgramBlock(prog, fstmt.getInputParams(), fstmt.getOutputParams()); - + // process the function statement body for (StatementBlock sblock : fstmt.getBody()){ // process the body ProgramBlock childBlock = createRuntimeProgramBlock(prog, sblock, config); rtpb.addProgramBlock(childBlock); } - + // check there are actually Lops in to process (loop stmt body will not have any) if (fsb.getLops() != null && !fsb.getLops().isEmpty()){ throw new LopsException(fsb.printBlockErrorLocation() + "FunctionStatementBlock should have no Lops"); } - + retPB = rtpb; - + // add statement block retPB.setStatementBlock(sb); - + // add location information retPB.setParseInfo(sb); } else { - + // handle general case BasicProgramBlock rtpb = new BasicProgramBlock(prog); - + // DAGs for Lops dag = new Dag<>(); // check there are actually Lops in to process (loop stmt body will not have any) if (sb.getLops() != null && !sb.getLops().isEmpty()){ - + for (Lop l : sb.getLops()) { l.addToDag(dag); } - + // Instructions for Lops DAGs instruct = dag.getJobs(sb, config); rtpb.addInstructions(instruct); } - + retPB = rtpb; - + //post processing for generating missing instructions //retPB.setExitInstruction(deriveExitInstruction(sb)); - + // add statement block retPB.setStatementBlock(sb); - + // add location information retPB.setParseInfo(sb); } - + return retPB; } @@ -715,14 +713,14 @@ public static void refreshMemEstimates(DMLProgram dmlp) { refreshMemEstimates(fsblock); } } - + // handle statement blocks in "main" method for (int i = 0; i < dmlp.getNumStatementBlocks(); i++) { StatementBlock current = dmlp.getStatementBlock(i); refreshMemEstimates(current); } } - + private static Instruction deriveExitInstruction(StatementBlock sb) { Set rmVars = VariableSet.union( VariableSet.minus(sb.liveIn(), sb.liveOut()), @@ -730,20 +728,20 @@ private static Instruction deriveExitInstruction(StatementBlock sb) { return rmVars.isEmpty() ? null : VariableCPInstruction.prepareRemoveInstruction(rmVars.toArray(new String[0])); } - + public static void refreshMemEstimates(StatementBlock current) { - + MemoTable memo = new MemoTable(); - + if( HopRewriteUtils.isLastLevelStatementBlock(current) ) { ArrayList hopsDAG = current.getHops(); if (hopsDAG != null && !hopsDAG.isEmpty()) for( Hop hop : hopsDAG ) hop.refreshMemEstimates(memo); } - + if (current instanceof FunctionStatementBlock) { - + FunctionStatement fstmt = (FunctionStatement)current.getStatement(0); for (StatementBlock sb : fstmt.getBody()){ refreshMemEstimates(sb); @@ -753,11 +751,11 @@ else if (current instanceof WhileStatementBlock) { // handle predicate WhileStatementBlock wstb = (WhileStatementBlock) current; wstb.getPredicateHops().refreshMemEstimates(new MemoTable()); - + if (wstb.getNumStatements() > 1) LOG.debug("While statement block has more than 1 stmt"); WhileStatement ws = (WhileStatement)wstb.getStatement(0); - + for (StatementBlock sb : ws.getBody()){ refreshMemEstimates(sb); } @@ -766,11 +764,11 @@ else if (current instanceof IfStatementBlock) { // handle predicate IfStatementBlock istb = (IfStatementBlock) current; istb.getPredicateHops().refreshMemEstimates(new MemoTable()); - + if (istb.getNumStatements() > 1) LOG.debug("If statement block has more than 1 stmt"); IfStatement is = (IfStatement)istb.getStatement(0); - + for (StatementBlock sb : is.getIfBody()){ refreshMemEstimates(sb); } @@ -787,17 +785,17 @@ else if (current instanceof ForStatementBlock) { fsb.getToHops().refreshMemEstimates(new MemoTable()); if (fsb.getIncrementHops() != null) fsb.getIncrementHops().refreshMemEstimates(new MemoTable()); - + if (fsb.getNumStatements() > 1) LOG.debug("For statement block has more than 1 stmt"); ForStatement ws = (ForStatement)fsb.getStatement(0); - + for (StatementBlock sb : ws.getBody()){ refreshMemEstimates(sb); } } } - + public static void resetHopsDAGVisitStatus(DMLProgram dmlp) { // for each namespace, handle function program blocks -- forward direction @@ -807,23 +805,23 @@ public static void resetHopsDAGVisitStatus(DMLProgram dmlp) { resetHopsDAGVisitStatus(fsblock); } } - + // handle statement blocks in "main" method for (int i = 0; i < dmlp.getNumStatementBlocks(); i++) { StatementBlock current = dmlp.getStatementBlock(i); resetHopsDAGVisitStatus(current); } } - + public static void resetHopsDAGVisitStatus(StatementBlock current) { - + if( HopRewriteUtils.isLastLevelStatementBlock(current) ) { ArrayList hopsDAG = current.getHops(); if (hopsDAG != null && !hopsDAG.isEmpty() ) { Hop.resetVisitStatus(hopsDAG); } } - + if (current instanceof FunctionStatementBlock) { FunctionStatement fstmt = (FunctionStatement)current.getStatement(0); for (StatementBlock sb : fstmt.getBody()){ @@ -834,7 +832,7 @@ else if (current instanceof WhileStatementBlock) { // handle predicate WhileStatementBlock wstb = (WhileStatementBlock) current; wstb.getPredicateHops().resetVisitStatus(); - + WhileStatement ws = (WhileStatement)wstb.getStatement(0); for (StatementBlock sb : ws.getBody()) resetHopsDAGVisitStatus(sb); @@ -843,7 +841,7 @@ else if (current instanceof IfStatementBlock) { // handle predicate IfStatementBlock istb = (IfStatementBlock) current; istb.getPredicateHops().resetVisitStatus(); - + IfStatement is = (IfStatement)istb.getStatement(0); for (StatementBlock sb : is.getIfBody()) resetHopsDAGVisitStatus(sb); @@ -859,19 +857,19 @@ else if (current instanceof ForStatementBlock) { fsb.getToHops().resetVisitStatus(); if (fsb.getIncrementHops() != null) fsb.getIncrementHops().resetVisitStatus(); - + if (fsb.getNumStatements() > 1) LOG.debug("For statment block has more than 1 stmt"); ForStatement ws = (ForStatement)fsb.getStatement(0); - + for (StatementBlock sb : ws.getBody()){ resetHopsDAGVisitStatus(sb); } } } - + public void resetLopsDAGVisitStatus(DMLProgram dmlp) { - + // for each namespace, handle function program blocks for (String namespaceKey : dmlp.getNamespaces().keySet()){ for (String fname : dmlp.getFunctionStatementBlocks(namespaceKey).keySet()){ @@ -879,15 +877,15 @@ public void resetLopsDAGVisitStatus(DMLProgram dmlp) { resetLopsDAGVisitStatus(fsblock); } } - + for (int i = 0; i < dmlp.getNumStatementBlocks(); i++) { StatementBlock current = dmlp.getStatementBlock(i); resetLopsDAGVisitStatus(current); } } - + public void resetLopsDAGVisitStatus(StatementBlock current) { - + ArrayList hopsDAG = current.getHops(); if (hopsDAG != null && !hopsDAG.isEmpty() ) { @@ -897,66 +895,64 @@ public void resetLopsDAGVisitStatus(StatementBlock current) { currentHop.getLops().resetVisitStatus(); } } - + if (current instanceof FunctionStatementBlock) { FunctionStatementBlock fsb = (FunctionStatementBlock) current; FunctionStatement fs = (FunctionStatement)fsb.getStatement(0); - + for (StatementBlock sb : fs.getBody()){ resetLopsDAGVisitStatus(sb); } } - - + if (current instanceof WhileStatementBlock) { WhileStatementBlock wstb = (WhileStatementBlock) current; wstb.getPredicateLops().resetVisitStatus(); if (wstb.getNumStatements() > 1) LOG.debug("While statement block has more than 1 stmt"); WhileStatement ws = (WhileStatement)wstb.getStatement(0); - + for (StatementBlock sb : ws.getBody()){ resetLopsDAGVisitStatus(sb); } } - + if (current instanceof IfStatementBlock) { IfStatementBlock istb = (IfStatementBlock) current; istb.getPredicateLops().resetVisitStatus(); if (istb.getNumStatements() > 1) LOG.debug("If statement block has more than 1 stmt"); IfStatement is = (IfStatement)istb.getStatement(0); - + for (StatementBlock sb : is.getIfBody()){ resetLopsDAGVisitStatus(sb); } - + for (StatementBlock sb : is.getElseBody()){ resetLopsDAGVisitStatus(sb); } } - + if (current instanceof ForStatementBlock) { ForStatementBlock fsb = (ForStatementBlock) current; - + if (fsb.getFromLops() != null) fsb.getFromLops().resetVisitStatus(); if (fsb.getToLops() != null) fsb.getToLops().resetVisitStatus(); if (fsb.getIncrementLops() != null) fsb.getIncrementLops().resetVisitStatus(); - + if (fsb.getNumStatements() > 1) LOG.debug("For statement block has more than 1 stmt"); ForStatement ws = (ForStatement)fsb.getStatement(0); - + for (StatementBlock sb : ws.getBody()){ resetLopsDAGVisitStatus(sb); } } } - public void constructHops(StatementBlock sb) { if (sb instanceof WhileStatementBlock) { constructHopsForWhileControlBlock((WhileStatementBlock) sb); @@ -967,17 +963,17 @@ public void constructHops(StatementBlock sb) { constructHopsForIfControlBlock((IfStatementBlock) sb); return; } - + if (sb instanceof ForStatementBlock) { //incl ParForStatementBlock constructHopsForForControlBlock((ForStatementBlock) sb); return; } - + if (sb instanceof FunctionStatementBlock) { constructHopsForFunctionControlBlock((FunctionStatementBlock) sb); return; } - + HashMap ids = new HashMap<>(); ArrayList output = new ArrayList<>(); @@ -991,7 +987,7 @@ public void constructHops(StatementBlock sb) { HashMap liveOutToTemp = new HashMap<>(); for (int i = 0; i < sb.getNumStatements(); i++) { Statement current = sb.getStatement(i); - + if (current instanceof AssignmentStatement) { AssignmentStatement as = (AssignmentStatement) current; DataIdentifier target = as.getTarget(); @@ -1003,7 +999,7 @@ public void constructHops(StatementBlock sb) { } if (current instanceof MultiAssignmentStatement) { MultiAssignmentStatement mas = (MultiAssignmentStatement) current; - + for (DataIdentifier target : mas.getTargetList()){ if (liveOut.containsVariable(target.getName())) { liveOutToTemp.put(target.getName(), Integer.valueOf(i)); @@ -1015,11 +1011,11 @@ public void constructHops(StatementBlock sb) { // only create transient read operations for variables either updated or read-before-update // (i.e., from LV analysis, updated and gen sets) if ( !liveIn.getVariables().values().isEmpty() ) { - + for (String varName : liveIn.getVariables().keySet()) { if (updated.containsVariable(varName) || gen.containsVariable(varName)){ - + DataIdentifier var = liveIn.getVariables().get(varName); long actualDim1 = (var instanceof IndexedIdentifier) ? ((IndexedIdentifier)var).getOrigDim1() : var.getDim1(); long actualDim2 = (var instanceof IndexedIdentifier) ? ((IndexedIdentifier)var).getOrigDim2() : var.getDim2(); @@ -1030,7 +1026,6 @@ public void constructHops(StatementBlock sb) { } } - for( int i = 0; i < sb.getNumStatements(); i++ ) { Statement current = sb.getStatement(i); @@ -1143,8 +1138,7 @@ else if (ptype == PRINTTYPE.STOP) { DataIdentifier target = as.getTarget(); Expression source = as.getSource(); - - // CASE: regular assignment statement -- source is DML expression that is NOT user-defined or external function + // CASE: regular assignment statement -- source is DML expression that is NOT user-defined or external function if (!(source instanceof FunctionCallIdentifier)){ // CASE: target is regular data identifier @@ -1176,7 +1170,7 @@ else if (ptype == PRINTTYPE.STOP) { updatedLiveOut.addVariable(target.getName(), target); output.add(transientwrite); } - } + } // CASE: target is indexed identifier (left-hand side indexed expression) else { Hop ae = processLeftIndexedExpression(source, (IndexedIdentifier)target, ids); @@ -1224,8 +1218,8 @@ else if (ptype == PRINTTYPE.STOP) { FunctionStatementBlock fsb = this._dmlProg.getFunctionStatementBlock(fci.getNamespace(),fci.getName()); //error handling missing function - if (fsb == null) { - throw new LanguageException(source.printErrorLocation() + "function " + if (fsb == null) { + throw new LanguageException(source.printErrorLocation() + "function " + fci.getName() + " is undefined in namespace " + fci.getNamespace()); } @@ -1272,7 +1266,7 @@ else if (current instanceof MultiAssignmentStatement) { FunctionCallIdentifier fci = (FunctionCallIdentifier) source; FunctionStatementBlock fsb = this._dmlProg.getFunctionStatementBlock(fci.getNamespace(),fci.getName()); if (fsb == null){ - throw new LanguageException(source.printErrorLocation() + "function " + throw new LanguageException(source.printErrorLocation() + "function " + fci.getName() + " is undefined in namespace " + fci.getNamespace()); } @@ -1319,7 +1313,7 @@ else if ( source instanceof ParameterizedBuiltinFunctionExpression && ((Paramete sb.updateLiveVariablesOut(updatedLiveOut); sb.setHops(output); } - + private static DataIdentifier getAccumulatorData(VariableSet liveIn, String varname) { DataIdentifier accum = liveIn.getVariable(varname); if( accum == null ) @@ -1327,13 +1321,13 @@ private static DataIdentifier getAccumulatorData(VariableSet liveIn, String varn + "to non-existing variable "+varname+"."); return accum; } - + private void appendDefaultArguments(FunctionStatement fstmt, List inputNames, List inputs, HashMap ids) { //NOTE: For default expressions of unspecified function arguments, we have two choices: //either (a) compile ifelse(exist(argName),default, argName) into the function, or //simply (b) add the default to the argument list of function calls when needed. //We decided for (b) because it simplifies IPA and dynamic recompilation. - + if( fstmt.getInputParams().size() == inputs.size() ) return; HashSet probeNames = new HashSet<>(inputNames); @@ -1349,26 +1343,26 @@ private void appendDefaultArguments(FunctionStatement fstmt, List inputN inputs.add(processExpression(exp, null, ids)); } } - + public void constructHopsForIfControlBlock(IfStatementBlock sb) { IfStatement ifsb = (IfStatement) sb.getStatement(0); ArrayList ifBody = ifsb.getIfBody(); ArrayList elseBody = ifsb.getElseBody(); - + // construct hops for predicate in if statement constructHopsForConditionalPredicate(sb); - + // handle if statement body for( StatementBlock current : ifBody ) { constructHops(current); } - + // handle else stmt body for( StatementBlock current : elseBody ) { constructHops(current); } } - + /** * Constructs Hops for a given ForStatementBlock or ParForStatementBlock, respectively. * @@ -1381,28 +1375,27 @@ public void constructHopsForForControlBlock(ForStatementBlock sb) { for( StatementBlock current : body ) constructHops(current); } - + public void constructHopsForFunctionControlBlock(FunctionStatementBlock fsb) { ArrayList body = ((FunctionStatement)fsb.getStatement(0)).getBody(); for( StatementBlock current : body ) constructHops(current); } - + public void constructHopsForWhileControlBlock(WhileStatementBlock sb) { ArrayList body = ((WhileStatement)sb.getStatement(0)).getBody(); constructHopsForConditionalPredicate(sb); for( StatementBlock current : body ) constructHops(current); } - - + public void constructHopsForConditionalPredicate(StatementBlock passedSB) { HashMap _ids = new HashMap<>(); - + // set conditional predicate ConditionalPredicate cp = null; - + if (passedSB instanceof WhileStatementBlock){ WhileStatement ws = (WhileStatement) ((WhileStatementBlock)passedSB).getStatement(0); cp = ws.getConditionalPredicate(); @@ -1414,36 +1407,36 @@ else if (passedSB instanceof IfStatementBlock) { else { throw new ParseException("ConditionalPredicate expected only for while or if statements."); } - + VariableSet varsRead = cp.variablesRead(); for (String varName : varsRead.getVariables().keySet()) { - + // creating transient read for live in variables DataIdentifier var = passedSB.liveIn().getVariables().get(varName); - + DataOp read = null; - + if (var == null) { throw new ParseException("variable " + varName + " not live variable for conditional predicate"); } else { long actualDim1 = (var instanceof IndexedIdentifier) ? ((IndexedIdentifier)var).getOrigDim1() : var.getDim1(); long actualDim2 = (var instanceof IndexedIdentifier) ? ((IndexedIdentifier)var).getOrigDim2() : var.getDim2(); - + read = new DataOp(var.getName(), var.getDataType(), var.getValueType(), OpOpData.TRANSIENTREAD, null, actualDim1, actualDim2, var.getNnz(), var.getBlocksize()); read.setParseInfo(var); } _ids.put(varName, read); } - + DataIdentifier target = new DataIdentifier(Expression.getTempName()); target.setDataType(DataType.SCALAR); target.setValueType(ValueType.BOOLEAN); target.setParseInfo(passedSB); Hop predicateHops = null; Expression predicate = cp.getPredicate(); - + if (predicate instanceof RelationalExpression) { predicateHops = processRelationalExpression((RelationalExpression) cp.getPredicate(), target, _ids); } else if (predicate instanceof BooleanExpression) { @@ -1471,19 +1464,18 @@ else if (passedSB instanceof IfStatementBlock) { } predicateHops = processExpression(cp.getPredicate(), null, _ids); } - + //create transient write to internal variable name on top of expression //in order to ensure proper instruction generation predicateHops = HopRewriteUtils.createDataOp( ProgramBlock.PRED_VAR, predicateHops, OpOpData.TRANSIENTWRITE); - + if (passedSB instanceof WhileStatementBlock) ((WhileStatementBlock)passedSB).setPredicateHops(predicateHops); else if (passedSB instanceof IfStatementBlock) ((IfStatementBlock)passedSB).setPredicateHops(predicateHops); } - /** * Constructs all predicate Hops (for FROM, TO, INCREMENT) of an iterable predicate * and assigns these Hops to the passed statement block. @@ -1495,19 +1487,19 @@ else if (passedSB instanceof IfStatementBlock) public void constructHopsForIterablePredicate(ForStatementBlock fsb) { HashMap _ids = new HashMap<>(); - + // set iterable predicate ForStatement fs = (ForStatement) fsb.getStatement(0); IterablePredicate ip = fs.getIterablePredicate(); - + for(int i=0; i < 3; i++) { Expression expr = (i == 0) ? ip.getFromExpr() : (i == 1) ? ip.getToExpr() : ( ip.getIncrementExpr() != null ) ? ip.getIncrementExpr() : null; VariableSet varsRead = (expr != null) ? expr.variablesRead() : null; - + if(varsRead != null) { for (String varName : varsRead.getVariables().keySet()) { - + DataIdentifier var = fsb.liveIn().getVariable(varName); DataOp read = null; if (var == null) { @@ -1523,15 +1515,15 @@ public void constructHopsForIterablePredicate(ForStatementBlock fsb) _ids.put(varName, read); } } - + //create transient write to internal variable name on top of expression //in order to ensure proper instruction generation Hop predicateHops = processTempIntExpression(expr, _ids); if( predicateHops != null ) predicateHops = HopRewriteUtils.createDataOp( ProgramBlock.PRED_VAR, predicateHops, OpOpData.TRANSIENTWRITE); - - //construct hops for from, to, and increment expressions + + //construct hops for from, to, and increment expressions if( i == 0 ) fsb.setFromHops( predicateHops ); else if( i == 1 ) @@ -1540,7 +1532,7 @@ else if( ip.getIncrementExpr() != null ) fsb.setIncrementHops( predicateHops ); } } - + /** * Construct Hops from parse tree : Process Expression in an assignment * statement @@ -1641,8 +1633,7 @@ private static DataIdentifier createTarget(Expression source) { private static DataIdentifier createTarget() { return new DataIdentifier(Expression.getTempName()); } - - + /** * Constructs the Hops for arbitrary expressions that eventually evaluate to an INT scalar. * @@ -1659,63 +1650,62 @@ private Hop processTempIntExpression( Expression source, HashMap h source.setOutput(tmpOut); return processExpression(source, tmpOut, hops ); } - + private Hop processLeftIndexedExpression(Expression source, IndexedIdentifier target, HashMap hops) { // process target indexed expressions Hop[] ixRange = getIndexingBounds(target, hops, true); - + // process the source expression to get source Hops Hop sourceOp = processExpression(source, target, hops); - + // process the target to get targetHops Hop targetOp = hops.get(target.getName()); if (targetOp == null){ throw new ParseException(target.printErrorLocation() + " must define matrix " + target.getName() + " before indexing operations are allowed "); } - + if( sourceOp.getDataType().isMatrix() && source.getOutput().getDataType().isScalar() ) sourceOp.setDataType(DataType.SCALAR); - + Hop leftIndexOp = new LeftIndexingOp(target.getName(), target.getDataType(), ValueType.FP64, targetOp, sourceOp, ixRange[0], ixRange[1], ixRange[2], ixRange[3], target.getRowLowerEqualsUpper(), target.getColLowerEqualsUpper()); - + setIdentifierParams(leftIndexOp, target); leftIndexOp.setParseInfo(target); leftIndexOp.setDim1(target.getOrigDim1()); leftIndexOp.setDim2(target.getOrigDim2()); - + return leftIndexOp; } - - + private Hop processIndexingExpression(IndexedIdentifier source, DataIdentifier target, HashMap hops) { // process Hops for indexes (for source) Hop[] ixRange = getIndexingBounds(source, hops, false); - + if (target == null) { target = createTarget(source); } //unknown nnz after range indexing (applies to indexing op but also //data dependent operations) target.setNnz(-1); - + Hop indexOp = new IndexingOp(target.getName(), target.getDataType(), target.getValueType(), hops.get(source.getName()), ixRange[0], ixRange[1], ixRange[2], ixRange[3], source.getRowLowerEqualsUpper(), source.getColLowerEqualsUpper()); - + indexOp.setParseInfo(target); setIdentifierParams(indexOp, target); - + return indexOp; } - + private Hop[] getIndexingBounds(IndexedIdentifier ix, HashMap hops, boolean lix) { Hop rowLowerHops = (ix.getRowLowerBound() != null) ? processExpression(ix.getRowLowerBound(),null, hops) : new LiteralOp(1); Hop colLowerHops = (ix.getColLowerBound() != null) ? processExpression(ix.getColLowerBound(),null, hops) : new LiteralOp(1); - + Hop rowUpperHops = null, colUpperHops = null; if (ix.getRowUpperBound() != null) rowUpperHops = processExpression(ix.getRowUpperBound(),null,hops); @@ -1725,7 +1715,7 @@ private Hop[] getIndexingBounds(IndexedIdentifier ix, HashMap hops, new UnaryOp(ix.getName(), DataType.SCALAR, ValueType.INT64, OpOp1.NROW, hops.get(ix.getName())); rowUpperHops.setParseInfo(ix); } - + if (ix.getColUpperBound() != null) colUpperHops = processExpression(ix.getColUpperBound(),null,hops); else { @@ -1734,11 +1724,10 @@ private Hop[] getIndexingBounds(IndexedIdentifier ix, HashMap hops, new UnaryOp(ix.getName(), DataType.SCALAR, ValueType.INT64, OpOp1.NCOL, hops.get(ix.getName())); colUpperHops.setParseInfo(ix); } - + return new Hop[] {rowLowerHops, rowUpperHops, colLowerHops, colUpperHops}; } - - + /** * Construct Hops from parse tree : Process Binary Expression in an * assignment statement @@ -1757,14 +1746,14 @@ private Hop processBinaryExpression(BinaryExpression source, DataIdentifier targ throw new ParseException("Missing input in binary expressions (" + source.toString()+"): " + ((left==null)?source.getLeft():source.getRight())+", line="+source.getBeginLine()); } - + //prepare target identifier and ensure that output type is of inferred type //(type should not be determined by target (e.g., string for print) if (target == null) { target = createTarget(source); } target.setValueType(source.getOutput().getValueType()); - + Hop currBop = null; switch( source.getOpCode() ) { case PLUS: @@ -1783,7 +1772,7 @@ private Hop processBinaryExpression(BinaryExpression source, DataIdentifier targ default: throw new ParseException("Unsupported parsing of binary expression: "+source.getOpCode()); } - + setIdentifierParams(currBop, source.getOutput()); currBop.setParseInfo(source); return currBop; @@ -1814,7 +1803,7 @@ else if(left.getDataType() == DataType.FRAME || right.getDataType() == DataType. target.setValueType(ValueType.BOOLEAN); } } - + OpOp2 op = null; if (source.getOpCode() == Expression.RelationalOp.LESS) { @@ -1860,7 +1849,7 @@ private Hop processBooleanExpression(BooleanExpression source, DataIdentifier ta target = createTarget(source); if( target.getDataType().isScalar() ) target.setValueType(ValueType.BOOLEAN); - + if (source.getRight() == null) { Hop currUop = new UnaryOp(target.getName(), target.getDataType(), target.getValueType(), OpOp1.NOT, left); currUop.setParseInfo(source); @@ -1885,7 +1874,7 @@ private Hop processBooleanExpression(BooleanExpression source, DataIdentifier ta } private static Hop constructDfHop(String name, DataType dt, ValueType vt, Builtins op, LinkedHashMap paramHops) { - + // Add a hop to paramHops to store distribution information. // Distribution parameter hops would have been already present in paramHops. Hop distLop = null; @@ -1910,30 +1899,30 @@ private static Hop constructDfHop(String name, DataType dt, ValueType vt, Builti case PEXP: distLop = new LiteralOp("exp"); break; - + case CDF: case INVCDF: break; - + default: throw new HopsException("Invalid operation: " + op); } if (distLop != null) paramHops.put("dist", distLop); - + return new ParameterizedBuiltinOp(name, dt, vt, ParameterizedBuiltinFunctionExpression.pbHopMap.get(op), paramHops); } - + private Hop processMultipleReturnParameterizedBuiltinFunctionExpression(ParameterizedBuiltinFunctionExpression source, ArrayList targetList, HashMap hops) { FunctionType ftype = FunctionType.MULTIRETURN_BUILTIN; String nameSpace = DMLProgram.INTERNAL_NAMESPACE; - + // Create an array list to hold the outputs of this lop. // Exact list of outputs are added based on opcode. ArrayList outputs = new ArrayList<>(); - + // Construct Hop for current builtin function expression based on its type Hop currBuiltinOp = null; switch (source.getOpCode()) { @@ -1946,10 +1935,10 @@ private Hop processMultipleReturnParameterizedBuiltinFunctionExpression(Paramete outputNames[1] = targetList.get(1).getName(); outputs.add(new DataOp(outputNames[0], DataType.MATRIX, ValueType.FP64, inputs.get(0), OpOpData.FUNCTIONOUTPUT, inputs.get(0).getFilename())); outputs.add(new DataOp(outputNames[1], DataType.FRAME, ValueType.STRING, inputs.get(0), OpOpData.FUNCTIONOUTPUT, inputs.get(0).getFilename())); - + currBuiltinOp = new FunctionOp(ftype, nameSpace, source.getOpCode().toString(), null, inputs, outputNames, outputs); break; - + default: throw new ParseException("Invaid Opcode in DMLTranslator:processMultipleReturnParameterizedBuiltinFunctionExpression(): " + source.getOpCode()); } @@ -1963,7 +1952,7 @@ private Hop processMultipleReturnParameterizedBuiltinFunctionExpression(Paramete return currBuiltinOp; } - + /** * Construct Hops from parse tree : Process ParameterizedBuiltinFunction Expression in an * assignment statement @@ -1975,10 +1964,10 @@ private Hop processMultipleReturnParameterizedBuiltinFunctionExpression(Paramete */ private Hop processParameterizedBuiltinFunctionExpression(ParameterizedBuiltinFunctionExpression source, DataIdentifier target, HashMap hops) { - + // this expression has multiple "named" parameters LinkedHashMap paramHops = new LinkedHashMap<>(); - + // -- construct hops for all input parameters // -- store them in hashmap so that their "name"s are maintained Hop pHop = null; @@ -1986,13 +1975,13 @@ private Hop processParameterizedBuiltinFunctionExpression(ParameterizedBuiltinFu pHop = processExpression(source.getVarParam(paramName), null, hops); paramHops.put(paramName, pHop); } - + Hop currBuiltinOp = null; if (target == null) { target = createTarget(source); } - + // construct hop based on opcode switch(source.getOpCode()) { case CDF: @@ -2010,7 +1999,6 @@ private Hop processParameterizedBuiltinFunctionExpression(ParameterizedBuiltinFu currBuiltinOp = constructDfHop(target.getName(), target.getDataType(), target.getValueType(), source.getOpCode(), paramHops); break; - case CONTAINS: case GROUPEDAGG: case RMEMPTY: @@ -2027,7 +2015,7 @@ private Hop processParameterizedBuiltinFunctionExpression(ParameterizedBuiltinFu currBuiltinOp = new ParameterizedBuiltinOp(target.getName(), target.getDataType(), target.getValueType(), ParamBuiltinOp.valueOf(source.getOpCode().name()), paramHops); break; - + case ORDER: ArrayList inputs = new ArrayList<>(); inputs.add(paramHops.get("target")); @@ -2036,7 +2024,7 @@ private Hop processParameterizedBuiltinFunctionExpression(ParameterizedBuiltinFu inputs.add(paramHops.get("index.return")); currBuiltinOp = new ReorgOp(target.getName(), target.getDataType(), target.getValueType(), ReOrgOp.SORT, inputs); break; - + case TOSTRING: //check for input data type and only compile toString Hop for matrices/frames, //for scalars, we compile (s + "") to ensure consistent string output value types @@ -2110,12 +2098,12 @@ private Hop processParameterizedBuiltinFunctionExpression(ParameterizedBuiltinFu throw new ParseException(source.printErrorLocation() + "processParameterizedBuiltinFunctionExpression() -- Unknown operation: " + source.getOpCode()); } - + setIdentifierParams(currBuiltinOp, source.getOutput()); currBuiltinOp.setParseInfo(source); return currBuiltinOp; } - + /** * Construct Hops from parse tree : Process ParameterizedExpression in a * read/write/rand statement @@ -2127,10 +2115,10 @@ private Hop processParameterizedBuiltinFunctionExpression(ParameterizedBuiltinFu */ private Hop processDataExpression(DataExpression source, DataIdentifier target, HashMap hops) { - + // this expression has multiple "named" parameters HashMap paramHops = new HashMap<>(); - + // -- construct hops for all input parameters // -- store them in hashmap so that their "name"s are maintained Hop pHop = null; @@ -2138,25 +2126,25 @@ private Hop processDataExpression(DataExpression source, DataIdentifier target, pHop = processExpression(source.getVarParam(paramName), null, hops); paramHops.put(paramName, pHop); } - + Hop currBuiltinOp = null; if (target == null) { target = createTarget(source); } - + // construct hop based on opcode switch(source.getOpCode()) { case READ: currBuiltinOp = new DataOp(target.getName(), target.getDataType(), target.getValueType(), OpOpData.PERSISTENTREAD, paramHops); ((DataOp)currBuiltinOp).setFileName(((StringIdentifier)source.getVarParam(DataExpression.IO_FILENAME)).getValue()); break; - + case WRITE: currBuiltinOp = new DataOp(target.getName(), target.getDataType(), target.getValueType(), OpOpData.PERSISTENTWRITE, hops.get(target.getName()), paramHops); break; - + case RAND: // We limit RAND_MIN, RAND_MAX, RAND_SPARSITY, RAND_SEED, and RAND_PDF to be constants OpOpDG method = (paramHops.get(DataExpression.RAND_MIN).getValueType()==ValueType.STRING && @@ -2182,12 +2170,12 @@ private Hop processDataExpression(DataExpression source, DataIdentifier target, currBuiltinOp = new ReorgOp(target.getName(), target.getDataType(), target.getValueType(), ReOrgOp.RESHAPE, tmpMatrix); break; - + case SQL: currBuiltinOp = new DataOp(target.getName(), target.getDataType(), target.getValueType(), OpOpData.SQLREAD, paramHops); break; - + case FEDERATED: currBuiltinOp = new DataOp(target.getName(), target.getDataType(), target.getValueType(), OpOpData.FEDERATED, paramHops); @@ -2197,7 +2185,7 @@ private Hop processDataExpression(DataExpression source, DataIdentifier target, throw new ParseException(source.printErrorLocation() + "processDataExpression():: Unknown operation: " + source.getOpCode()); } - + //set identifier meta data (incl dimensions and blocksizes) setIdentifierParams(currBuiltinOp, source.getOutput()); if( source.getOpCode()==DataExpression.DataOp.READ ) @@ -2209,7 +2197,7 @@ else if ( source.getOpCode() == DataExpression.DataOp.WRITE ) { source.getVarParam(DataExpression.ROWBLOCKCOUNTPARAM).toString())); } currBuiltinOp.setParseInfo(source); - + return currBuiltinOp; } @@ -2225,7 +2213,7 @@ else if ( source.getOpCode() == DataExpression.DataOp.WRITE ) { */ private Hop processMultipleReturnBuiltinFunctionExpression(BuiltinFunctionExpression source, ArrayList targetList, HashMap hops) { - + // Construct Hops for all inputs ArrayList inputs = new ArrayList<>(); inputs.add( processExpression(source.getFirstExpr(), null, hops) ); @@ -2235,14 +2223,14 @@ private Hop processMultipleReturnBuiltinFunctionExpression(BuiltinFunctionExpres inputs.add( processExpression(expr[i], null, hops) ); } } - + FunctionType ftype = FunctionType.MULTIRETURN_BUILTIN; String nameSpace = DMLProgram.INTERNAL_NAMESPACE; - + // Create an array list to hold the outputs of this lop. // Exact list of outputs are added based on opcode. ArrayList outputs = new ArrayList<>(); - + // Construct Hop for current builtin function expression based on its type Hop currBuiltinOp = null; switch (source.getOpCode()) { @@ -2255,7 +2243,7 @@ private Hop processMultipleReturnBuiltinFunctionExpression(BuiltinFunctionExpres case BATCH_NORM2D_BACKWARD: case REMOVE: case SVD: - + // Number of outputs = size of targetList = #of identifiers in source.getOutputs String[] outputNames = new String[targetList.size()]; for ( int i=0; i < targetList.size(); i++ ) { @@ -2263,11 +2251,10 @@ private Hop processMultipleReturnBuiltinFunctionExpression(BuiltinFunctionExpres Hop output = new DataOp(outputNames[i], DataType.MATRIX, ValueType.FP64, inputs.get(0), OpOpData.FUNCTIONOUTPUT, inputs.get(0).getFilename()); outputs.add(output); } - + // Create the hop for current function call currBuiltinOp = new FunctionOp(ftype, nameSpace, source.getOpCode().toString(), null, inputs, outputNames, outputs); break; - case COMPRESS: // Number of outputs = size of targetList = #of identifiers in source.getOutputs String[] outputNamesCompress = new String[targetList.size()]; @@ -2279,7 +2266,6 @@ private Hop processMultipleReturnBuiltinFunctionExpression(BuiltinFunctionExpres // Create the hop for current function call currBuiltinOp = new FunctionOp(ftype, nameSpace, source.getOpCode().toString(), null, inputs, outputNamesCompress, outputs); break; - default: throw new ParseException("Invaid Opcode in DMLTranslator:processMultipleReturnBuiltinFunctionExpression(): " + source.getOpCode()); } @@ -2293,7 +2279,7 @@ private Hop processMultipleReturnBuiltinFunctionExpression(BuiltinFunctionExpres return currBuiltinOp; } - + /** * Construct Hops from parse tree : Process BuiltinFunction Expression in an * assignment statement @@ -2317,10 +2303,10 @@ private Hop processBuiltinFunctionExpression(BuiltinFunctionExpression source, D if (source.getThirdExpr() != null) { expr3 = processExpression(source.getThirdExpr(), null, hops); } - + Hop currBuiltinOp = null; target = (target == null) ? createTarget(source) : target; - + // Construct the hop based on the type of Builtin function switch (source.getOpCode()) { @@ -2362,12 +2348,12 @@ private Hop processBuiltinFunctionExpression(BuiltinFunctionExpression source, D currBuiltinOp = new AggUnaryOp(target.getName(), DataType.MATRIX, target.getValueType(), AggOp.MAXINDEX, Direction.Row, expr); break; - + case ROWINDEXMIN: currBuiltinOp = new AggUnaryOp(target.getName(), DataType.MATRIX, target.getValueType(), AggOp.MININDEX, Direction.Row, expr); break; - + case ROWSD: // rowStdDevs = sqrt(rowVariances) currBuiltinOp = new AggUnaryOp(target.getName(), DataType.MATRIX, @@ -2401,17 +2387,17 @@ private Hop processBuiltinFunctionExpression(BuiltinFunctionExpression source, D target.getValueType(), OpOp1.LINEAGE, expr); DMLScript.LINEAGE = true; break; - + case LIST: currBuiltinOp = new NaryOp(target.getName(), DataType.LIST, ValueType.UNKNOWN, OpOpN.LIST, processAllExpressions(source.getAllExpr(), hops)); break; - + case EXISTS: currBuiltinOp = new UnaryOp(target.getName(), DataType.SCALAR, target.getValueType(), OpOp1.EXISTS, expr); break; - + case SUM: case PROD: case VAR: @@ -2455,7 +2441,7 @@ private Hop processBuiltinFunctionExpression(BuiltinFunctionExpression source, D new NaryOp(target.getName(), target.getDataType(), target.getValueType(), OpOpN.valueOf(source.getOpCode().name()), processAllExpressions(source.getAllExpr(), hops)); break; - + case PPRED: String sop = ((StringIdentifier)source.getThirdExpr()).getValue(); sop = sop.replace("\"", ""); @@ -2477,7 +2463,7 @@ else if ( sop.equalsIgnoreCase("!=") ) } currBuiltinOp = new BinaryOp(target.getName(), target.getDataType(), target.getValueType(), operation, expr, expr2); break; - + case TRACE: currBuiltinOp = new AggUnaryOp(target.getName(), DataType.SCALAR, target.getValueType(), AggOp.TRACE, Direction.RowCol, expr); @@ -2489,7 +2475,7 @@ else if ( sop.equalsIgnoreCase("!=") ) currBuiltinOp = new ReorgOp(target.getName(), DataType.MATRIX, target.getValueType(), ReOrgOp.valueOf(source.getOpCode().name()), expr); break; - + case CBIND: case RBIND: OpOp2 appendOp2 = (source.getOpCode()==Builtins.CBIND) ? OpOp2.CBIND : OpOp2.RBIND; @@ -2499,13 +2485,13 @@ else if ( sop.equalsIgnoreCase("!=") ) new NaryOp(target.getName(), target.getDataType(), target.getValueType(), appendOpN, processAllExpressions(source.getAllExpr(), hops)); break; - + case TABLE: - + // Always a TertiaryOp is created for table(). // - create a hop for weights, if not provided in the function call. int numTableArgs = source._args.length; - + switch(numTableArgs) { case 2: case 4: @@ -2517,7 +2503,7 @@ else if ( sop.equalsIgnoreCase("!=") ) weightHop.setDim2(0); weightHop.setNnz(-1); weightHop.setBlocksize(0); - + if ( numTableArgs == 2 ) currBuiltinOp = new TernaryOp(target.getName(), target.getDataType(), target.getValueType(), OpOp3.CTABLE, expr, expr2, weightHop); else { @@ -2527,7 +2513,7 @@ else if ( sop.equalsIgnoreCase("!=") ) OpOp3.CTABLE, expr, expr2, weightHop, outDim1, outDim2, new LiteralOp(true)); } break; - + case 3: case 5: case 6: @@ -2543,7 +2529,7 @@ else if ( sop.equalsIgnoreCase("!=") ) OpOp3.CTABLE, expr, expr2, expr3, outDim1, outDim2, outputEmptyBlocks); } break; - + default: throw new ParseException("Invalid number of arguments "+ numTableArgs + " to table() function."); } @@ -2557,7 +2543,10 @@ else if ( sop.equalsIgnoreCase("!=") ) currBuiltinOp = new UnaryOp(target.getName(), DataType.MATRIX, target.getValueType(), OpOp1.CAST_AS_MATRIX, expr); break; case CAST_AS_FRAME: - currBuiltinOp = new UnaryOp(target.getName(), DataType.FRAME, target.getValueType(), OpOp1.CAST_AS_FRAME, expr); + if(expr2 != null) + currBuiltinOp = new BinaryOp(target.getName(), DataType.FRAME, target.getValueType(), OpOp2.CAST_AS_FRAME, expr, expr2); + else + currBuiltinOp = new UnaryOp(target.getName(), DataType.FRAME, target.getValueType(), OpOp1.CAST_AS_FRAME, expr); break; case CAST_AS_LIST: currBuiltinOp = new UnaryOp(target.getName(), DataType.LIST, target.getValueType(), OpOp1.CAST_AS_LIST, expr); @@ -2633,7 +2622,7 @@ else if ( sop.equalsIgnoreCase("!=") ) target.getValueType(), OpOp3.valueOf(source.getOpCode().name()), expr, expr2, (expr3==null) ? new LiteralOp(0L) : expr3); break; - + case LOG: if (expr2 == null) { OpOp1 mathOp2; @@ -2663,7 +2652,7 @@ else if ( sop.equalsIgnoreCase("!=") ) expr, expr2); } break; - + case MOMENT: case COV: case QUANTILE: @@ -2672,19 +2661,19 @@ else if ( sop.equalsIgnoreCase("!=") ) OpOp2.valueOf(source.getOpCode().name()), expr, expr2) : new TernaryOp(target.getName(), target.getDataType(), target.getValueType(), OpOp3.valueOf(source.getOpCode().name()), expr, expr2,expr3); break; - + case IQM: case MEDIAN: currBuiltinOp = (expr2 == null) ? new UnaryOp(target.getName(), target.getDataType(), target.getValueType(), OpOp1.valueOf(source.getOpCode().name()), expr) : new BinaryOp(target.getName(), target.getDataType(), target.getValueType(), OpOp2.valueOf(source.getOpCode().name()), expr, expr2); break; - + case IFELSE: currBuiltinOp=new TernaryOp(target.getName(), target.getDataType(), target.getValueType(), OpOp3.IFELSE, expr, expr2, expr3); break; - + case SEQ: HashMap randParams = new HashMap<>(); randParams.put(Statement.SEQ_FROM, expr); @@ -2697,18 +2686,18 @@ else if ( sop.equalsIgnoreCase("!=") ) case TIME: currBuiltinOp = new DataGenOp(OpOpDG.TIME, target); break; - + case SAMPLE: { Expression[] in = source.getAllExpr(); - + // arguments: range/size/replace/seed; defaults: replace=FALSE - + HashMap tmpparams = new HashMap<>(); tmpparams.put(DataExpression.RAND_MAX, expr); //range tmpparams.put(DataExpression.RAND_ROWS, expr2); tmpparams.put(DataExpression.RAND_COLS, new LiteralOp(1)); - + if ( in.length == 4 ) { tmpparams.put(DataExpression.RAND_PDF, expr3); @@ -2730,22 +2719,22 @@ else if ( expr3.getValueType() == ValueType.INT64 ) } else throw new HopsException("Invalid input type " + expr3.getValueType() + " in sample()."); - + } else if ( in.length == 2 ) { tmpparams.put(DataExpression.RAND_PDF, new LiteralOp(false)); tmpparams.put(DataExpression.RAND_SEED, new LiteralOp(DataGenOp.UNSPECIFIED_SEED) ); } - + currBuiltinOp = new DataGenOp(OpOpDG.SAMPLE, target, tmpparams); break; } - + case SOLVE: currBuiltinOp = new BinaryOp(target.getName(), target.getDataType(), target.getValueType(), OpOp2.SOLVE, expr, expr2); break; - + case INVERSE: case CHOLESKY: case TYPEOF: @@ -2754,19 +2743,19 @@ else if ( in.length == 2 ) currBuiltinOp = new UnaryOp(target.getName(), target.getDataType(), target.getValueType(), OpOp1.valueOf(source.getOpCode().name()), expr); break; - + case OUTER: if( !(expr3 instanceof LiteralOp) ) throw new HopsException("Operator for outer builtin function must be a constant: "+expr3); OpOp2 op = OpOp2.valueOfByOpcode(((LiteralOp)expr3).getStringValue()); if( op == null ) throw new HopsException("Unsupported outer vector binary operation: "+((LiteralOp)expr3).getStringValue()); - + currBuiltinOp = new BinaryOp(target.getName(), DataType.MATRIX, target.getValueType(), op, expr, expr2); ((BinaryOp)currBuiltinOp).setOuterVectorOperation(true); //flag op as specific outer vector operation currBuiltinOp.refreshSizeInformation(); //force size reevaluation according to 'outer' flag otherwise danger of incorrect dims break; - + case BIASADD: case BIASMULT: { ArrayList inHops1 = new ArrayList<>(); @@ -2813,7 +2802,7 @@ else if ( in.length == 2 ) default: throw new ParseException("Unsupported builtin function type: "+source.getOpCode()); } - + boolean isConvolution = source.getOpCode() == Builtins.CONV2D || source.getOpCode() == Builtins.CONV2D_BACKWARD_DATA || source.getOpCode() == Builtins.CONV2D_BACKWARD_FILTER || source.getOpCode() == Builtins.MAX_POOL || source.getOpCode() == Builtins.MAX_POOL_BACKWARD || @@ -2825,14 +2814,14 @@ else if ( in.length == 2 ) currBuiltinOp.setParseInfo(source); return currBuiltinOp; } - + private Hop[] processAllExpressions(Expression[] expr, HashMap hops) { Hop[] ret = new Hop[expr.length]; for(int i=0; i getALHopsForConvOp(Hop first, BuiltinFunctionExpression s } return ret; } - + public void setIdentifierParams(Hop h, Identifier id) { if( id.getDim1()>= 0 ) h.setDim1(id.getDim1()); @@ -2929,24 +2918,24 @@ public void setIdentifierParams(Hop h, Identifier id) { private boolean prepareReadAfterWrite( DMLProgram prog, HashMap pWrites ) { boolean ret = false; - + //process functions /*MB: for the moment we only support read-after-write in the main program for( FunctionStatementBlock fsb : prog.getFunctionStatementBlocks() ) ret |= prepareReadAfterWrite(fsb, pWrites); */ - + //process main program for( StatementBlock sb : prog.getStatementBlocks() ) ret |= prepareReadAfterWrite(sb, pWrites); - + return ret; } - + private boolean prepareReadAfterWrite( StatementBlock sb, HashMap pWrites ) { boolean ret = false; - + if(sb instanceof FunctionStatementBlock) { FunctionStatementBlock fsb = (FunctionStatementBlock) sb; @@ -2960,7 +2949,7 @@ else if(sb instanceof WhileStatementBlock) WhileStatement wstmt = (WhileStatement)wsb.getStatement(0); for (StatementBlock csb : wstmt.getBody()) ret |= prepareReadAfterWrite(csb, pWrites); - } + } else if(sb instanceof IfStatementBlock) { IfStatementBlock isb = (IfStatementBlock) sb; @@ -3018,7 +3007,7 @@ else if( s instanceof AssignmentStatement } } } - + return ret; } } diff --git a/src/main/java/org/apache/sysds/parser/ExpressionList.java b/src/main/java/org/apache/sysds/parser/ExpressionList.java index 90d50c6a96b..55788fa2c35 100644 --- a/src/main/java/org/apache/sysds/parser/ExpressionList.java +++ b/src/main/java/org/apache/sysds/parser/ExpressionList.java @@ -48,6 +48,10 @@ public void setValue(ArrayList _value) { this._value = _value; } + public Identifier getOutput() { + return new ListIdentifier(); + } + @Override public void validateExpression(HashMap ids, HashMap currConstVars, boolean conditional) { diff --git a/src/main/java/org/apache/sysds/parser/ListIdentifier.java b/src/main/java/org/apache/sysds/parser/ListIdentifier.java new file mode 100644 index 00000000000..6b314c0c3e9 --- /dev/null +++ b/src/main/java/org/apache/sysds/parser/ListIdentifier.java @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sysds.parser; + +import org.apache.sysds.common.Types.DataType; +import org.apache.sysds.common.Types.ValueType; + +public class ListIdentifier extends Identifier { + + public ListIdentifier(){ + _dim1 = -1; + _dim2 = -1; + _dataType = DataType.LIST; + _valueType = ValueType.UNKNOWN; + _blocksize = -1; + _nnz = -1; + setOutput(this); + _format = null; + } + + @Override + public Expression rewriteExpression(String prefix) { + throw new UnsupportedOperationException("Unimplemented method 'rewriteExpression'"); + } + + @Override + public VariableSet variablesRead() { + throw new UnsupportedOperationException("Unimplemented method 'variablesRead'"); + } + + @Override + public VariableSet variablesUpdated() { + throw new UnsupportedOperationException("Unimplemented method 'variablesUpdated'"); + } + +} diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/FrameBlock.java b/src/main/java/org/apache/sysds/runtime/frame/data/FrameBlock.java index 737922397e5..e1dbb538630 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/FrameBlock.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/FrameBlock.java @@ -737,15 +737,16 @@ public ValueType getColumnType(int c) { } public Array getColumn(int c) { - return _coldata[c]; + return _coldata != null ? _coldata[c] : null; } public void setColumn(int c, Array column) { if(_coldata == null) { _coldata = new Array[getNumColumns()]; - _nRow = column.size(); + if(column != null) + _nRow = column.size(); } - if(column.size() != _nRow) + else if(column != null && column.size() != _nRow) throw new DMLRuntimeException("Invalid number of rows in set column"); _coldata[c] = column; _msize = -1; diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/columns/Array.java b/src/main/java/org/apache/sysds/runtime/frame/data/columns/Array.java index 7f6698ef18c..68ef739e65c 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/columns/Array.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/columns/Array.java @@ -42,8 +42,6 @@ */ public abstract class Array implements Writable { protected static final Log LOG = LogFactory.getLog(Array.class.getName()); - /** internal configuration */ - private static final boolean REUSE_RECODE_MAPS = true; /** A soft reference to a memorization of this arrays mapping, used in transformEncode */ protected SoftReference> _rcdMapCache = null; @@ -79,29 +77,34 @@ public final void setCache(SoftReference> m) { _rcdMapCache = m; } - public Map getRecodeMap() { + /** + * Get a recode map that maps each unique value in the array, to a long ID. Null values are ignored, and not included + * in the mapping. The resulting recode map in stored in a soft reference to speed up repeated calls to the same column. + * + * @return A recode map + */ + public synchronized final Map getRecodeMap() { // probe cache for existing map - if(REUSE_RECODE_MAPS) { - SoftReference> tmp = getCache(); - Map map = (tmp != null) ? tmp.get() : null; - if(map != null) - return map; - } + Map map; + SoftReference> tmp = getCache(); + map = (tmp != null) ? tmp.get() : null; + if(map != null) + return map; // construct recode map - Map map = createRecodeMap(); + map = createRecodeMap(); // put created map into cache - if(REUSE_RECODE_MAPS) - setCache(new SoftReference<>(map)); + setCache(new SoftReference<>(map)); return map; } /** - * Recreate the recode map from what is already there. + * Recreate the recode map from what is inside array. This is an internal method for arrays, and the result is cached + * in the main class of the arrays. * - * @return + * @return The recode map */ protected Map createRecodeMap() { Map map = new HashMap<>(); @@ -123,13 +126,13 @@ protected Map createRecodeMap() { * @return a dictionary containing all unique values. */ protected Map getDictionary() { - Map dict = new HashMap<>(); - int id = 0; + final Map dict = new HashMap<>(); + Integer id = 0; for(int i = 0; i < size(); i++) { - T val = get(i); - Integer v = dict.putIfAbsent(val, id); + final T val = get(i); + final Integer v = dict.get(val); if(v == null) - id++; + dict.put(val, id++); } return dict; @@ -147,8 +150,8 @@ public final int size() { /** * Get the value at a given index. * - * This method returns objects that have a high overhead in allocation. Therefore it is not as efficient as using - * the vectorized operations specified in the object. + * This method returns objects that have a high overhead in allocation. Therefore it is not as efficient as using the + * vectorized operations specified in the object. * * @param index The index to query * @return The value returned as an object @@ -230,7 +233,10 @@ public double getAsNaNDouble(int i) { * @param ru row upper (inclusive) * @param value value array to take values from (same type) */ - public abstract void set(int rl, int ru, Array value); + public void set(int rl, int ru, Array value){ + for(int i = rl; i <= ru; i++) + set(i, value.get(i)); + } /** * Set range to given arrays value with an offset into other array @@ -240,7 +246,10 @@ public double getAsNaNDouble(int i) { * @param value value array to take values from * @param rlSrc the offset into the value array to take values from */ - public abstract void set(int rl, int ru, Array value, int rlSrc); + public void set(int rl, int ru, Array value, int rlSrc){ + for(int i = rl, off = rlSrc; i <= ru; i++, off++) + set(i, value.get(off)); + } /** * Set non default values from the value array given @@ -627,12 +636,10 @@ public ArrayIterator getIterator() { @Override @SuppressWarnings("unchecked") public boolean equals(Object other) { - try { - return other instanceof Array && this.equals((Array) other); - } - catch(ClassCastException e) { - return false; - } + return other instanceof Array && // + ((Array) other).getValueType() == this.getValueType() && // + this.equals((Array) other); + } public abstract boolean equals(Array other); diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/columns/ArrayFactory.java b/src/main/java/org/apache/sysds/runtime/frame/data/columns/ArrayFactory.java index b0de2460d9f..12ca401c6b9 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/columns/ArrayFactory.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/columns/ArrayFactory.java @@ -161,10 +161,6 @@ public static Array allocateOptional(ValueType v, int nRow) { } } - public static DDCArray allocateDDC(DDCArray start, int nRow) { - return start.allocateLarger(nRow); - } - public static ABooleanArray allocateBoolean(int nRow) { if(nRow > bitSetSwitchPoint) return new BitSetArray(nRow); @@ -239,7 +235,7 @@ public static Array read(DataInput in, int nRow) throws IOException { * * @param The type to return, java automatically make this Object, and this is fine. * @param a The first array to append to (potentially modifying this a if applicable) - * @param b THe array to append to a, (not getting modified). + * @param b The array to append to a, (not getting modified). * @return A array containing the concatenation of the two. */ @SuppressWarnings("unchecked") @@ -260,53 +256,57 @@ public static Array append(Array a, Array b) { * Set the target array in the range of rl to ru with the src array. The type returned is the common or highest * common type of array. * - * @param THe highest common type to return. - * @param target The target to pout the values into + * @param The highest common type to return. + * @param target The target to put the values into * @param src The source to take the values from * @param rl The index to start on - * @param ru The index to end on + * @param ru The index to end on (inclusive) * @param rlen The length of the target (a parameter in case target is null) * @return A new or modified array. */ @SuppressWarnings("unchecked") public static Array set(Array target, Array src, int rl, int ru, int rlen) { - try { - - if(target == null) { - - if(src.getFrameArrayType() == FrameArrayType.OPTIONAL) + + if(rlen <= ru) + throw new DMLRuntimeException("Invalid range ru: " + ru + " should be less than rlen: " + rlen); + else if(rl < 0) + throw new DMLRuntimeException("Invalid rl is less than zero"); + else if(src == null) + throw new NullPointerException("Invalid src, cannot be null"); + else if(ru - rl > src.size()) + throw new DMLRuntimeException("Invalid range length to big: " + src.size() + " vs range: " + (ru - rl)); + else if(target != null && target.size() < rlen) + throw new DMLRuntimeException("Invalid allocated target is not large enough"); + + if(target == null) { // if target is not specified. allocate one. + if(src.getFrameArrayType() == FrameArrayType.OPTIONAL) + target = allocateOptional(src.getValueType(), rlen); + else if(src.getFrameArrayType() == FrameArrayType.DDC) { + Array ddcDict = ((DDCArray) src).getDict(); + if(ddcDict.getFrameArrayType() == FrameArrayType.OPTIONAL) { target = allocateOptional(src.getValueType(), rlen); - else if(src.getFrameArrayType() == FrameArrayType.DDC) - target = allocateDDC((DDCArray) src, rlen); - else + } + else { target = allocate(src.getValueType(), rlen); - - if(rlen == ru) - throw new DMLRuntimeException("Invalid length to set"); - } - else if(target.getFrameArrayType() != FrameArrayType.OPTIONAL // - && src.getFrameArrayType() == FrameArrayType.OPTIONAL) { - target = new OptionalArray<>(target, false); + } } - - if(target.size() < rlen) { - throw new DMLRuntimeException("Invalid allocated target is not large enough"); - } - - final ValueType ta = target.getValueType(); - final ValueType tb = src.getValueType(); - final ValueType tc = ValueType.getHighestCommonType(ta, tb); - - Array targetC = (Array) (ta != tc ? target.changeType(tc) : target); - Array srcC = (Array) (tb != tc ? src.changeType(tc) : src); - targetC.set(rl, ru, srcC); - return targetC; + else + target = allocate(src.getValueType(), rlen); } - catch(Exception e) { - throw new DMLRuntimeException( - "Failed to set subpart with: \n\n" + target + "\n\n" + src + " \n\n " + rl + " " + ru + " " + rlen, e); + else if(target.getFrameArrayType() != FrameArrayType.OPTIONAL // + && src.getFrameArrayType() == FrameArrayType.OPTIONAL) { + target = new OptionalArray<>(target, false); } + final ValueType ta = target.getValueType(); + final ValueType tb = src.getValueType(); + final ValueType tc = ValueType.getHighestCommonType(ta, tb); + + Array targetC = (Array) (ta != tc ? target.changeType(tc) : target); + Array srcC = (Array) (tb != tc ? src.changeType(tc) : src); + targetC.set(rl, ru, srcC); + return targetC; + } public static Object parseString(String s, ValueType v) { diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/columns/BitSetArray.java b/src/main/java/org/apache/sysds/runtime/frame/data/columns/BitSetArray.java index 3adf9d2fa11..dbd5d7328c3 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/columns/BitSetArray.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/columns/BitSetArray.java @@ -42,6 +42,8 @@ public class BitSetArray extends ABooleanArray { /** Vectorized "words" containing all the bits set */ protected long[] _data; + private volatile int allTrue = -1; + protected BitSetArray(int size) { this(new long[longSize(size)], size); } @@ -104,7 +106,7 @@ public synchronized void set(int index, boolean value) { @Override public void set(int index, double value) { - set(index, value == 1.0); + set(index, Math.round(value) == 1.0); } @Override @@ -135,11 +137,19 @@ private static long[] toLongArrayPadded(BitSet data, int minLength) { @Override public void set(int rl, int ru, Array value, int rlSrc) { - if(useVectorizedKernel && value instanceof BitSetArray && (ru - rl >= 64)) - setVectorized(rl, ru, (BitSetArray) value, rlSrc); + if(useVectorizedKernel && value instanceof BitSetArray && (ru - rl >= 64)){ + try { + // try system array copy. + // but if it does not work, default to get. + setVectorized(rl, ru, (BitSetArray) value, rlSrc); + return; + } + catch(Exception e) { + // do nothing + } + } else // default - for(int i = rl, off = rlSrc; i <= ru; i++, off++) - set(i, value.get(off)); + super.set(rl,ru,value, rlSrc); } private void setVectorized(int rl, int ru, BitSetArray value, int rlSrc) { @@ -502,9 +512,15 @@ public boolean isEmpty() { @Override public boolean isAllTrue() { + if(allTrue != -1) + return allTrue ==1; + for(int i = 0; i < _data.length; i++) - if(_data[i] != -1L) + if(_data[i] != -1L){ + allTrue = 0; return false; + } + allTrue = 1; return true; } diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/columns/BooleanArray.java b/src/main/java/org/apache/sysds/runtime/frame/data/columns/BooleanArray.java index f34d95233fc..da874555d33 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/columns/BooleanArray.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/columns/BooleanArray.java @@ -57,7 +57,7 @@ public void set(int index, Boolean value) { @Override public void set(int index, double value) { - _data[index] = value == 1.0; + _data[index] = (Math.round(value) == 1.0); } @Override @@ -79,11 +79,19 @@ public void setFromOtherType(int rl, int ru, Array value) { @Override public void set(int rl, int ru, Array value, int rlSrc) { - if(value instanceof BooleanArray) - System.arraycopy(value.get(), rlSrc, _data, rl, ru - rl + 1); - else - for(int i = rl, off = rlSrc; i <= ru; i++, off++) - _data[i] = value.get(off); + if(value instanceof BooleanArray){ + try { + // try system array copy. + // but if it does not work, default to get. + System.arraycopy(value.get(), rlSrc, _data, rl, ru - rl + 1); + return; + } + catch(Exception e) { + // go default + } + } + + super.set(rl, ru, value, rlSrc); } @Override diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/columns/CharArray.java b/src/main/java/org/apache/sysds/runtime/frame/data/columns/CharArray.java index 47ec83c7884..9862974ad77 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/columns/CharArray.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/columns/CharArray.java @@ -98,7 +98,14 @@ public void set(int rl, int ru, Array value) { @Override public void set(int rl, int ru, Array value, int rlSrc) { - System.arraycopy(((CharArray) value)._data, rlSrc, _data, rl, ru - rl + 1); + try { + // try system array copy. + // but if it does not work, default to get. + System.arraycopy(value.get(), rlSrc, _data, rl, ru - rl + 1); + } + catch(Exception e) { + super.set(rl, ru, value, rlSrc); + } } @Override diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/columns/DDCArray.java b/src/main/java/org/apache/sysds/runtime/frame/data/columns/DDCArray.java index 9b49839721a..4ddc3e4367c 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/columns/DDCArray.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/columns/DDCArray.java @@ -55,6 +55,10 @@ public DDCArray(Array dict, AMapToData map) { } } + protected Array getDict(){ + return dict; + } + /** * Try to compress array into DDC format. * @@ -64,10 +68,11 @@ public DDCArray(Array dict, AMapToData map) { */ @SuppressWarnings("unchecked") public static Array compressToDDC(Array arr) { + final int s = arr.size(); // Early aborts // if the size is small do not consider // or if the instance if RaggedArray where all values typically are unique. - if(arr.size() <= 10 || arr instanceof RaggedArray) + if(s <= 10 || arr instanceof RaggedArray) return arr; // Two pass algorithm @@ -75,7 +80,7 @@ public static Array compressToDDC(Array arr) { Map rcd = arr.getDictionary(); // Abort if there are to many unique values. - if(rcd.size() > arr.size() / 2) + if(rcd.size() > s / 2) return arr; // Allocate the correct dictionary output @@ -90,8 +95,8 @@ public static Array compressToDDC(Array arr) { ar.set(e.getValue(), e.getKey()); // 2. full iteration: Make map - AMapToData m = MapToFactory.create(arr.size(), rcd.size()); - for(int i = 0; i < arr.size(); i++) + final AMapToData m = MapToFactory.create(s, rcd.size()); + for(int i = 0; i < s; i++) m.set(i, rcd.get(arr.get(i))); return new DDCArray<>(ar, m); diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/columns/DoubleArray.java b/src/main/java/org/apache/sysds/runtime/frame/data/columns/DoubleArray.java index 7abee26fdc7..754748a28b3 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/columns/DoubleArray.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/columns/DoubleArray.java @@ -35,6 +35,8 @@ import org.apache.sysds.runtime.util.UtilFunctions; import org.apache.sysds.utils.MemoryEstimates; +import ch.randelshofer.fastdoubleparser.JavaDoubleParser; + public class DoubleArray extends Array { private double[] _data; @@ -81,7 +83,14 @@ public void setFromOtherType(int rl, int ru, Array value) { @Override public void set(int rl, int ru, Array value, int rlSrc) { - System.arraycopy(((DoubleArray) value)._data, rlSrc, _data, rl, ru - rl + 1); + try { + // try system array copy. + // but if it does not work, default to get. + System.arraycopy(value.get(), rlSrc, _data, rl, ru - rl + 1); + } + catch(Exception e) { + super.set(rl, ru, value, rlSrc); + } } @Override @@ -186,7 +195,8 @@ public Pair analyzeValueType() { case FP32: switch(c) { case FP64: - state = c; break; + state = c; + break; default: } break; @@ -194,7 +204,8 @@ public Pair analyzeValueType() { switch(c) { case FP64: case FP32: - state = c; break; + state = c; + break; default: } break; @@ -203,7 +214,8 @@ public Pair analyzeValueType() { case FP64: case FP32: case INT64: - state = c; break; + state = c; + break; default: } break; @@ -214,7 +226,8 @@ public Pair analyzeValueType() { case FP32: case INT64: case INT32: - state = c; break; + state = c; + break; default: } break; @@ -332,10 +345,20 @@ public double getAsDouble(int i) { } public static double parseDouble(String value) { - if(value == null || value.isEmpty()) - return 0.0; - else - return Double.parseDouble(value); + try { + if(value == null || value.isEmpty()) + return 0.0; + return JavaDoubleParser.parseDouble(value); + } + catch(NumberFormatException e) { + final int len = value.length(); + // check for common extra cases. + if(len == 3 && value.compareToIgnoreCase("Inf") == 0) + return Double.POSITIVE_INFINITY; + else if(len == 4 && value.compareToIgnoreCase("-Inf") == 0) + return Double.NEGATIVE_INFINITY; + throw new DMLRuntimeException(e); + } } @Override @@ -388,7 +411,7 @@ public boolean equals(Array other) { } @Override - public boolean possiblyContainsNaN(){ + public boolean possiblyContainsNaN() { return true; } diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/columns/FloatArray.java b/src/main/java/org/apache/sysds/runtime/frame/data/columns/FloatArray.java index 6eb7885a4d5..51d29b167db 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/columns/FloatArray.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/columns/FloatArray.java @@ -80,7 +80,14 @@ public void setFromOtherType(int rl, int ru, Array value) { @Override public void set(int rl, int ru, Array value, int rlSrc) { - System.arraycopy(((FloatArray) value)._data, rlSrc, _data, rl, ru - rl + 1); + try { + // try system array copy. + // but if it does not work, default to get. + System.arraycopy(value.get(), rlSrc, _data, rl, ru - rl + 1); + } + catch(Exception e) { + super.set(rl, ru, value, rlSrc); + } } @Override @@ -284,10 +291,20 @@ public double getAsDouble(int i) { } public static float parseFloat(String value) { - if(value == null || value.isEmpty()) - return 0.0f; - else + try { + if(value == null || value.isEmpty()) + return 0.0f; return Float.parseFloat(value); + } + catch(NumberFormatException e) { + final int len = value.length(); + // check for common extra cases. + if(len == 3 && value.compareToIgnoreCase("Inf") == 0) + return Float.POSITIVE_INFINITY; + else if(len == 4 && value.compareToIgnoreCase("-Inf") == 0) + return Float.NEGATIVE_INFINITY; + throw new DMLRuntimeException(e); + } } @Override @@ -340,7 +357,7 @@ public boolean equals(Array other) { } @Override - public boolean possiblyContainsNaN(){ + public boolean possiblyContainsNaN() { return true; } diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/columns/IntegerArray.java b/src/main/java/org/apache/sysds/runtime/frame/data/columns/IntegerArray.java index 6ebd3d9a844..df60803ddad 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/columns/IntegerArray.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/columns/IntegerArray.java @@ -81,7 +81,14 @@ public void setFromOtherType(int rl, int ru, Array value) { @Override public void set(int rl, int ru, Array value, int rlSrc) { - System.arraycopy(((IntegerArray) value)._data, rlSrc, _data, rl, ru - rl + 1); + try { + // try system array copy. + // but if it does not work, default to get. + System.arraycopy(value.get(), rlSrc, _data, rl, ru - rl + 1); + } + catch(Exception e) { + super.set(rl, ru, value, rlSrc); + } } @Override diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/columns/LongArray.java b/src/main/java/org/apache/sysds/runtime/frame/data/columns/LongArray.java index b46d86da1e6..c1e0fe06c9b 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/columns/LongArray.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/columns/LongArray.java @@ -80,7 +80,14 @@ public void setFromOtherType(int rl, int ru, Array value) { @Override public void set(int rl, int ru, Array value, int rlSrc) { - System.arraycopy(value.get(), rlSrc, _data, rl, ru - rl + 1); + try { + // try system array copy. + // but if it does not work, default to get. + System.arraycopy(value.get(), rlSrc, _data, rl, ru - rl + 1); + } + catch(Exception e) { + super.set(rl, ru, value, rlSrc); + } } @Override diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/columns/RaggedArray.java b/src/main/java/org/apache/sysds/runtime/frame/data/columns/RaggedArray.java index 450a9efa456..a63026b1484 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/columns/RaggedArray.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/columns/RaggedArray.java @@ -25,6 +25,7 @@ import org.apache.commons.lang3.NotImplementedException; import org.apache.sysds.common.Types.ValueType; +import org.apache.sysds.runtime.DMLRuntimeException; import org.apache.sysds.runtime.frame.data.columns.ArrayFactory.FrameArrayType; import org.apache.sysds.runtime.matrix.data.Pair; @@ -75,14 +76,11 @@ public void write(DataOutput out) throws IOException { } @Override - @SuppressWarnings("unchecked") public void readFields(DataInput in) throws IOException { - _size = in.readInt(); - _a = (Array) ArrayFactory.read(in, in.readInt()); + throw new DMLRuntimeException("Should not be called"); } protected static RaggedArray readRagged(DataInput in, int nRow) throws IOException { - int m = in.readInt(); final Array a = ArrayFactory.read(in, in.readInt()); return new RaggedArray<>(a, m); @@ -154,19 +152,7 @@ public void setFromOtherType(int rl, int ru, Array value) { @Override public void set(int rl, int ru, Array value) { - if(rl >= 0 && rl < _a._size && ru < _a._size) - if(value instanceof RaggedArray) - _a.set(rl, ru, ((RaggedArray) value).getInnerArray()); - else if(_a.getClass() == value.getClass()) - _a.set(rl, ru, value); - else - throw new RuntimeException( - "RaggedArray set: value type should be same to RaggedArray type " + _a.getClass()); - else if(rl >= 0 && rl < super.size() && ru < super.size()) { - _a.reset(rl + 1); - _a.set(rl, ru, value); - LOG.warn("Reallocated ragged array"); - } + set(rl, ru, value, 0); } @Override @@ -177,7 +163,7 @@ public void set(int rl, int ru, Array value, int rlSrc) { else if(_a.getClass() == value.getClass()) _a.set(rl, ru, value, rlSrc); else - throw new RuntimeException( + throw new DMLRuntimeException( "RaggedArray set: value type should be same to RaggedArray type " + _a.getClass()); } @@ -382,7 +368,16 @@ public double hashDouble(int idx) { @Override public boolean equals(Array other) { - throw new NotImplementedException("Unimplemented method 'equals'"); + if(other._size == this._size && // + other.getValueType() == this.getValueType() && // + other instanceof RaggedArray) { + if(other == this){// same pointer + return true; + } + RaggedArray ot = (RaggedArray) other; + return ot._a.equals(this._a); + } + return false; } @Override @@ -396,7 +391,7 @@ public boolean containsNull() { } @Override - public boolean possiblyContainsNaN(){ + public boolean possiblyContainsNaN() { return true; } diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/columns/StringArray.java b/src/main/java/org/apache/sysds/runtime/frame/data/columns/StringArray.java index 9f0a3596440..ef66a046f74 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/columns/StringArray.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/columns/StringArray.java @@ -26,6 +26,7 @@ import java.util.BitSet; import java.util.HashMap; import java.util.Map; +import java.util.regex.Pattern; import org.apache.commons.lang3.NotImplementedException; import org.apache.sysds.common.Types.ValueType; @@ -37,9 +38,6 @@ import org.apache.sysds.runtime.transform.encode.ColumnEncoderRecode; import org.apache.sysds.utils.MemoryEstimates; -import ch.randelshofer.fastdoubleparser.JavaDoubleParser; -import ch.randelshofer.fastdoubleparser.JavaFloatParser; - public class StringArray extends Array { private String[] _data; @@ -79,7 +77,6 @@ public void set(int index, double value) { @Override public void set(int rl, int ru, Array value) { set(rl, ru, value, 0); - materializedSize = -1; } @Override @@ -96,8 +93,17 @@ public void setFromOtherType(int rl, int ru, Array value) { @Override public void set(int rl, int ru, Array value, int rlSrc) { - System.arraycopy(((StringArray) value)._data, rlSrc, _data, rl, ru - rl + 1); - materializedSize = -1; + try { + // try system array copy. + // but if it does not work, default to get. + System.arraycopy(value.get(), rlSrc, _data, rl, ru - rl + 1); + } + catch(Exception e) { + super.set(rl, ru, value, rlSrc); + } + finally{ + materializedSize = -1; + } } @Override @@ -249,8 +255,9 @@ public Pair analyzeValueType() { boolean nulls = false; for(int i = 0; i < _size; i++) { final ValueType c = FrameUtil.isType(_data[i], state); - if(c == ValueType.STRING) // early termination + if(c == ValueType.STRING) { return new Pair<>(ValueType.STRING, false); + } else if(c == ValueType.UNKNOWN) nulls = true; else @@ -311,10 +318,10 @@ protected Array changeTypeBoolean() { return ArrayFactory.allocateBoolean(size()); else if(firstNN.toLowerCase().equals("true") || firstNN.toLowerCase().equals("false")) return changeTypeBooleanStandard(); - else if(firstNN.equals("0") || firstNN.equals("1")) + else if(firstNN.equals("0") || firstNN.equals("1") || firstNN.equals("1.0") || firstNN.equals("0.0")) return changeTypeBooleanNumeric(); - else if(firstNN.equals("0.0") || firstNN.equals("1.0")) - return changeTypeBooleanFloat(); + // else if(firstNN.equals("0.0") || firstNN.equals("1.0")) + // return changeTypeBooleanFloat(); else if(firstNN.toLowerCase().equals("t") || firstNN.toLowerCase().equals("f")) return changeTypeBooleanCharacter(); else @@ -392,13 +399,23 @@ protected Array changeTypeBooleanNumericBitSet() { for(int i = 0; i < size(); i++) { final String s = _data[i]; if(s != null) { + if(s.length() > 1) { + final boolean zero = _data[i].equals("0.0"); + final boolean one = _data[i].equals("1.0"); + if(zero | one) + ret.set(i, one); + else + throw new DMLRuntimeException("Unable to change to Boolean from String array, value:" + _data[i]); - final boolean zero = _data[i].equals("0"); - final boolean one = _data[i].equals("1"); - if(zero | one) - ret.set(i, one); - else - throw new DMLRuntimeException("Unable to change to Boolean from String array, value:" + _data[i]); + } + else { + final boolean zero = _data[i].charAt(0) == '0'; + final boolean one = _data[i].charAt(0) == '1'; + if(zero | one) + ret.set(i, one); + else + throw new DMLRuntimeException("Unable to change to Boolean from String array, value:" + _data[i]); + } } } return new BitSetArray(ret, size()); @@ -409,105 +426,91 @@ protected Array changeTypeBooleanNumericArray() { for(int i = 0; i < size(); i++) { final String s = _data[i]; if(s != null) { - final boolean zero = _data[i].equals("0"); - final boolean one = _data[i].equals("1"); - if(zero | one) - ret[i] = one; - else - throw new DMLRuntimeException("Unable to change to Boolean from String array, value:" + _data[i]); + if(s.length() > 1) { + final boolean zero = _data[i].equals("0.0"); + final boolean one = _data[i].equals("1.0"); + if(zero | one) + ret[i] = one; + else + throw new DMLRuntimeException("Unable to change to Boolean from String array, value:" + _data[i]); + + } + else { + final boolean zero = _data[i].charAt(0) == '0'; + final boolean one = _data[i].charAt(0) == '1'; + if(zero | one) + ret[i] = one; + else + throw new DMLRuntimeException("Unable to change to Boolean from String array, value:" + _data[i]); + } } } return new BooleanArray(ret); } - protected Array changeTypeBooleanFloat() { - if(size() > ArrayFactory.bitSetSwitchPoint) - return changeTypeBooleanFloatBitSet(); - else - return changeTypeBooleanFloatArray(); + @Override + protected Array changeTypeDouble() { + double[] ret = new double[size()]; + for(int i = 0; i < size(); i++) + ret[i] = DoubleArray.parseDouble(_data[i]); + return new DoubleArray(ret); } - protected Array changeTypeBooleanFloatBitSet() { - BitSet ret = new BitSet(size()); - for(int i = 0; i < size(); i++) { - final String s = _data[i]; - if(s != null) { + @Override + protected Array changeTypeFloat() { + float[] ret = new float[size()]; + for(int i = 0; i < size(); i++) + ret[i] = FloatArray.parseFloat(_data[i]); + return new FloatArray(ret); + } - final boolean zero = _data[i].equals("0.0"); - final boolean one = _data[i].equals("1.0"); - if(zero | one) - ret.set(i, one); - else - throw new DMLRuntimeException("Unable to change to Boolean from String array, value:" + _data[i]); - } + @Override + protected Array changeTypeInteger() { + String firstNN = _data[0]; + int i = 1; + while(firstNN == null && i < size()) { + firstNN = _data[i++]; } - return new BitSetArray(ret, size()); + if(firstNN == null) + throw new DMLRuntimeException("Invalid change to int on all null"); + else if(firstNN.contains(".")) + return changeTypeIntegerFloatString(); + else + return changeTypeIntegerNormal(); } - protected Array changeTypeBooleanFloatArray() { - boolean[] ret = new boolean[size()]; + protected Array changeTypeIntegerFloatString() { + int[] ret = new int[size()]; + Pattern p = Pattern.compile("\\."); for(int i = 0; i < size(); i++) { final String s = _data[i]; - if(s != null) { - final boolean zero = _data[i].equals("0.0"); - final boolean one = _data[i].equals("1.0"); - if(zero | one) - ret[i] = one; - else - throw new DMLRuntimeException("Unable to change to Boolean from String array, value:" + _data[i]); - } - - } - return new BooleanArray(ret); - } - - @Override - protected Array changeTypeDouble() { - try { - double[] ret = new double[size()]; - for(int i = 0; i < size(); i++) { - final String s = _data[i]; + try { if(s != null) - ret[i] = JavaDoubleParser.parseDouble(s); + ret[i] = Integer.parseInt(p.split(s, 2)[0]); } - return new DoubleArray(ret); - } - catch(NumberFormatException e) { - throw new DMLRuntimeException("Unable to change to Double from String array", e); - } - } + catch(NumberFormatException e) { + + throw new DMLRuntimeException("Unable to change to Integer from String array", e); - @Override - protected Array changeTypeFloat() { - try { - float[] ret = new float[size()]; - for(int i = 0; i < size(); i++) { - final String s = _data[i]; - if(s != null) - ret[i] = JavaFloatParser.parseFloat(s); } - return new FloatArray(ret); - } - catch(NumberFormatException e) { - throw new DMLRuntimeException("Unable to change to Float from String array", e); } + return new IntegerArray(ret); } - @Override - protected Array changeTypeInteger() { - try { - int[] ret = new int[size()]; - for(int i = 0; i < size(); i++) { - final String s = _data[i]; + protected Array changeTypeIntegerNormal() { + int[] ret = new int[size()]; + for(int i = 0; i < size(); i++) { + final String s = _data[i]; + try { if(s != null) ret[i] = Integer.parseInt(s); } - return new IntegerArray(ret); - } - catch(NumberFormatException e) { - throw new DMLRuntimeException("Unable to change to Integer from String array", e); + catch(NumberFormatException e) { + throw new DMLRuntimeException("Unable to change to Integer from String array", e); + } } + return new IntegerArray(ret); } @Override @@ -655,8 +658,6 @@ protected Map createRecodeMap() { String[] tmp = ColumnEncoderRecode.splitRecodeMapEntry(val.toString()); map.put(tmp[0], Long.parseLong(tmp[1])); } - else // once we hit null return. - break; } return map; } @@ -682,7 +683,7 @@ public boolean equals(Array other) { } @Override - public boolean possiblyContainsNaN(){ + public boolean possiblyContainsNaN() { return true; } diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/compress/CompressedFrameBlockFactory.java b/src/main/java/org/apache/sysds/runtime/frame/data/compress/CompressedFrameBlockFactory.java index 90f24b527cb..482e6a129e3 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/compress/CompressedFrameBlockFactory.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/compress/CompressedFrameBlockFactory.java @@ -19,12 +19,18 @@ package org.apache.sysds.runtime.frame.data.compress; +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Future; + import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.apache.sysds.runtime.compress.workload.WTreeRoot; import org.apache.sysds.runtime.frame.data.FrameBlock; import org.apache.sysds.runtime.frame.data.columns.Array; import org.apache.sysds.runtime.frame.data.columns.DDCArray; +import org.apache.sysds.runtime.util.CommonThreadPool; public class CompressedFrameBlockFactory { @@ -35,11 +41,16 @@ public class CompressedFrameBlockFactory { private final ArrayCompressionStatistics[] stats; private final Array[] compressedColumns; + private final int nSamples; + private CompressedFrameBlockFactory(FrameBlock fb, FrameCompressionSettings cs) { this.in = fb; this.cs = cs; this.stats = new ArrayCompressionStatistics[in.getNumColumns()]; this.compressedColumns = new Array[in.getNumColumns()]; + + this.nSamples = Math.min(in.getNumRows(), (int) Math.ceil(in.getNumRows() * cs.sampleRatio)); + } public static FrameBlock compress(FrameBlock fb) { @@ -62,37 +73,61 @@ public static FrameBlock compress(FrameBlock fb, FrameCompressionSettings cs) { } private FrameBlock compressFrame() { - extractStatistics(); - logStatistics(); encodeColumns(); final FrameBlock ret = new FrameBlock(compressedColumns, in.getColumnNames(false)); + logStatistics(); logRet(ret); return ret; } - private void extractStatistics() { - final int nSamples = Math.min(in.getNumRows(), (int) Math.ceil(in.getNumRows() * cs.sampleRatio)); - for(int i = 0; i < stats.length; i++) { - stats[i] = in.getColumn(i).statistics(nSamples); - } + private void encodeColumns() { + if(cs.k > 1) + encodeParallel(); + else + encodeSingleThread(); } - private void encodeColumns() { - for(int i = 0; i < compressedColumns.length; i++) { - if(stats[i] != null) { - // commented out because no other encodings are supported yet - // switch(stats[i].bestType) { - // case DDC: - compressedColumns[i] = DDCArray.compressToDDC(in.getColumn(i)); - // break; - // default: - // compressedColumns[i] = in.getColumn(i); - // break; - // } + private void encodeSingleThread() { + for(int i = 0; i < compressedColumns.length; i++) + compressCol(i); + } + + private void encodeParallel() { + ExecutorService pool = CommonThreadPool.get(cs.k); + try { + List> tasks = new ArrayList<>(); + for(int i = 0; i < compressedColumns.length; i++) { + final int l = i; + tasks.add(pool.submit(() -> compressCol(l))); } - else - compressedColumns[i] = in.getColumn(i); + + for(Future t : tasks) + t.get(); + + } + catch(Exception e) { + throw new RuntimeException(e); + } + finally { + pool.shutdown(); + } + } + + private void compressCol(int i) { + stats[i] = in.getColumn(i).statistics(nSamples); + if(stats[i] != null) { + // commented out because no other encodings are supported yet + // switch(stats[i].bestType) { + // case DDC: + compressedColumns[i] = DDCArray.compressToDDC(in.getColumn(i)); + // break; + // default: + // compressedColumns[i] = in.getColumn(i); + // break; + // } } + else + compressedColumns[i] = in.getColumn(i); } private void logStatistics() { diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/lib/FrameUtil.java b/src/main/java/org/apache/sysds/runtime/frame/data/lib/FrameUtil.java index c2d53a650b4..705aeb24c37 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/lib/FrameUtil.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/lib/FrameUtil.java @@ -27,7 +27,6 @@ import org.apache.commons.logging.LogFactory; import org.apache.sysds.common.Types.ValueType; import org.apache.sysds.runtime.DMLRuntimeException; -import org.apache.sysds.runtime.compress.utils.Util; import org.apache.sysds.runtime.frame.data.FrameBlock; import org.apache.sysds.runtime.frame.data.columns.Array; import org.apache.sysds.runtime.frame.data.columns.BooleanArray; @@ -60,8 +59,24 @@ public static Array[] add(Array[] ar, Array e) { return ret; } - private static ValueType isBooleanType(final String val, int len) { - if(val.length() <= 16 && booleanPattern.matcher(val).matches()) + private static boolean isBooleanType(final char c) { + switch(c) { + case '0': + case '1': + case 't': + case 'T': + case 'f': + case 'F': + return true; + default: + return false; + } + } + + private static ValueType isBooleanType(final String val, final int len) { + if(len == 1 && isBooleanType(val.charAt(0))) + return ValueType.BOOLEAN; + else if(len <= 16 && isBooleanType(val.charAt(0)) && booleanPattern.matcher(val).matches()) return ValueType.BOOLEAN; return null; } @@ -69,12 +84,21 @@ private static ValueType isBooleanType(final String val, int len) { private static boolean simpleIntMatch(final String val, final int len) { for(int i = 0; i < len; i++) { final char c = val.charAt(i); + if(c == '.' && i < 0) + return restIsZero(val, i + 1, len); if(c < '0' || c > '9') return false; } return true; } + private static boolean restIsZero(final String val, int i, final int len) { + for(; i < len; i++) + if(val.charAt(i) != '0') + return false; + return true; + } + private static ValueType intType(final long value) { if(value >= Integer.MIN_VALUE && value <= Integer.MAX_VALUE) return ValueType.INT32; @@ -99,7 +123,7 @@ else if(integerFloatPattern.matcher(val).matches()) { } public static ValueType isFloatType(final String val, final int len) { - if(len <= 25 && (simpleFloatMatch(val, len) || floatPattern.matcher(val).matches())) { + if(len <= 30 && (simpleFloatMatch(val, len) || floatPattern.matcher(val).matches())) { if(len <= 7 || (len == 8 && val.charAt(0) == '-')) return ValueType.FP32; else if(len >= 13) @@ -113,8 +137,27 @@ else if(same(d, (float) d)) else return ValueType.FP64; } - else if(val.equals("infinity") || val.equals("-infinity") || val.equals("nan")) - return ValueType.FP32; + final char first = val.charAt(0); + // char sec = val.charAt(1); + + if(len >= 3 && (first == 'i' || first == 'I')) { + String val2 = val.toLowerCase(); + if((len == 3 && val2.equals("inf")) || (len == 8 && val2.equals("infinity"))) + return ValueType.FP32; + } + else if(len == 3 & (first == 'n' || first == 'N')) { + final String val2 = val.toLowerCase(); + if(val2.equals("nan")) + return ValueType.FP32; + } + else if(len > 1 && first == '-') { + final char sec = val.charAt(1); + if(sec == 'i' || sec == 'I') { + String val2 = val.toLowerCase(); + if((len == 4 && val2.equals("-inf")) || (len == 9 && val2.equals("-infinity"))) + return ValueType.FP32; + } + } return null; } @@ -126,11 +169,12 @@ private static boolean simpleFloatMatch(final String val, final int len) { final char c = val.charAt(i); if(c >= '0' && c <= '9') continue; - else if(c == '.' || c == ',') + else if(c == '.' || c == ','){ if(encounteredDot == true) return false; else encounteredDot = true; + } else return false; } @@ -165,7 +209,7 @@ public static ValueType isType(String val, ValueType minType) { switch(minType) { case UNKNOWN: case BOOLEAN: - case CHARACTER: + // case CHARACTER: if(isBooleanType(val, len) != null) return ValueType.BOOLEAN; case UINT8: @@ -179,6 +223,7 @@ public static ValueType isType(String val, ValueType minType) { r = isFloatType(val, len); if(r != null) return r; + case CHARACTER: if(len == 1) return ValueType.CHARACTER; case STRING: @@ -194,27 +239,32 @@ public static ValueType isType(String val) { public static ValueType isType(double val) { if(val == 1.0d || val == 0.0d) return ValueType.BOOLEAN; - else if(val < Integer.MAX_VALUE && Util.eq((int) val,val)) - return ValueType.INT32; - else if(val < Long.MAX_VALUE && Util.eq((long) val, val)) + else if((long) (val) == val) { + if((int) val == val) + return ValueType.INT32; + else return ValueType.INT64; + } else if(same(val, (float) val)) return ValueType.FP32; else return ValueType.FP64; + } public static ValueType isType(double val, ValueType min) { switch(min) { case BOOLEAN: return isType(val); - case UINT8: case INT32: - if(val < Integer.MAX_VALUE && Util.eq((int) val,val)) - return ValueType.INT32; + case UINT8: case INT64: - if(val < Long.MAX_VALUE && Util.eq((long) val, val)) - return ValueType.INT64; + if((long) (val) == val) { + if((int) val == val) + return ValueType.INT32; + else + return ValueType.INT64; + } case FP32: if(same(val, (float) val)) return ValueType.FP32; @@ -229,8 +279,7 @@ public static FrameBlock mergeSchema(FrameBlock temp1, FrameBlock temp2) { String[] rowTemp2 = IteratorFactory.getStringRowIterator(temp2).next(); if(rowTemp1.length != rowTemp2.length) - throw new DMLRuntimeException( - "Schema dimension " + "mismatch: " + rowTemp1.length + " vs " + rowTemp2.length); + throw new DMLRuntimeException("Schema dimension " + "mismatch: " + rowTemp1.length + " vs " + rowTemp2.length); for(int i = 0; i < rowTemp1.length; i++) { // modify schema1 if necessary (different schema2) diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/lib/MatrixBlockFromFrame.java b/src/main/java/org/apache/sysds/runtime/frame/data/lib/MatrixBlockFromFrame.java index a4854695103..230ece072ce 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/lib/MatrixBlockFromFrame.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/lib/MatrixBlockFromFrame.java @@ -18,11 +18,18 @@ */ package org.apache.sysds.runtime.frame.data.lib; +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Future; + import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; +import org.apache.sysds.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer; import org.apache.sysds.runtime.data.DenseBlock; import org.apache.sysds.runtime.frame.data.FrameBlock; import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import org.apache.sysds.runtime.util.CommonThreadPool; public interface MatrixBlockFromFrame { public static final Log LOG = LogFactory.getLog(MatrixBlockFromFrame.class.getName()); @@ -34,44 +41,84 @@ public interface MatrixBlockFromFrame { * double, we do a best effort conversion of non-double types which might result in errors for non-numerical data. * * @param frame frame block + * @param k parallelization degree * @return matrix block */ - public static MatrixBlock convertToMatrixBlock(FrameBlock frame) { + public static MatrixBlock convertToMatrixBlock(FrameBlock frame, int k) { final int m = frame.getNumRows(); final int n = frame.getNumColumns(); final MatrixBlock mb = new MatrixBlock(m, n, false); mb.allocateDenseBlock(); + if(k == -1) + k = InfrastructureAnalyzer.getLocalParallelism(); - if(mb.getDenseBlock().isContiguous()) - convertContiguous(frame, mb, m, n); + long nnz = 0; + if(k == 1) + nnz = convert(frame, mb, n, 0, m); else - convertGeneric(frame, mb, m, n); + nnz = convertParallel(frame, mb, m, n, k); + + mb.setNonZeros(nnz); mb.examSparsity(); return mb; } - private static void convertContiguous(final FrameBlock frame, final MatrixBlock mb, final int m, final int n) { + private static long convert(FrameBlock frame, MatrixBlock mb, int n, int rl, int ru) { + if(mb.getDenseBlock().isContiguous()) + return convertContiguous(frame, mb, n, rl, ru); + else + return convertGeneric(frame, mb, n, rl, ru); + } + + private static long convertParallel(FrameBlock frame, MatrixBlock mb, int m, int n, int k){ + ExecutorService pool = CommonThreadPool.get(k); + try{ + List> tasks = new ArrayList<>(); + final int blkz = Math.max(m / k, 1000); + + for( int i = 0; i < m; i+= blkz){ + final int start = i; + final int end = Math.min(i + blkz, m); + tasks.add(pool.submit(() -> convert(frame, mb, n, start, end))); + } + + long nnz = 0; + for( Future t : tasks) + nnz += t.get(); + return nnz; + } + catch(Exception e){ + throw new RuntimeException(e); + + } + finally{ + pool.shutdown(); + } + } + + private static long convertContiguous(final FrameBlock frame, final MatrixBlock mb, final int n, final int rl, + final int ru) { long lnnz = 0; double[] c = mb.getDenseBlockValues(); - for(int bi = 0; bi < m; bi += blocksizeIJ) { + for(int bi = rl; bi < ru; bi += blocksizeIJ) { for(int bj = 0; bj < n; bj += blocksizeIJ) { - int bimin = Math.min(bi + blocksizeIJ, m); + int bimin = Math.min(bi + blocksizeIJ, ru); int bjmin = Math.min(bj + blocksizeIJ, n); for(int i = bi, aix = bi * n; i < bimin; i++, aix += n) for(int j = bj; j < bjmin; j++) lnnz += (c[aix + j] = frame.getDoubleNaN(i, j)) != 0 ? 1 : 0; } } - mb.setNonZeros(lnnz); + return lnnz; } - private static void convertGeneric(final FrameBlock frame, final MatrixBlock mb, final int m, final int n) { + private static long convertGeneric(final FrameBlock frame, final MatrixBlock mb, final int n, final int rl, final int ru) { long lnnz = 0; final DenseBlock c = mb.getDenseBlock(); - for(int bi = 0; bi < m; bi += blocksizeIJ) { + for(int bi = rl; bi < ru; bi += blocksizeIJ) { for(int bj = 0; bj < n; bj += blocksizeIJ) { - int bimin = Math.min(bi + blocksizeIJ, m); + int bimin = Math.min(bi + blocksizeIJ, ru); int bjmin = Math.min(bj + blocksizeIJ, n); for(int i = bi; i < bimin; i++) { double[] cvals = c.values(i); @@ -81,6 +128,6 @@ private static void convertGeneric(final FrameBlock frame, final MatrixBlock mb, } } } - mb.setNonZeros(lnnz); + return lnnz; } } diff --git a/src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java b/src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java index 07ce7d620bb..3de9fcd65df 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java @@ -254,6 +254,7 @@ public class CPInstructionParser extends InstructionParser { String2CPInstructionType.put( "rmfilevar" , CPType.Variable); String2CPInstructionType.put( OpOp1.CAST_AS_SCALAR.toString(), CPType.Variable); String2CPInstructionType.put( OpOp1.CAST_AS_MATRIX.toString(), CPType.Variable); + String2CPInstructionType.put( "cast_as_frame", CPType.Variable); String2CPInstructionType.put( OpOp1.CAST_AS_FRAME.toString(), CPType.Variable); String2CPInstructionType.put( OpOp1.CAST_AS_LIST.toString(), CPType.Variable); String2CPInstructionType.put( OpOp1.CAST_AS_DOUBLE.toString(), CPType.Variable); @@ -482,7 +483,7 @@ public static CPInstruction parseSingleInstruction ( CPType cptype, String str ) case Broadcast: return BroadcastCPInstruction.parseInstruction(str); - + default: throw new DMLRuntimeException("Invalid CP Instruction Type: " + cptype ); } diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/CPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/CPInstruction.java index 174e4f2d27f..3503b256f77 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/cp/CPInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/CPInstruction.java @@ -20,6 +20,8 @@ package org.apache.sysds.runtime.instructions.cp; +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; import org.apache.sysds.api.DMLScript; import org.apache.sysds.common.Types.DataType; import org.apache.sysds.conf.ConfigurationManager; @@ -34,15 +36,17 @@ import org.apache.sysds.runtime.matrix.operators.Operator; import org.apache.sysds.runtime.privacy.propagation.PrivacyPropagator; -public abstract class CPInstruction extends Instruction -{ +public abstract class CPInstruction extends Instruction { + protected static final Log LOG = LogFactory.getLog(CPInstruction.class.getName()); public enum CPType { AggregateUnary, AggregateBinary, AggregateTernary, Unary, Binary, Ternary, Quaternary, BuiltinNary, Ctable, MultiReturnParameterizedBuiltin, ParameterizedBuiltin, MultiReturnBuiltin, Builtin, Reorg, Variable, FCall, Append, Rand, QSort, QPick, Local, MatrixIndexing, MMTSJ, PMMJ, MMChain, Reshape, Partition, Compression, DeCompression, SpoofFused, - StringInit, CentralMoment, Covariance, UaggOuterChain, Dnn, Sql, Prefetch, Broadcast, TrigRemote } + StringInit, CentralMoment, Covariance, UaggOuterChain, Dnn, Sql, Prefetch, Broadcast, TrigRemote, + NoOp, + } protected final CPType _cptype; protected final boolean _requiresLabelUpdate; diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/VariableCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/VariableCPInstruction.java index 29f09930f04..cd7077ed9db 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/cp/VariableCPInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/VariableCPInstruction.java @@ -69,24 +69,24 @@ import org.apache.sysds.runtime.util.ProgramConverter; import org.apache.sysds.utils.Statistics; +/* + * Supported Operations + * -------------------- + * 1) assignvar x:type y:type + * assign value of y to x (both types should match) + * 2) rmvar x + * remove variable x + * 3) cpvar x y + * copy x to y (same as assignvar followed by rmvar, types are not required) + * 4) rmfilevar x:type b:type + * remove variable x, and if b=true then the file object associated with x (b's type should be boolean) + * 5) assignvarwithfile FN x + * assign x with the first value from the file whose name=FN + * 6) attachfiletovar FP x + * allocate a new file object with name FP, and associate it with variable x + * createvar x FP [dimensions] [formatinfo] + */ public class VariableCPInstruction extends CPInstruction implements LineageTraceable { - /* - * Supported Operations - * -------------------- - * 1) assignvar x:type y:type - * assign value of y to x (both types should match) - * 2) rmvar x - * remove variable x - * 3) cpvar x y - * copy x to y (same as assignvar followed by rmvar, types are not required) - * 4) rmfilevar x:type b:type - * remove variable x, and if b=true then the file object associated with x (b's type should be boolean) - * 5) assignvarwithfile FN x - * assign x with the first value from the file whose name=FN - * 6) attachfiletovar FP x - * allocate a new file object with name FP, and associate it with variable x - * createvar x FP [dimensions] [formatinfo] - */ public enum VariableOperationCode { CreateVariable, @@ -189,30 +189,24 @@ private VariableCPInstruction(VariableOperationCode op, CPOperand in1, CPOperand } private static VariableOperationCode getVariableOperationCode ( String str ) { - if ( str.equalsIgnoreCase("createvar")) return VariableOperationCode.CreateVariable; - else if ( str.equalsIgnoreCase("assignvar")) return VariableOperationCode.AssignVariable; - else if ( str.equalsIgnoreCase("cpvar")) return VariableOperationCode.CopyVariable; - else if ( str.equalsIgnoreCase("mvvar")) return VariableOperationCode.MoveVariable; - else if ( str.equalsIgnoreCase("rmvar") ) return VariableOperationCode.RemoveVariable; - else if ( str.equalsIgnoreCase("rmfilevar") ) return VariableOperationCode.RemoveVariableAndFile; - else if ( str.equalsIgnoreCase(OpOp1.CAST_AS_SCALAR.toString()) ) return VariableOperationCode.CastAsScalarVariable; else if ( str.equalsIgnoreCase(OpOp1.CAST_AS_MATRIX.toString()) ) return VariableOperationCode.CastAsMatrixVariable; - else if ( str.equalsIgnoreCase(OpOp1.CAST_AS_FRAME.toString()) ) + else if ( str.equalsIgnoreCase(OpOp1.CAST_AS_FRAME.toString()) + || str.equalsIgnoreCase("cast_as_frame")) return VariableOperationCode.CastAsFrameVariable; else if ( str.equalsIgnoreCase(OpOp1.CAST_AS_LIST.toString()) ) return VariableOperationCode.CastAsListVariable; @@ -222,16 +216,12 @@ else if ( str.equalsIgnoreCase(OpOp1.CAST_AS_INT.toString()) ) return VariableOperationCode.CastAsIntegerVariable; else if ( str.equalsIgnoreCase(OpOp1.CAST_AS_BOOLEAN.toString()) ) return VariableOperationCode.CastAsBooleanVariable; - else if ( str.equalsIgnoreCase("write") ) return VariableOperationCode.Write; - else if ( str.equalsIgnoreCase("read") ) return VariableOperationCode.Read; - else if ( str.equalsIgnoreCase("setfilename") ) return VariableOperationCode.SetFileName; - else throw new DMLRuntimeException("Invalid function: " + str); } @@ -343,7 +333,7 @@ public static VariableCPInstruction parseInstruction ( String str ) { String[] parts = InstructionUtils.getInstructionPartsWithValueType ( str ); String opcode = parts[0]; VariableOperationCode voc = getVariableOperationCode(opcode); - + if ( voc == VariableOperationCode.CreateVariable ){ if ( parts.length < 5 ) //&& parts.length != 10 ) throw new DMLRuntimeException("Invalid number of operands in createvar instruction: " + str); @@ -361,6 +351,10 @@ else if ( voc == VariableOperationCode.Write ) { if ( parts.length != 6 && parts.length != 7 && parts.length != 9 ) throw new DMLRuntimeException("Invalid number of operands in write instruction: " + str); } + else if(voc == VariableOperationCode.CastAsFrameVariable){ + // LOG.error(parts); + InstructionUtils.checkNumFields(parts, 3, 4, 5); + } else { try{ if( voc != VariableOperationCode.RemoveVariable ) @@ -548,9 +542,16 @@ else if(fmt.equalsIgnoreCase("hdf5")) { throw new DMLRuntimeException("Unexpected value type for second argument in: " + str); break; + case CastAsFrameVariable: + if(parts.length==5){ + in1 = new CPOperand(parts[1]); // input to cast + in2 = new CPOperand(parts[2]); // list of column names + out = new CPOperand(parts[3]); // output + k = Integer.parseInt(parts[4]); + break; + } case CastAsScalarVariable: case CastAsMatrixVariable: - case CastAsFrameVariable: case CastAsListVariable: case CastAsDoubleVariable: case CastAsIntegerVariable: @@ -966,21 +967,39 @@ private void processCastAsFrameVariableInstruction(ExecutionContext ec){ out = new FrameBlock(1, getInput1().getValueType()); out.ensureAllocatedColumns(1); out.set(0, 0, scalarInput.getStringValue()); + setColumnNames(ec, out); ec.setFrameOutput(output.getName(), out); } else if(getInput1().getDataType()==DataType.MATRIX) { //DataType.FRAME MatrixBlock min = ec.getMatrixInput(getInput1().getName()); out = DataConverter.convertToFrameBlock(min, k); ec.releaseMatrixInput(getInput1().getName()); + setColumnNames(ec, out); ec.setFrameOutput(output.getName(), out); } else { //convert list ListObject list = (ListObject)ec.getVariable(getInput1().getName()); Data tmp = list.slice(0); + if(getInput2() != null){ + throw new RuntimeException("List does not support as.frame column names arguments"); + } ec.setVariable(output.getName(), tmp); } } + private void setColumnNames(ExecutionContext ec, FrameBlock out){ + if(getInput2() != null){ + ListObject colNames = (ListObject)ec.getVariable(getInput2().getName()); + String[] names = new String[out.getNumColumns()]; + List dat = colNames.getData(); + LOG.error(dat); + for(int i = 0; i < out.getNumColumns();i++){ + names[i] = ((StringObject)dat.get(i)).getStringValue(); + } + out.setColumnNames(names); + } + } + /** * Handler for Read instruction * diff --git a/src/main/java/org/apache/sysds/runtime/util/DataConverter.java b/src/main/java/org/apache/sysds/runtime/util/DataConverter.java index 72428163f49..c575baf0219 100644 --- a/src/main/java/org/apache/sysds/runtime/util/DataConverter.java +++ b/src/main/java/org/apache/sysds/runtime/util/DataConverter.java @@ -602,7 +602,7 @@ public static MatrixBlock convertToMatrixBlock( CTableMap map, int rlen, int cle * @return matrix block */ public static MatrixBlock convertToMatrixBlock(FrameBlock frame){ - return MatrixBlockFromFrame.convertToMatrixBlock(frame); + return MatrixBlockFromFrame.convertToMatrixBlock(frame, 1); } /** diff --git a/src/test/java/org/apache/sysds/test/component/frame/FrameUtilTest.java b/src/test/java/org/apache/sysds/test/component/frame/FrameUtilTest.java index b627c2042b4..746762ce1b3 100644 --- a/src/test/java/org/apache/sysds/test/component/frame/FrameUtilTest.java +++ b/src/test/java/org/apache/sysds/test/component/frame/FrameUtilTest.java @@ -133,7 +133,7 @@ public void testIsTypeMinimumString_5() { @Test public void testIsIntLongString() { - assertEquals(ValueType.STRING, FrameUtil.isType("11111111111111111111111111111")); + assertEquals(ValueType.STRING, FrameUtil.isType("111111111111111111111111111111111")); } @Test diff --git a/src/test/java/org/apache/sysds/test/component/frame/array/CustomArrayTests.java b/src/test/java/org/apache/sysds/test/component/frame/array/CustomArrayTests.java index 24e1895e43a..f0dcbf9c6eb 100644 --- a/src/test/java/org/apache/sysds/test/component/frame/array/CustomArrayTests.java +++ b/src/test/java/org/apache/sysds/test/component/frame/array/CustomArrayTests.java @@ -852,6 +852,40 @@ public void testSetBChangeType() { assertEquals(c.size(), 4); } + @Test + public void testDDCIn() { + try { + Array a = null; + Array b = new DDCArray(new LongArray(new long[] {1, 2, 3, 4}), // + MapToFactory.create(10, new int[] {0, 0, 0, 0, 1, 1, 1, 2, 2, 3,3}, 4)); + Array c = ArrayFactory.set(a, b, 10, 19, 20); + assertEquals((long) c.get(0), 0L); + assertEquals((long) c.get(10), 1L); + } + catch(Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + } + + @Test + public void testDDCInOptional() { + try { + Array a = null; + Array b = new DDCArray(new OptionalArray(new Long[] {1L, 2L, 3L, 4L}), // + MapToFactory.create(10, new int[] {0, 0, 0, 0, 1, 1, 1, 2, 2, 3,3}, 4)); + Array c = ArrayFactory.set(a, b, 10, 19, 20); + assertEquals(c.get(0), null); + assertEquals((long) c.get(10), 1L); + } + catch(Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + } + + + @Test public void testSetOptionalB() { try { diff --git a/src/test/java/org/apache/sysds/test/component/frame/array/FrameArrayTests.java b/src/test/java/org/apache/sysds/test/component/frame/array/FrameArrayTests.java index 91a17d2d2cb..35d4d0e87c9 100644 --- a/src/test/java/org/apache/sysds/test/component/frame/array/FrameArrayTests.java +++ b/src/test/java/org/apache/sysds/test/component/frame/array/FrameArrayTests.java @@ -54,6 +54,7 @@ import org.apache.sysds.runtime.frame.data.columns.OptionalArray; import org.apache.sysds.runtime.frame.data.columns.RaggedArray; import org.apache.sysds.runtime.frame.data.columns.StringArray; +import org.apache.sysds.runtime.frame.data.compress.ArrayCompressionStatistics; import org.apache.sysds.runtime.frame.data.lib.FrameLibRemoveEmpty; import org.apache.sysds.runtime.matrix.data.Pair; import org.junit.Test; @@ -101,17 +102,13 @@ public static Collection data() { tests.add(new Object[] {ArrayFactory.create(new String[] {"1", "0", "1"}), FrameArrayType.STRING}); tests.add(new Object[] {ArrayFactory.create(new String[] {"1", "0", "null"}), FrameArrayType.STRING}); tests.add(new Object[] {ArrayFactory.create(new String[] {"0", "0", "null"}), FrameArrayType.STRING}); - tests.add( - new Object[] {ArrayFactory.create(new String[] {"true", "false", "false"}), FrameArrayType.STRING}); - tests.add( - new Object[] {ArrayFactory.create(new String[] {"True", "False", "False"}), FrameArrayType.STRING}); - tests.add( - new Object[] {ArrayFactory.create(new String[] {"False", "False", "False"}), FrameArrayType.STRING}); + tests.add(new Object[] {ArrayFactory.create(new String[] {"true", "false", "false"}), FrameArrayType.STRING}); + tests.add(new Object[] {ArrayFactory.create(new String[] {"True", "False", "False"}), FrameArrayType.STRING}); + tests.add(new Object[] {ArrayFactory.create(new String[] {"False", "False", "False"}), FrameArrayType.STRING}); tests.add(new Object[] {ArrayFactory.create(new String[] {"T", "F", "F"}), FrameArrayType.STRING}); tests.add(new Object[] {ArrayFactory.create(new String[] {"t", "f", "f"}), FrameArrayType.STRING}); tests.add(new Object[] {ArrayFactory.create(new String[] {"f", "t", "t"}), FrameArrayType.STRING}); - tests - .add(new Object[] {ArrayFactory.create(new String[] {"true", "false", "BLAA"}), FrameArrayType.STRING}); + tests.add(new Object[] {ArrayFactory.create(new String[] {"true", "false", "BLAA"}), FrameArrayType.STRING}); tests.add(new Object[] {ArrayFactory.create(new float[] {0.0f, 1.0f, 1.0f, 0.0f}), FrameArrayType.FP32}); tests.add(new Object[] {ArrayFactory.create(new double[] {0.0, 1.0, 1.0, 0.0}), FrameArrayType.FP64}); tests.add(new Object[] {ArrayFactory.create(new long[] {0, 1, 1, 0, 0, 1}), FrameArrayType.INT64}); @@ -119,12 +116,9 @@ public static Collection data() { tests.add(new Object[] {ArrayFactory.create(generateRandom01String(100, 324)), FrameArrayType.STRING}); tests.add(new Object[] {ArrayFactory.create(generateRandom01String(80, 22)), FrameArrayType.STRING}); tests.add(new Object[] {ArrayFactory.create(generateRandom01String(32, 221)), FrameArrayType.STRING}); - tests - .add(new Object[] {ArrayFactory.create(generateRandomTrueFalseString(32, 221)), FrameArrayType.STRING}); - tests - .add(new Object[] {ArrayFactory.create(generateRandomTrueFalseString(80, 221)), FrameArrayType.STRING}); - tests.add( - new Object[] {ArrayFactory.create(generateRandomTrueFalseString(150, 221)), FrameArrayType.STRING}); + tests.add(new Object[] {ArrayFactory.create(generateRandomTrueFalseString(32, 221)), FrameArrayType.STRING}); + tests.add(new Object[] {ArrayFactory.create(generateRandomTrueFalseString(80, 221)), FrameArrayType.STRING}); + tests.add(new Object[] {ArrayFactory.create(generateRandomTrueFalseString(150, 221)), FrameArrayType.STRING}); tests.add(new Object[] {ArrayFactory.create(generateRandomTFString(150, 221)), FrameArrayType.STRING}); tests.add(new Object[] {ArrayFactory.create(generateRandomTFString(22, 2)), FrameArrayType.STRING}); tests.add(new Object[] {ArrayFactory.create(generateRandomTFString(142, 4)), FrameArrayType.STRING}); @@ -137,10 +131,8 @@ public static Collection data() { tests.add(new Object[] {ArrayFactory.create(generateRandomNullFloatString(67, 21)), FrameArrayType.STRING}); tests.add(new Object[] {ArrayFactory.create(new String[30]), FrameArrayType.STRING}); // all null tests.add(new Object[] {ArrayFactory.create(new char[] {0, 0, 0, 0, 1, 1, 1}), FrameArrayType.CHARACTER}); - tests.add( - new Object[] {ArrayFactory.create(new char[] {'t', 't', 'f', 'f', 'T'}), FrameArrayType.CHARACTER}); - tests.add( - new Object[] {ArrayFactory.create(new char[] {'0', '2', '3', '4', '9'}), FrameArrayType.CHARACTER}); + tests.add(new Object[] {ArrayFactory.create(new char[] {'t', 't', 'f', 'f', 'T'}), FrameArrayType.CHARACTER}); + tests.add(new Object[] {ArrayFactory.create(new char[] {'0', '2', '3', '4', '9'}), FrameArrayType.CHARACTER}); tests.add(new Object[] {ArrayFactory.create(generateRandom01chars(150, 221)), FrameArrayType.CHARACTER}); tests.add(new Object[] {ArrayFactory.create(generateRandom01chars(67, 221)), FrameArrayType.CHARACTER}); tests.add(new Object[] {DDCArray.compressToDDC(ArrayFactory.create(generateRandom01chars(67, 221))), @@ -148,8 +140,7 @@ public static Collection data() { tests.add(new Object[] {DDCArray.compressToDDC(ArrayFactory.create(generateRandom01chars(30, 221))), FrameArrayType.CHARACTER}); // Long to int - tests.add( - new Object[] {ArrayFactory.create(new long[] {3214, 424, 13, 22, 111, 134}), FrameArrayType.INT64}); + tests.add(new Object[] {ArrayFactory.create(new long[] {3214, 424, 13, 22, 111, 134}), FrameArrayType.INT64}); tests.add(new Object[] {ArrayFactory.create(new double[] {// Double.NaN, 424, 13, Double.NEGATIVE_INFINITY, Double.POSITIVE_INFINITY, 134}), FrameArrayType.FP64}); @@ -311,6 +302,113 @@ public void testToString() { a.toString(); } + @Test + public void equalsOtherType() { + try { + + switch(a.getValueType()) { + case BOOLEAN: + assertFalse(a.equals((Object) ArrayFactory.create(new char[] {'a', 'b'}))); + break; + default: + assertFalse(a.equals((Object) ArrayFactory.create(new boolean[] {true, false}))); + } + } + catch(Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + } + + @Test + public void equalsSelf() { + try { + assertTrue(a.equals((Object) a)); + } + catch(Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + } + + @Test + public void equalsClone() { + try { + assertTrue(a.equals((Object) a.clone())); + } + catch(Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + } + + @Test + public void notEqualsRandomObject() { + try { + assertFalse(a.equals((Object) Double.valueOf(4213.2))); + } + catch(Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + } + + @Test + public void sameValueTypeNotEquals() { + try { + Array b = ArrayFactory.allocate(a.getValueType(), a.size() == 1 ? 2 : 1); + assertFalse(a.equals((Object) b)); + } + catch(Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + } + + @Test + public void getStatistics() { + ArrayCompressionStatistics s = (a.size() < 1000) ? // + a.statistics(a.size()) : a.statistics(1000); + if(s != null) { + assertTrue(s.compressedSizeEstimate < s.originalSize); + } + } + + @Test + public void setWithDDC() { + if(a.size() > 31) { + try{ + + Array t = a.clone(); + Array ddc = DDCArray.compressToDDC(// + ArrayFactory.allocate(t.getValueType(), 30)); + ArrayFactory.set(t, ddc, 0, 29, t.size()); + switch(t.getValueType()) { + case BOOLEAN: + assertEquals(t.get(0), (Boolean) false); + break; + default: + + } + } + catch(DMLCompressionException e){ + // valid error, Illegal to set range in a compressed array. + } + catch(DMLRuntimeException e){ + // is intentional here. + if(!e.getMessage().contains("RaggedArray")){ + e.printStackTrace(); + fail(e.getMessage()); + } + } + catch(Exception e){ + e.printStackTrace(); + fail(e.getMessage()); + } + // assertEquals(t.get(0), ) + } + } + @Test public void getFrameArrayType() { if(t == FrameArrayType.BITSET) @@ -641,71 +739,79 @@ public void setDouble_2() { @Test public void analyzeValueType() { - ValueType av = a.analyzeValueType().getKey(); - switch(a.getValueType()) { - case BOOLEAN: - switch(av) { - case BOOLEAN: - return; - default: - fail("Invalid type returned from analyze valueType " + av); - } - case INT32: - switch(av) { - case BOOLEAN: - case INT32: - case UINT8: - return; - default: - fail("Invalid type returned from analyze valueType " + av); - } - case INT64: - switch(av) { - case BOOLEAN: - case INT32: - case UINT8: - case INT64: - return; - default: - fail("Invalid type returned from analyze valueType " + av + " " + a); - } - case UINT8: - switch(av) { - case BOOLEAN: - case UINT8: - return; - default: - fail("Invalid type returned from analyze valueType " + av); - } - case FP32: - switch(av) { - case BOOLEAN: - case INT32: - case UINT8: - case INT64: - case FP32: - return; - default: - fail("Invalid type returned from analyze valueType " + av); - } - case FP64: - switch(av) { - case BOOLEAN: - case INT32: - case UINT8: - case INT64: - case FP32: - case FP64: - return; - default: - fail("Invalid type returned from analyze valueType " + av); - } - case STRING: - break;// all allowed - case UNKNOWN: - fail("Not allowed to be unknown"); - default: - break; + try { + + ValueType av = a.analyzeValueType().getKey(); + + switch(a.getValueType()) { + case BOOLEAN: + switch(av) { + case BOOLEAN: + return; + default: + fail("Invalid type returned from analyze valueType " + av); + } + case INT32: + switch(av) { + case BOOLEAN: + case INT32: + case UINT8: + return; + default: + fail("Invalid type returned from analyze valueType " + av); + } + case INT64: + switch(av) { + case BOOLEAN: + case INT32: + case UINT8: + case INT64: + return; + default: + fail("Invalid type returned from analyze valueType " + av + " " + a); + } + case UINT8: + switch(av) { + case BOOLEAN: + case UINT8: + return; + default: + fail("Invalid type returned from analyze valueType " + av); + } + case FP32: + switch(av) { + case BOOLEAN: + case INT32: + case UINT8: + case INT64: + case FP32: + return; + default: + fail("Invalid type returned from analyze valueType " + av); + } + case FP64: + switch(av) { + case BOOLEAN: + case INT32: + case UINT8: + case INT64: + case FP32: + case FP64: + return; + default: + fail("Invalid type returned from analyze valueType " + av); + } + case STRING: + break;// all allowed + case UNKNOWN: + fail("Not allowed to be unknown"); + default: + break; + } + } + catch(Exception e) { + e.printStackTrace(); + fail(e.getMessage()); } } @@ -1747,7 +1853,7 @@ protected static void compareSetSubRange(Array out, Array in, int rl, int Object bv = in.get(off); if((av == null && bv != null) || (bv == null && av != null)) fail("not both null"); - else if(av != null && bv != null){ + else if(av != null && bv != null) { String v1 = av.toString(); String v2 = bv.toString(); assertEquals("i: " + i + " args: " + rl + " " + ru + " " + (off - i) + " " + out.size(), v1, v2); @@ -1762,7 +1868,8 @@ protected static Array serializeAndBack(Array g) { DataOutputStream fos = new DataOutputStream(bos); g.write(fos); DataInputStream fis = new DataInputStream(new ByteArrayInputStream(bos.toByteArray())); - return ArrayFactory.read(fis, nRow); + Array gr = ArrayFactory.read(fis, nRow); + return gr; } catch(Exception e) { e.printStackTrace(); @@ -1783,8 +1890,7 @@ protected static Array createDDC(FrameArrayType t, int size, int seed) { return DDCArray .compressToDDC(ArrayFactory.create(generateRandomIntegerNUniqueLengthOpt(size, seed, nUnique))); case INT64: - return DDCArray - .compressToDDC(ArrayFactory.create(generateRandomLongNUniqueLengthOpt(size, seed, nUnique))); + return DDCArray.compressToDDC(ArrayFactory.create(generateRandomLongNUniqueLengthOpt(size, seed, nUnique))); case FP32: return DDCArray .compressToDDC(ArrayFactory.create(generateRandomFloatNUniqueLengthOpt(size, seed, nUnique))); @@ -1798,20 +1904,20 @@ protected static Array createDDC(FrameArrayType t, int size, int seed) { Random r = new Random(seed); switch(r.nextInt(7)) { case 0: - return DDCArray.compressToDDC( - ArrayFactory.create(generateRandomIntegerNUniqueLengthOpt(size, seed, nUnique))); + return DDCArray + .compressToDDC(ArrayFactory.create(generateRandomIntegerNUniqueLengthOpt(size, seed, nUnique))); case 1: - return DDCArray.compressToDDC( - ArrayFactory.create(generateRandomLongNUniqueLengthOpt(size, seed, nUnique))); + return DDCArray + .compressToDDC(ArrayFactory.create(generateRandomLongNUniqueLengthOpt(size, seed, nUnique))); case 2: - return DDCArray.compressToDDC( - ArrayFactory.create(generateRandomDoubleNUniqueLengthOpt(size, seed, nUnique))); + return DDCArray + .compressToDDC(ArrayFactory.create(generateRandomDoubleNUniqueLengthOpt(size, seed, nUnique))); case 3: - return DDCArray.compressToDDC( - ArrayFactory.create(generateRandomFloatNUniqueLengthOpt(size, seed, nUnique))); + return DDCArray + .compressToDDC(ArrayFactory.create(generateRandomFloatNUniqueLengthOpt(size, seed, nUnique))); case 4: - return DDCArray.compressToDDC( - ArrayFactory.create(generateRandomCharacterNUniqueLengthOpt(size, seed, nUnique))); + return DDCArray + .compressToDDC(ArrayFactory.create(generateRandomCharacterNUniqueLengthOpt(size, seed, nUnique))); default: return DDCArray.compressToDDC(ArrayFactory.create(generateRandomBooleanOpt(size, seed))); } @@ -1835,23 +1941,23 @@ protected static Array createDDC(FrameArrayType t, int size, int seed) { Random r2 = new Random(seed); switch(r2.nextInt(7)) { case 0: - return DDCArray.compressToDDC( - ArrayFactory.create(generateRandomIntegerNUniqueLengthOpt(size, seed, nUnique))); + return DDCArray + .compressToDDC(ArrayFactory.create(generateRandomIntegerNUniqueLengthOpt(size, seed, nUnique))); case 1: - return DDCArray.compressToDDC( - ArrayFactory.create(generateRandomLongNUniqueLengthOpt(size, seed, nUnique))); + return DDCArray + .compressToDDC(ArrayFactory.create(generateRandomLongNUniqueLengthOpt(size, seed, nUnique))); case 2: - return DDCArray.compressToDDC( - ArrayFactory.create(generateRandomDoubleNUniqueLengthOpt(size, seed, nUnique))); + return DDCArray + .compressToDDC(ArrayFactory.create(generateRandomDoubleNUniqueLengthOpt(size, seed, nUnique))); case 3: - return DDCArray.compressToDDC( - ArrayFactory.create(generateRandomFloatNUniqueLengthOpt(size, seed, nUnique))); + return DDCArray + .compressToDDC(ArrayFactory.create(generateRandomFloatNUniqueLengthOpt(size, seed, nUnique))); case 4: - return DDCArray.compressToDDC( - ArrayFactory.create(generateRandomCharacterNUniqueLengthOpt(size, seed, nUnique))); + return DDCArray + .compressToDDC(ArrayFactory.create(generateRandomCharacterNUniqueLengthOpt(size, seed, nUnique))); case 5: - return DDCArray.compressToDDC( - ArrayFactory.create(generateRandomStringNUniqueLengthOpt(size, seed, nUnique, 32))); + return DDCArray + .compressToDDC(ArrayFactory.create(generateRandomStringNUniqueLengthOpt(size, seed, nUnique, 32))); default: return DDCArray.compressToDDC(ArrayFactory.create(generateRandomBooleanOpt(size, seed))); } @@ -1901,23 +2007,23 @@ protected static Array createOptional(FrameArrayType t, int size, int seed) { int nUnique = Math.max(size / 100, 2); switch(r2.nextInt(7)) { case 0: - return DDCArray.compressToDDC( - ArrayFactory.create(generateRandomIntegerNUniqueLengthOpt(size, seed, nUnique))); + return DDCArray + .compressToDDC(ArrayFactory.create(generateRandomIntegerNUniqueLengthOpt(size, seed, nUnique))); case 1: - return DDCArray.compressToDDC( - ArrayFactory.create(generateRandomLongNUniqueLengthOpt(size, seed, nUnique))); + return DDCArray + .compressToDDC(ArrayFactory.create(generateRandomLongNUniqueLengthOpt(size, seed, nUnique))); case 2: - return DDCArray.compressToDDC( - ArrayFactory.create(generateRandomDoubleNUniqueLengthOpt(size, seed, nUnique))); + return DDCArray + .compressToDDC(ArrayFactory.create(generateRandomDoubleNUniqueLengthOpt(size, seed, nUnique))); case 3: - return DDCArray.compressToDDC( - ArrayFactory.create(generateRandomFloatNUniqueLengthOpt(size, seed, nUnique))); + return DDCArray + .compressToDDC(ArrayFactory.create(generateRandomFloatNUniqueLengthOpt(size, seed, nUnique))); case 4: - return DDCArray.compressToDDC( - ArrayFactory.create(generateRandomCharacterNUniqueLengthOpt(size, seed, nUnique))); + return DDCArray + .compressToDDC(ArrayFactory.create(generateRandomCharacterNUniqueLengthOpt(size, seed, nUnique))); case 5: - return DDCArray.compressToDDC( - ArrayFactory.create(generateRandomStringNUniqueLengthOpt(size, seed, nUnique, 32))); + return DDCArray + .compressToDDC(ArrayFactory.create(generateRandomStringNUniqueLengthOpt(size, seed, nUnique, 32))); default: return DDCArray.compressToDDC(ArrayFactory.create(generateRandomBooleanOpt(size, seed))); } @@ -1984,23 +2090,23 @@ protected static Array create(FrameArrayType t, int size, int seed) { int nUnique = Math.max(size / 100, 2); switch(r2.nextInt(7)) { case 0: - return DDCArray.compressToDDC( - ArrayFactory.create(generateRandomIntegerNUniqueLengthOpt(size, seed, nUnique))); + return DDCArray + .compressToDDC(ArrayFactory.create(generateRandomIntegerNUniqueLengthOpt(size, seed, nUnique))); case 1: - return DDCArray.compressToDDC( - ArrayFactory.create(generateRandomLongNUniqueLengthOpt(size, seed, nUnique))); + return DDCArray + .compressToDDC(ArrayFactory.create(generateRandomLongNUniqueLengthOpt(size, seed, nUnique))); case 2: - return DDCArray.compressToDDC( - ArrayFactory.create(generateRandomDoubleNUniqueLengthOpt(size, seed, nUnique))); + return DDCArray + .compressToDDC(ArrayFactory.create(generateRandomDoubleNUniqueLengthOpt(size, seed, nUnique))); case 3: - return DDCArray.compressToDDC( - ArrayFactory.create(generateRandomFloatNUniqueLengthOpt(size, seed, nUnique))); + return DDCArray + .compressToDDC(ArrayFactory.create(generateRandomFloatNUniqueLengthOpt(size, seed, nUnique))); case 4: - return DDCArray.compressToDDC( - ArrayFactory.create(generateRandomCharacterNUniqueLengthOpt(size, seed, nUnique))); + return DDCArray + .compressToDDC(ArrayFactory.create(generateRandomCharacterNUniqueLengthOpt(size, seed, nUnique))); case 5: - return DDCArray.compressToDDC( - ArrayFactory.create(generateRandomStringNUniqueLengthOpt(size, seed, nUnique, 32))); + return DDCArray + .compressToDDC(ArrayFactory.create(generateRandomStringNUniqueLengthOpt(size, seed, nUnique, 32))); default: return DDCArray.compressToDDC(ArrayFactory.create(generateRandomBooleanOpt(size, seed))); } diff --git a/src/test/java/org/apache/sysds/test/component/frame/array/NegativeArrayTests.java b/src/test/java/org/apache/sysds/test/component/frame/array/NegativeArrayTests.java index 68900fe5c99..dd471e3159f 100644 --- a/src/test/java/org/apache/sysds/test/component/frame/array/NegativeArrayTests.java +++ b/src/test/java/org/apache/sysds/test/component/frame/array/NegativeArrayTests.java @@ -40,6 +40,7 @@ import org.apache.sysds.runtime.frame.data.columns.IntegerArray; import org.apache.sysds.runtime.frame.data.columns.LongArray; import org.apache.sysds.runtime.frame.data.columns.OptionalArray; +import org.apache.sysds.runtime.frame.data.columns.RaggedArray; import org.apache.sysds.runtime.frame.data.columns.StringArray; import org.junit.Test; import org.mockito.Mockito; @@ -207,7 +208,7 @@ public void createOptionalWithOptionalConstructor2() { } @Test(expected = DMLRuntimeException.class) - public void readFields() { + public void readFieldsOpt() { try { new OptionalArray<>(new Integer[1]).readFields(null); } @@ -216,6 +217,16 @@ public void readFields() { } } + @Test(expected = DMLRuntimeException.class) + public void readFieldsRagged() { + try { + new RaggedArray<>(ArrayFactory.create(new Integer[]{1,2,3}),10).readFields(null); + } + catch(IOException e) { + fail("not correct exception"); + } + } + @Test(expected = NullPointerException.class) public void invalidConstructOptional1() { new OptionalArray<>(ArrayFactory.allocate(ValueType.CHARACTER, 10), null); @@ -279,4 +290,39 @@ public void set3() { ACompressedArray a = mock(ACompressedArray.class, Mockito.CALLS_REAL_METHODS); a.set(0, Integer.valueOf(13)); } + + @Test(expected = DMLRuntimeException.class) + public void testInvalidRLen() { + Array a = null; + Array b = new OptionalArray(new Long[] {1L, 2L, 3L, 4L}); + ArrayFactory.set(a, b, 10, 20, 20); + } + + @Test(expected = NullPointerException.class) + public void testNull() { + Array a = null; + Array b = null; + ArrayFactory.set(a, b, 10, 15, 20); + } + + @Test(expected = DMLRuntimeException.class) + public void testInvalidBLength() { + Array a = null; + Array b = new OptionalArray(new Long[] {1L, 2L, 3L, 4L}); + ArrayFactory.set(a, b, 10, 15, 20);// one to short + } + + @Test(expected = DMLRuntimeException.class) + public void testInvalidALength() { + Array a = ArrayFactory.allocate( ValueType.INT32, 10); + Array b = new OptionalArray(new Long[] {1L, 2L, 3L, 4L}); + ArrayFactory.set(a, b, 10, 14, 20);// one to short + } + + @Test(expected = DMLRuntimeException.class) + public void testInvalidRL() { + Array a = ArrayFactory.allocate( ValueType.INT32, 10); + Array b = new OptionalArray(new Long[] {1L, 2L, 3L, 4L}); + ArrayFactory.set(a, b, -1, 15, 20);// one to short + } } From 902484c811b9362b6ca9bb0a2b599ea5db16732e Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 24 Oct 2023 10:27:36 +0000 Subject: [PATCH 02/28] Bump actions/setup-node from 3 to 4 (#1931) --- .github/workflows/monitoringUITests.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/monitoringUITests.yml b/.github/workflows/monitoringUITests.yml index a54384807fb..6e37d194b1b 100644 --- a/.github/workflows/monitoringUITests.yml +++ b/.github/workflows/monitoringUITests.yml @@ -54,7 +54,7 @@ jobs: steps: - uses: actions/checkout@v4 - name: Build the application, with Node.js ${{ matrix.node-version }} - uses: actions/setup-node@v3 + uses: actions/setup-node@v4 with: # Set always-auth in npmrc always-auth: false # optional, default is false From e7f1640f10c2d7534941b3913bd9363253d461d4 Mon Sep 17 00:00:00 2001 From: Sebastian Baunsgaard Date: Thu, 7 Sep 2023 18:21:47 +0200 Subject: [PATCH 03/28] [SYSTEMDS-3619] IO save Encoded This commit adds settings to systemds for writing and reading in compressed encoded formats via the Hadoop compression blocks. The implementation works with all writers used in the system that interact with the filesystem via HDFS sequence file writers. Initial experiments are promissing with: Uncompressed: StandardDisk, 6.345 ms, 1260913549 Byte/s, 1278545356 Byte/s Snappy: StandardDisk, 6.021 ms, 1328816357Byte/s, 291445747 Byte/s This indicate that at least for matrices such as the above we should compress on the IO path to disk improving write speeds by 100MB/s and reduce the saved size by 4.3x. Also included in this PR is various fixes and upgrades including - A new Sparse to Dense blocks library file to parallelize the transformation - Compression compilation and Hop insertion of compression instructions - Local File IO shortcuts when IO is not using Hadoop to reduce folder and file scan look-ups. - Fixes to reading and writing CLA compressed matrix blocks. Closes #1902 --- scripts/builtin/auc.dml | 2 +- .../java/org/apache/sysds/conf/DMLConfig.java | 4 +- .../org/apache/sysds/hops/FunctionOp.java | 36 +- src/main/java/org/apache/sysds/hops/Hop.java | 69 +- .../java/org/apache/sysds/hops/LiteralOp.java | 5 - .../java/org/apache/sysds/hops/UnaryOp.java | 5 +- .../rewrite/RewriteCompressedReblock.java | 36 +- .../RewriteRemovePersistentReadWrite.java | 3 +- .../org/apache/sysds/lops/Compression.java | 16 +- .../CompressedMatrixBlockFactory.java | 38 +- .../compress/bitmap/BitmapEncoder.java | 17 +- .../runtime/compress/cocode/CoCodeGreedy.java | 6 +- .../runtime/compress/colgroup/AColGroup.java | 16 +- .../compress/colgroup/ColGroupConst.java | 40 +- .../compress/colgroup/ColGroupDDC.java | 2 +- .../compress/colgroup/ColGroupDDCFOR.java | 4 +- .../compress/colgroup/ColGroupEmpty.java | 32 +- .../compress/colgroup/ColGroupFactory.java | 149 ++- .../colgroup/ColGroupLinearFunctional.java | 4 +- .../compress/colgroup/ColGroupOLE.java | 4 +- .../compress/colgroup/ColGroupRLE.java | 4 +- .../compress/colgroup/ColGroupSDC.java | 42 +- .../compress/colgroup/ColGroupSDCFOR.java | 6 +- .../compress/colgroup/ColGroupSDCSingle.java | 38 +- .../colgroup/ColGroupSDCSingleZeros.java | 32 +- .../compress/colgroup/ColGroupSDCZeros.java | 26 +- .../colgroup/ColGroupUncompressed.java | 38 +- .../colgroup/dictionary/ADictionary.java | 5 + .../colgroup/dictionary/Dictionary.java | 13 +- .../colgroup/dictionary/IDictionary.java | 22 +- .../colgroup/dictionary/PlaceHolderDict.java | 932 +++++++++--------- .../compress/colgroup/indexes/ArrayIndex.java | 37 +- .../colgroup/indexes/SingleIndex.java | 7 +- .../compress/colgroup/mapping/MapToBit.java | 23 +- .../compress/colgroup/mapping/MapToByte.java | 22 +- .../compress/colgroup/mapping/MapToChar.java | 46 +- .../colgroup/mapping/MapToCharPByte.java | 21 +- .../compress/colgroup/mapping/MapToInt.java | 22 +- .../compress/colgroup/mapping/MapToZero.java | 16 +- .../compress/colgroup/offset/AOffset.java | 98 +- .../compress/colgroup/offset/OffsetByte.java | 13 +- .../colgroup/offset/OffsetByteNZ.java | 11 +- .../colgroup/offset/OffsetByteUNZ.java | 11 +- .../colgroup/offset/OffsetFactory.java | 1 - .../compress/colgroup/offset/OffsetTwo.java | 4 +- .../colgroup/scheme/CompressionScheme.java | 10 +- .../compress/colgroup/scheme/SDCSchemeSC.java | 32 +- .../compress/estim/EstimationFactors.java | 2 +- .../runtime/compress/io/ReaderCompressed.java | 78 +- .../runtime/compress/io/WriterCompressed.java | 51 +- .../compress/lib/CLALibBinaryCellOp.java | 9 + .../runtime/compress/lib/CLALibCompAgg.java | 2 +- .../compress/lib/CLALibMatrixMult.java | 6 +- .../runtime/compress/lib/CLALibSeparator.java | 23 +- .../runtime/compress/lib/CLALibSlice.java | 10 +- .../runtime/compress/lib/CLALibStack.java | 48 +- .../sysds/runtime/compress/utils/ACount.java | 5 +- .../runtime/compress/utils/ACountHashMap.java | 498 +++++----- .../compress/utils/CompressRDDClean.java | 10 +- .../compress/utils/DblArrayCountHashMap.java | 11 +- .../utils/DblArrayIntListHashMap.java | 185 ++-- .../compress/utils/DoubleCountHashMap.java | 24 +- .../runtime/compress/utils/IntArrayList.java | 24 +- .../context/SparkExecutionContext.java | 5 +- .../sysds/runtime/frame/data/FrameBlock.java | 29 +- .../runtime/frame/data/columns/Array.java | 2 +- .../frame/data/columns/StringArray.java | 37 +- .../spark/CompressionSPInstruction.java | 27 +- .../sysds/runtime/io/IOUtilFunctions.java | 170 +++- .../runtime/io/ReaderBinaryBlockParallel.java | 6 +- .../runtime/matrix/data/LibMatrixReorg.java | 244 +++-- .../matrix/data/LibMatrixSparseToDense.java | 185 ++++ .../runtime/transform/decode/Decoder.java | 33 +- .../runtime/transform/decode/DecoderBin.java | 33 +- .../transform/decode/DecoderComposite.java | 37 + .../transform/decode/DecoderDummycode.java | 11 +- .../transform/decode/DecoderPassThrough.java | 11 +- .../transform/decode/DecoderRecode.java | 11 +- .../transform/encode/ColumnEncoderBin.java | 34 +- .../encode/ColumnEncoderComposite.java | 3 +- .../encode/ColumnEncoderDummycode.java | 12 +- .../encode/ColumnEncoderPassThrough.java | 15 +- .../transform/encode/CompressedEncode.java | 98 +- .../transform/encode/MultiColumnEncoder.java | 35 +- .../apache/sysds/runtime/util/HDFSTool.java | 57 +- .../sysds/runtime/util/UtilFunctions.java | 12 +- .../micro/FrameCompressedTransform.java | 175 ++++ .../performance/micro/InformationLoss.java | 186 ++++ .../colgroup/ColGroupNegativeTests.java | 4 +- .../compress/colgroup/ColGroupTest.java | 6 +- .../compress/colgroup/CustomColGroupTest.java | 69 ++ .../compress/indexes/IndexesTest.java | 10 +- .../compress/io/IOCompressionTestUtils.java | 4 +- .../test/component/compress/io/IOEmpty.java | 4 +- .../test/component/compress/io/IOSpark.java | 58 +- .../test/component/compress/io/IOTest.java | 6 +- .../compress/offset/CustomOffsetTest.java | 39 + .../compress/util/ArrCountMapTest.java | 2 +- .../compress/util/ArrayListTest.java | 214 ++++ .../component/compress/util/CountTest.java | 16 + .../compress/util/ListHashMapTest.java | 180 ++++ .../federated/FederatedTestUtils.java | 1 - .../TransformCompressedTestLogger.java | 18 +- .../TransformCompressedTestMultiCol.java | 34 +- .../TransformCompressedTestSingleCol.java | 29 +- .../test/component/matrix/SparseCSRTest.java | 2 - .../io/compressed/WriteCompressedTest.java | 3 +- .../transform/TransformApplyUnknownsTest.java | 18 +- 108 files changed, 3623 insertions(+), 1568 deletions(-) create mode 100644 src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixSparseToDense.java create mode 100644 src/test/java/org/apache/sysds/performance/micro/FrameCompressedTransform.java create mode 100644 src/test/java/org/apache/sysds/performance/micro/InformationLoss.java create mode 100644 src/test/java/org/apache/sysds/test/component/compress/colgroup/CustomColGroupTest.java create mode 100644 src/test/java/org/apache/sysds/test/component/compress/offset/CustomOffsetTest.java create mode 100644 src/test/java/org/apache/sysds/test/component/compress/util/ArrayListTest.java create mode 100644 src/test/java/org/apache/sysds/test/component/compress/util/ListHashMapTest.java diff --git a/scripts/builtin/auc.dml b/scripts/builtin/auc.dml index 7a874034194..8b05456fcbd 100644 --- a/scripts/builtin/auc.dml +++ b/scripts/builtin/auc.dml @@ -19,7 +19,7 @@ # #------------------------------------------------------------- -# This builting function computes the area under the ROC curve (AUC) +# This builtin function computes the area under the ROC curve (AUC) # for binary classifiers. # # INPUT: diff --git a/src/main/java/org/apache/sysds/conf/DMLConfig.java b/src/main/java/org/apache/sysds/conf/DMLConfig.java index 767026a1614..c5ca98d3a8e 100644 --- a/src/main/java/org/apache/sysds/conf/DMLConfig.java +++ b/src/main/java/org/apache/sysds/conf/DMLConfig.java @@ -70,6 +70,7 @@ public class DMLConfig public static final String DEFAULT_BLOCK_SIZE = "sysds.defaultblocksize"; public static final String CP_PARALLEL_OPS = "sysds.cp.parallel.ops"; public static final String CP_PARALLEL_IO = "sysds.cp.parallel.io"; + public static final String IO_COMPRESSION_CODEC = "sysds.io.compression.encoding"; public static final String PARALLEL_ENCODE = "sysds.parallel.encode"; // boolean: enable multi-threaded transformencode and apply public static final String PARALLEL_ENCODE_STAGED = "sysds.parallel.encode.staged"; public static final String PARALLEL_ENCODE_APPLY_BLOCKS = "sysds.parallel.encode.applyBlocks"; @@ -154,6 +155,7 @@ public class DMLConfig _defaultVals.put(DEFAULT_BLOCK_SIZE, String.valueOf(OptimizerUtils.DEFAULT_BLOCKSIZE) ); _defaultVals.put(CP_PARALLEL_OPS, "true" ); _defaultVals.put(CP_PARALLEL_IO, "true" ); + _defaultVals.put(IO_COMPRESSION_CODEC, "none"); _defaultVals.put(PARALLEL_TOKENIZE, "false"); _defaultVals.put(PARALLEL_TOKENIZE_NUM_BLOCKS, "64"); _defaultVals.put(PARALLEL_ENCODE, "true" ); @@ -463,7 +465,7 @@ public String getConfigInfo() { FLOATING_POINT_PRECISION, GPU_EVICTION_POLICY, LOCAL_SPARK_NUM_THREADS, EVICTION_SHADOW_BUFFERSIZE, GPU_MEMORY_ALLOCATOR, GPU_MEMORY_UTILIZATION_FACTOR, USE_SSL_FEDERATED_COMMUNICATION, DEFAULT_FEDERATED_INITIALIZATION_TIMEOUT, FEDERATED_TIMEOUT, FEDERATED_MONITOR_FREQUENCY, FEDERATED_COMPRESSION, - ASYNC_PREFETCH, ASYNC_SPARK_BROADCAST, ASYNC_SPARK_CHECKPOINT + ASYNC_PREFETCH, ASYNC_SPARK_BROADCAST, ASYNC_SPARK_CHECKPOINT, IO_COMPRESSION_CODEC }; StringBuilder sb = new StringBuilder(); diff --git a/src/main/java/org/apache/sysds/hops/FunctionOp.java b/src/main/java/org/apache/sysds/hops/FunctionOp.java index 367b845532a..e5e66f2ccfb 100644 --- a/src/main/java/org/apache/sysds/hops/FunctionOp.java +++ b/src/main/java/org/apache/sysds/hops/FunctionOp.java @@ -24,12 +24,16 @@ import java.util.List; import org.apache.sysds.api.DMLScript; +import org.apache.sysds.lops.Compression; +import org.apache.sysds.lops.Data; +import org.apache.sysds.lops.DeCompression; import org.apache.sysds.lops.FunctionCallCP; import org.apache.sysds.lops.Lop; import org.apache.sysds.common.Types.ExecType; import org.apache.sysds.parser.DMLProgram; import org.apache.sysds.common.Types.DataType; import org.apache.sysds.common.Types.ValueType; +import org.apache.sysds.runtime.compress.SingletonLookupHashMap; import org.apache.sysds.runtime.controlprogram.Program; import org.apache.sysds.runtime.controlprogram.parfor.opt.CostEstimatorHops; import org.apache.sysds.runtime.meta.DataCharacteristics; @@ -295,18 +299,42 @@ public Lop constructLops() tmp.add( in.constructLops() ); //construct function call - Lop fcall = new FunctionCallCP(tmp, _fnamespace, _fname, _inputNames, _outputNames, _outputHops, _opt, et); + FunctionCallCP fcall = new FunctionCallCP(tmp, _fnamespace, _fname, _inputNames, _outputNames, _outputHops, _opt, et); setLineNumbers(fcall); setLops(fcall); //note: no reblock lop because outputs directly bound + constructAndSetCompressionLopFunctionalIfRequired(et); return getLops(); } + protected void constructAndSetCompressionLopFunctionalIfRequired(ExecType et) { + if((requiresCompression()) && ((FunctionCallCP) getLops()).getFunctionName().equalsIgnoreCase("transformencode")){ // xor + + // Lop matrixOut = lop.getFunctionOutputs().get(0); + Lop compressionInstruction = null; + + if(_compressedWorkloadTree != null) { + SingletonLookupHashMap m = SingletonLookupHashMap.getMap(); + int singletonID = m.put(_compressedWorkloadTree); + compressionInstruction = new Compression(getLops(), DataType.MATRIX, ValueType.FP64, et, singletonID); + } + else + compressionInstruction = new Compression(getLops(), DataType.MATRIX, ValueType.FP64, et, 0); + + + setOutputDimensions( compressionInstruction ); + setLineNumbers( compressionInstruction ); + setLops( compressionInstruction ); + + } + } + + @Override public String getOpString() { - return OPCODE; + return OPCODE + " " + _fnamespace + " " + _fname; } @Override @@ -385,8 +413,4 @@ public boolean compare(Hop that) { return false; } - @Override - public String toString(){ - return getOpString(); - } } diff --git a/src/main/java/org/apache/sysds/hops/Hop.java b/src/main/java/org/apache/sysds/hops/Hop.java index bb46b7c2e38..265ba672e96 100644 --- a/src/main/java/org/apache/sysds/hops/Hop.java +++ b/src/main/java/org/apache/sysds/hops/Hop.java @@ -441,26 +441,18 @@ private void constructAndSetReblockLopIfRequired() if( _requiresReblock && et != ExecType.CP ) { Lop input = getLops(); + Lop reblock = null; - - try - { - if( this instanceof DataOp // CSV - && ((DataOp)this).getOp() == OpOpData.PERSISTENTREAD - && ((DataOp)this).getFileFormat() == FileFormat.CSV ) - { - reblock = new CSVReBlock( input, getBlocksize(), - getDataType(), getValueType(), et); - } - else { //ALL OTHER - reblock = new ReBlock( input, getBlocksize(), - getDataType(), getValueType(), _outputEmptyBlocks, et); - } + if(this instanceof DataOp // CSV + && ((DataOp) this).getOp() == OpOpData.PERSISTENTREAD && + ((DataOp) this).getFileFormat() == FileFormat.CSV) { + reblock = new CSVReBlock(input, getBlocksize(), getDataType(), getValueType(), et); } - catch( LopsException ex ) { - throw new HopsException(ex); + else { // ALL OTHER + reblock = new ReBlock(input, getBlocksize(), getDataType(), getValueType(), _outputEmptyBlocks, et); } - + + // replace this lop with the reblock instruction setOutputDimensions(reblock); setLineNumbers(reblock); setPrivacy(reblock); @@ -513,43 +505,40 @@ else if( !dimsKnown(true) ) { } } - private void constructAndSetCompressionLopIfRequired() { - if((requiresCompression() && ! hasCompressedInput()) ^ _requiresDeCompression){ // xor + protected void constructAndSetCompressionLopIfRequired() { + if((requiresCompression()) ^ _requiresDeCompression){ // xor ExecType et = getExecutionModeForCompression(); Lop compressionInstruction = null; - try{ - if( requiresCompression() ){ - if(_compressedWorkloadTree != null){ - SingletonLookupHashMap m = SingletonLookupHashMap.getMap(); - int singletonID = m.put(_compressedWorkloadTree); - compressionInstruction = new Compression(getLops(), getDataType(), getValueType(), et, singletonID); - } - else - compressionInstruction = new Compression(getLops(), getDataType(), getValueType(), et, 0); + + if(requiresCompression()) { + if(_compressedWorkloadTree != null) { + SingletonLookupHashMap m = SingletonLookupHashMap.getMap(); + int singletonID = m.put(_compressedWorkloadTree); + compressionInstruction = new Compression(getLops(), getDataType(), getValueType(), et, singletonID); } - else if( _requiresDeCompression && et != ExecType.SPARK ) // Disabled spark decompression instruction. - compressionInstruction = new DeCompression(getLops(), getDataType(), getValueType(), et); else - return; - } - catch (LopsException ex) { - throw new HopsException(ex); + compressionInstruction = new Compression(getLops(), getDataType(), getValueType(), et, 0); } + else if(_requiresDeCompression && et != ExecType.SPARK) // Disabled spark decompression instruction. + compressionInstruction = new DeCompression(getLops(), getDataType(), getValueType(), et); + else + return; + setOutputDimensions( compressionInstruction ); setLineNumbers( compressionInstruction ); setLops( compressionInstruction ); } } - private ExecType getExecutionModeForCompression(){ + protected ExecType getExecutionModeForCompression(){ ExecType et = ExecType.CP; // conditional checkpoint based on memory estimate in order to avoid unnecessary // persist and unpersist calls (4x the memory budget is conservative) if( OptimizerUtils.isSparkExecutionMode() && getDataType()!=DataType.SCALAR ) if( OptimizerUtils.isHybridExecutionMode() && 2 * _outputMemEstimate < OptimizerUtils.getLocalMemBudget() - || _etypeForced == ExecType.CP ) + || _etypeForced == ExecType.CP || getLops().isExecCP() ) et = ExecType.CP; else et = ExecType.SPARK; @@ -1050,8 +1039,12 @@ protected final ExecType optFindExecType() { public abstract String getOpString(); @Override - public String toString(){ - return super.getClass().getSimpleName() + " " + getOpString(); + public final String toString() { + StringBuilder sb = new StringBuilder(); + sb.append(getClass().getSimpleName()); + sb.append(" "); + sb.append(getOpString()); + return sb.toString(); } // ======================================================================================== diff --git a/src/main/java/org/apache/sysds/hops/LiteralOp.java b/src/main/java/org/apache/sysds/hops/LiteralOp.java index 5d3f06bd663..0feb0525381 100644 --- a/src/main/java/org/apache/sysds/hops/LiteralOp.java +++ b/src/main/java/org/apache/sysds/hops/LiteralOp.java @@ -287,9 +287,4 @@ public boolean compare( Hop that ) { return false; } - - @Override - public String toString(){ - return getOpString(); - } } diff --git a/src/main/java/org/apache/sysds/hops/UnaryOp.java b/src/main/java/org/apache/sysds/hops/UnaryOp.java index 72f63e99ed3..5dbb55a3035 100644 --- a/src/main/java/org/apache/sysds/hops/UnaryOp.java +++ b/src/main/java/org/apache/sysds/hops/UnaryOp.java @@ -21,6 +21,7 @@ import java.util.ArrayList; +import org.apache.commons.lang.NotImplementedException; import org.apache.sysds.api.DMLScript; import org.apache.sysds.common.Types.AggOp; import org.apache.sysds.common.Types.DataType; @@ -158,12 +159,14 @@ else if(_op == OpOp1.MEDIAN) } else { // general case MATRIX ExecType et = optFindExecType(); - // special handling cumsum/cumprod/cummin/cumsum if(isCumulativeUnaryOperation() && !(et == ExecType.CP || et == ExecType.GPU)) { // TODO additional physical operation if offsets fit in memory ret = constructLopsSparkCumulativeUnary(); } + else if(_op == OpOp1.CAST_AS_FRAME && getInput().size() == 2) { + throw new NotImplementedException(); + } else {// default unary final boolean inplace = OptimizerUtils.ALLOW_UNARY_UPDATE_IN_PLACE && input.getParent().size() == 1 && (!(input instanceof DataOp) || !((DataOp) input).isRead()); diff --git a/src/main/java/org/apache/sysds/hops/rewrite/RewriteCompressedReblock.java b/src/main/java/org/apache/sysds/hops/rewrite/RewriteCompressedReblock.java index cf353a7eca7..8dd323dd44e 100644 --- a/src/main/java/org/apache/sysds/hops/rewrite/RewriteCompressedReblock.java +++ b/src/main/java/org/apache/sysds/hops/rewrite/RewriteCompressedReblock.java @@ -27,6 +27,7 @@ import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.apache.sysds.common.Types.AggOp; +import org.apache.sysds.common.Types.DataType; import org.apache.sysds.common.Types.OpOp1; import org.apache.sysds.common.Types.OpOp2; import org.apache.sysds.common.Types.OpOp3; @@ -35,6 +36,7 @@ import org.apache.sysds.hops.AggBinaryOp; import org.apache.sysds.hops.FunctionOp; import org.apache.sysds.hops.Hop; +import org.apache.sysds.hops.LiteralOp; import org.apache.sysds.hops.OptimizerUtils; import org.apache.sysds.lops.Compression.CompressConfig; import org.apache.sysds.parser.DMLProgram; @@ -92,12 +94,18 @@ public List rewriteStatementBlocks(List sbs, Pro } private static void injectCompressionDirective(Hop hop, CompressConfig compress, DMLProgram prog) { - if(hop.isVisited() || hop.requiresCompression() || hop.hasCompressedInput()) + if(hop.isVisited() || hop.requiresCompression()) return; - // recursively process children + // recursion for inputs. for(Hop hi : hop.getInput()) injectCompressionDirective(hi, compress, prog); + + // filter candidates. + if((HopRewriteUtils.isData(hop, OpOpData.PERSISTENTREAD, DataType.SCALAR))// + || HopRewriteUtils.isData(hop, OpOpData.TRANSIENTREAD, OpOpData.TRANSIENTWRITE)// + || hop instanceof LiteralOp) + return; // check for compression conditions switch(compress) { @@ -128,29 +136,41 @@ public static boolean satisfiesSizeConstraintsForCompression(Hop hop) { if(hop.getDim2() >= 1) { final long x = hop.getDim1(); final long y = hop.getDim2(); - return + final boolean ret = // If the Cube of the number of rows is greater than multiplying the number of columns by 1024. y << 10 <= x * x // is very sparse and at least 100 rows. || (hop.getSparsity() < 0.0001 && y > 100); + return ret; + } + else if(hop.getDim1() >= 1){ + // known rows. but not cols; + boolean ret = hop.getDim1() > 10000; + return ret; + } + else{ + return true; // unknown dimensions lets always try. } - return false; } public static boolean satisfiesCompressionCondition(Hop hop) { boolean satisfies = false; - if(satisfiesSizeConstraintsForCompression(hop)) + if(satisfiesSizeConstraintsForCompression(hop)){ satisfies |= HopRewriteUtils.isData(hop, OpOpData.PERSISTENTREAD); - + satisfies |= HopRewriteUtils.isTransformEncode(hop); + } return satisfies; } public static boolean satisfiesAggressiveCompressionCondition(Hop hop) { //size-independent conditions (robust against unknowns) - boolean satisfies = HopRewriteUtils.isTernary(hop, OpOp3.CTABLE) //matrix (no vector) ctable - && hop.getInput(0).getDataType().isMatrix() && hop.getInput(1).getDataType().isMatrix(); + boolean satisfies = false; //size-dependent conditions if(satisfiesSizeConstraintsForCompression(hop)) { + //matrix (no vector) ctable + satisfies |= HopRewriteUtils.isTernary(hop, OpOp3.CTABLE) + && hop.getInput(0).getDataType().isMatrix() + && hop.getInput(1).getDataType().isMatrix(); satisfies |= HopRewriteUtils.isData(hop, OpOpData.PERSISTENTREAD); satisfies |= HopRewriteUtils.isUnary(hop, OpOp1.ROUND, OpOp1.FLOOR, OpOp1.NOT, OpOp1.CEIL); satisfies |= HopRewriteUtils.isBinary(hop, OpOp2.EQUAL, OpOp2.NOTEQUAL, OpOp2.LESS, diff --git a/src/main/java/org/apache/sysds/hops/rewrite/RewriteRemovePersistentReadWrite.java b/src/main/java/org/apache/sysds/hops/rewrite/RewriteRemovePersistentReadWrite.java index 9265b2fc35e..e0d9033add6 100644 --- a/src/main/java/org/apache/sysds/hops/rewrite/RewriteRemovePersistentReadWrite.java +++ b/src/main/java/org/apache/sysds/hops/rewrite/RewriteRemovePersistentReadWrite.java @@ -100,8 +100,7 @@ private void rule_RemovePersistentDataOp( Hop hop ) { rule_RemovePersistentDataOp( inputs.get(i) ); //remove cast if unnecessary - if( hop instanceof DataOp ) - { + if( hop instanceof DataOp ) { DataOp dop = (DataOp) hop; OpOpData dotype = dop.getOp(); diff --git a/src/main/java/org/apache/sysds/lops/Compression.java b/src/main/java/org/apache/sysds/lops/Compression.java index 6871dc8771b..e636fc8b848 100644 --- a/src/main/java/org/apache/sysds/lops/Compression.java +++ b/src/main/java/org/apache/sysds/lops/Compression.java @@ -61,9 +61,21 @@ public String getInstructions(String input1, String output) { sb.append(Lop.OPERAND_DELIMITOR); sb.append(OPCODE); sb.append(OPERAND_DELIMITOR); - sb.append(getInputs().get(0).prepInputOperand(input1)); + if(getInputs().get(0) instanceof FunctionCallCP && + ((FunctionCallCP)getInputs().get(0)).getFunctionName().equalsIgnoreCase("transformencode") ){ + sb.append(getInputs().get(0).getOutputs().get(0).getOutputParameters().getLabel()); + } + else{ + sb.append(getInputs().get(0).prepInputOperand(input1)); + } sb.append(OPERAND_DELIMITOR); - sb.append(prepOutputOperand(output)); + if(getInputs().get(0) instanceof FunctionCallCP && + ((FunctionCallCP)getInputs().get(0)).getFunctionName().equalsIgnoreCase("transformencode") ){ + sb.append(getInputs().get(0).getOutputs().get(0).getOutputParameters().getLabel()); + } + else{ + sb.append(prepOutputOperand(output)); + } if(_singletonLookupKey != 0){ sb.append(OPERAND_DELIMITOR); sb.append(_singletonLookupKey); diff --git a/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlockFactory.java b/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlockFactory.java index 74b6cc8e0a8..cc5f5465fb4 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlockFactory.java +++ b/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlockFactory.java @@ -396,24 +396,28 @@ private void transposeHeuristics() { compSettings.transposed = false; break; default: - if(mb.isInSparseFormat()) { - if(mb.getNumColumns() > 10000) - // many sparse columns we have to... - compSettings.transposed = true; - else if(mb.getNonZeros() < 1000) - // low nnz trivial to transpose - compSettings.transposed = true; - else { - // is enough rows to make it usable - boolean isAboveRowNumbers = mb.getNumRows() > 500000; - // Make sure that it is not more efficient to extract the rows. - boolean isAboveThreadToColumnRatio = compressionGroups.getNumberColGroups() > mb.getNumColumns() / 30; - compSettings.transposed = isAboveRowNumbers && isAboveThreadToColumnRatio; - } - } - else - compSettings.transposed = false; + compSettings.transposed = transposeHeuristics(compressionGroups.getNumberColGroups() , mb); + } + } + + public static boolean transposeHeuristics(int nGroups, MatrixBlock mb) { + if(mb.isInSparseFormat()) { + if(mb.getNumColumns() > 10000 || mb.getNumRows() > 10000) + // many sparse columns or rows we have to... + return true; + else if(mb.getNonZeros() < 1000) + // low nnz trivial to transpose + return true; + else { + // is enough rows to make it usable + boolean isAboveRowNumbers = mb.getNumRows() > 500000; + // Make sure that it is not more efficient to extract the rows. + boolean isAboveThreadToColumnRatio = nGroups > mb.getNumColumns() / 30; + return isAboveRowNumbers && isAboveThreadToColumnRatio; + } } + else + return false; } private void compressPhase() { diff --git a/src/main/java/org/apache/sysds/runtime/compress/bitmap/BitmapEncoder.java b/src/main/java/org/apache/sysds/runtime/compress/bitmap/BitmapEncoder.java index 0dc3e0e754b..96bfe7a2fb0 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/bitmap/BitmapEncoder.java +++ b/src/main/java/org/apache/sysds/runtime/compress/bitmap/BitmapEncoder.java @@ -73,7 +73,8 @@ public static ABitmap extractBitmap(IColIndex colIndices, MatrixBlock rawBlock, final int numRows = transposed ? rawBlock.getNumColumns() : rawBlock.getNumRows(); final int estimatedNumber = Math.max(estimatedNumberOfUniques, 8); if(colIndices.size() == 1) - return extractBitmapSingleColumn(colIndices.get(0), rawBlock, numRows, transposed, estimatedNumber, sortedEntries); + return extractBitmapSingleColumn(colIndices.get(0), rawBlock, numRows, transposed, estimatedNumber, + sortedEntries); else return extractBitmapMultiColumns(colIndices, rawBlock, numRows, transposed, estimatedNumber, sortedEntries); } @@ -165,10 +166,18 @@ private static ABitmap extractBitmapMultiColumns(IColIndex colIndices, MatrixBlo boolean transposed, int estimatedUnique, boolean sort) { final DblArrayIntListHashMap map = new DblArrayIntListHashMap(estimatedUnique); final ReaderColumnSelection reader = ReaderColumnSelection.createReader(rawBlock, colIndices, transposed); - DblArray cellVals = null; - while((cellVals = reader.nextRow()) != null) - map.appendValue(cellVals, reader.getCurrentRowIndex()); + try { + DblArray empty = new DblArray(new double[colIndices.size()]); + while((cellVals = reader.nextRow()) != null) { + if(!cellVals.equals(empty)) + map.appendValue(cellVals, reader.getCurrentRowIndex()); + } + + } + catch(Exception e) { + throw new RuntimeException("failed extracting bitmap and adding. " + map + " \n " + cellVals, e); + } return makeMultiColBitmap(map, numRows, colIndices.size(), sort); } diff --git a/src/main/java/org/apache/sysds/runtime/compress/cocode/CoCodeGreedy.java b/src/main/java/org/apache/sysds/runtime/compress/cocode/CoCodeGreedy.java index e2c2252018b..f8fe0287542 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/cocode/CoCodeGreedy.java +++ b/src/main/java/org/apache/sysds/runtime/compress/cocode/CoCodeGreedy.java @@ -32,6 +32,7 @@ import org.apache.sysds.runtime.compress.estim.AComEst; import org.apache.sysds.runtime.compress.estim.CompressedSizeInfo; import org.apache.sysds.runtime.compress.estim.CompressedSizeInfoColGroup; +import org.apache.sysds.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer; import org.apache.sysds.runtime.util.CommonThreadPool; public class CoCodeGreedy extends AColumnCoCoder { @@ -43,8 +44,7 @@ protected CoCodeGreedy(AComEst sizeEstimator, ACostEstimate costEstimator, Compr mem = new Memorizer(sizeEstimator); } - protected CoCodeGreedy(AComEst sizeEstimator, ACostEstimate costEstimator, CompressionSettings cs, - Memorizer mem) { + protected CoCodeGreedy(AComEst sizeEstimator, ACostEstimate costEstimator, CompressionSettings cs, Memorizer mem) { super(sizeEstimator, costEstimator, cs); this.mem = mem; } @@ -64,7 +64,7 @@ protected List combine(List coCodeBruteForce(List inputColumns, int k) { final List workSet = new ArrayList<>(inputColumns.size()); - + k = k <= 0 ? InfrastructureAnalyzer.getLocalParallelism() : k; final ExecutorService pool = CommonThreadPool.get(k); for(int i = 0; i < inputColumns.size(); i++) { CompressedSizeInfoColGroup g = inputColumns.get(i); diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/AColGroup.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/AColGroup.java index 471fbbcc135..a4030d95612 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/AColGroup.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/AColGroup.java @@ -178,7 +178,9 @@ public final void decompressToDenseBlock(DenseBlock db, int rl, int ru) { * @throws IOException if IOException occurs */ protected void write(DataOutput out) throws IOException { - out.writeByte(getColGroupType().ordinal()); + final byte[] o = new byte[1]; + o[0] = (byte) getColGroupType().ordinal(); + out.write(o); _colIndexes.write(out); } @@ -622,10 +624,12 @@ public AColGroup addVector(double[] v) { * If it is not possible or very inefficient null is returned. * * @param groups The groups to combine. + * @param blen The normal number of rows in the groups + * @param rlen The total number of rows of all combined. * @return A combined column group or null */ - public static AColGroup appendN(AColGroup[] groups) { - return groups[0].appendNInternal(groups); + public static AColGroup appendN(AColGroup[] groups, int blen, int rlen) { + return groups[0].appendNInternal(groups, blen, rlen); } /** @@ -636,9 +640,11 @@ public static AColGroup appendN(AColGroup[] groups) { * If it is not possible or very inefficient null is returned. * * @param groups The groups to combine. + * @param blen The normal number of rows in the groups + * @param rlen The total number of rows of all combined. * @return A combined column group or null */ - protected abstract AColGroup appendNInternal(AColGroup[] groups); + protected abstract AColGroup appendNInternal(AColGroup[] groups, int blen, int rlen); /** * Get the compression scheme for this column group to enable compression of other data. @@ -713,7 +719,7 @@ public AColGroup sortColumnIndexes() { @Override public String toString() { StringBuilder sb = new StringBuilder(); - sb.append(String.format("%15s", "Type: ")); + sb.append(String.format("\n%15s", "Type: ")); sb.append(this.getClass().getSimpleName()); sb.append(String.format("\n%15s", "Columns: ")); sb.append(_colIndexes); diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupConst.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupConst.java index 493a2a71192..0ef7a423503 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupConst.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupConst.java @@ -22,6 +22,7 @@ import java.io.DataInput; import java.io.IOException; +import org.apache.commons.lang3.NotImplementedException; import org.apache.sysds.runtime.compress.DMLCompressionException; import org.apache.sysds.runtime.compress.colgroup.dictionary.Dictionary; import org.apache.sysds.runtime.compress.colgroup.dictionary.DictionaryFactory; @@ -30,6 +31,10 @@ import org.apache.sysds.runtime.compress.colgroup.dictionary.MatrixBlockDictionary; import org.apache.sysds.runtime.compress.colgroup.indexes.ColIndexFactory; import org.apache.sysds.runtime.compress.colgroup.indexes.IColIndex; +import org.apache.sysds.runtime.compress.colgroup.mapping.AMapToData; +import org.apache.sysds.runtime.compress.colgroup.mapping.MapToFactory; +import org.apache.sysds.runtime.compress.colgroup.offset.AOffset; +import org.apache.sysds.runtime.compress.colgroup.offset.OffsetEmpty; import org.apache.sysds.runtime.compress.colgroup.scheme.ConstScheme; import org.apache.sysds.runtime.compress.colgroup.scheme.ICLAScheme; import org.apache.sysds.runtime.compress.cost.ComputationCostEstimator; @@ -48,7 +53,7 @@ import org.apache.sysds.runtime.matrix.operators.ScalarOperator; import org.apache.sysds.runtime.matrix.operators.UnaryOperator; -public class ColGroupConst extends ADictBasedColGroup implements IContainDefaultTuple { +public class ColGroupConst extends ADictBasedColGroup implements IContainDefaultTuple, AOffsetsGroup, IMapToDataGroup { private static final long serialVersionUID = -7387793538322386611L; @@ -566,10 +571,22 @@ public AColGroup append(AColGroup g) { } @Override - public AColGroup appendNInternal(AColGroup[] g) { - for(int i = 0; i < g.length; i++) - if(!_colIndexes.equals(g[i]._colIndexes) || !this._dict.equals(((ColGroupConst) g[i])._dict)) - return null; + public AColGroup appendNInternal(AColGroup[] g, int blen, int rlen) { + for(int i = 0; i < g.length; i++) { + final AColGroup gs = g[i]; + if(!_colIndexes.equals(gs._colIndexes)) + throw new DMLCompressionException("Invalid columns not matching " + gs._colIndexes + " " + _colIndexes); + if(gs instanceof ColGroupConst) { + if(this._dict.equals(((ColGroupConst) gs)._dict)) + continue; // common case + else + throw new NotImplementedException("Appending const not equivalent"); + } + else if(gs instanceof ColGroupEmpty) + throw new NotImplementedException("Appending empty and const"); + else + return gs.appendNInternal(g, blen, rlen); + } return this; } @@ -614,9 +631,20 @@ protected AColGroup fixColIndexes(IColIndex newColIndex, int[] reordering) { public String toString() { StringBuilder sb = new StringBuilder(); sb.append(super.toString()); - sb.append(String.format("\n%15s", "Values: " + _dict.getClass().getSimpleName())); + sb.append(String.format("\n%15s", "Values: ")); + sb.append(_dict.getClass().getSimpleName()); sb.append(_dict.getString(_colIndexes.size())); return sb.toString(); } + + @Override + public AOffset getOffsets() { + return new OffsetEmpty(); + } + + @Override + public AMapToData getMapToData() { + return MapToFactory.create(0, 0); + } } diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupDDC.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupDDC.java index 26e33178c29..8f5fccaf7d6 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupDDC.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupDDC.java @@ -544,7 +544,7 @@ public AColGroup append(AColGroup g) { } @Override - public AColGroup appendNInternal(AColGroup[] g) { + public AColGroup appendNInternal(AColGroup[] g, int blen, int rlen) { for(int i = 1; i < g.length; i++) { if(!_colIndexes.equals(g[i]._colIndexes)) { LOG.warn("Not same columns therefore not appending DDC\n" + _colIndexes + "\n\n" + g[i]._colIndexes); diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupDDCFOR.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupDDCFOR.java index de361123f38..85f48ae5b6f 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupDDCFOR.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupDDCFOR.java @@ -452,8 +452,8 @@ public AColGroup append(AColGroup g) { } @Override - public AColGroup appendNInternal(AColGroup[] g) { - return null; + public AColGroup appendNInternal(AColGroup[] g, int blen, int rlen) { + throw new NotImplementedException(); } @Override diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupEmpty.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupEmpty.java index bb4a63afbbe..a8d8e6840e3 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupEmpty.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupEmpty.java @@ -24,11 +24,16 @@ import java.util.Arrays; import org.apache.sysds.runtime.DMLRuntimeException; +import org.apache.sysds.runtime.compress.DMLCompressionException; import org.apache.sysds.runtime.compress.colgroup.dictionary.ADictionary; import org.apache.sysds.runtime.compress.colgroup.dictionary.Dictionary; import org.apache.sysds.runtime.compress.colgroup.indexes.ColIndexFactory; import org.apache.sysds.runtime.compress.colgroup.indexes.IColIndex; import org.apache.sysds.runtime.compress.colgroup.indexes.IIterate; +import org.apache.sysds.runtime.compress.colgroup.mapping.AMapToData; +import org.apache.sysds.runtime.compress.colgroup.mapping.MapToFactory; +import org.apache.sysds.runtime.compress.colgroup.offset.AOffset; +import org.apache.sysds.runtime.compress.colgroup.offset.OffsetEmpty; import org.apache.sysds.runtime.compress.colgroup.scheme.EmptyScheme; import org.apache.sysds.runtime.compress.colgroup.scheme.ICLAScheme; import org.apache.sysds.runtime.compress.cost.ComputationCostEstimator; @@ -47,7 +52,8 @@ import org.apache.sysds.runtime.matrix.operators.ScalarOperator; import org.apache.sysds.runtime.matrix.operators.UnaryOperator; -public class ColGroupEmpty extends AColGroupCompressed implements IContainADictionary, IContainDefaultTuple { +public class ColGroupEmpty extends AColGroupCompressed + implements IContainADictionary, IContainDefaultTuple, AOffsetsGroup ,IMapToDataGroup{ private static final long serialVersionUID = -2307677253622099958L; /** @@ -332,10 +338,16 @@ public AColGroup append(AColGroup g) { } @Override - public AColGroup appendNInternal(AColGroup[] g) { - for(int i = 0; i < g.length; i++) - if(!_colIndexes.equals(g[i]._colIndexes)) - return null; + public AColGroup appendNInternal(AColGroup[] g, int blen, int rlen) { + for(int i = 0; i < g.length; i++) { + final AColGroup gs = g[i]; + if(!_colIndexes.equals(gs._colIndexes)) + throw new DMLCompressionException("Invalid columns not matching " + gs._colIndexes + " " + _colIndexes); + if(gs instanceof ColGroupEmpty) + continue; + else + return gs.appendNInternal(g, blen, rlen); + } return this; } @@ -381,4 +393,14 @@ protected AColGroup fixColIndexes(IColIndex newColIndex, int[] reordering) { return new ColGroupEmpty(newColIndex); } + @Override + public AOffset getOffsets() { + return new OffsetEmpty(); + } + + @Override + public AMapToData getMapToData() { + return MapToFactory.create(0, 0); + } + } diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupFactory.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupFactory.java index 72934b19963..23ba7d6fc4c 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupFactory.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupFactory.java @@ -59,6 +59,7 @@ import org.apache.sysds.runtime.compress.utils.DblArrayCountHashMap; import org.apache.sysds.runtime.compress.utils.DoubleCountHashMap; import org.apache.sysds.runtime.compress.utils.IntArrayList; +import org.apache.sysds.runtime.compress.utils.Util; import org.apache.sysds.runtime.controlprogram.parfor.stat.Timing; import org.apache.sysds.runtime.data.DenseBlock; import org.apache.sysds.runtime.data.SparseBlock; @@ -225,12 +226,12 @@ private void logEstVsActual(double time, AColGroup act, CompressedSizeInfoColGro if(estC < actC * 0.75) { String warning = "The estimate cost is significantly off : " + est; LOG.debug( - String.format("time[ms]: %10.2f %25s est %10.0f -- act %10.0f distinct:%5d cols:%s wanted:%s\n\t\t%s", - time, retType, estC, actC, act.getNumValues(), cols, wanted, warning)); + String.format("time[ms]: %10.2f %25s est %10.0f -- act %10.0f distinct:%5d cols:%s wanted:%s\n\t\t%s", time, + retType, estC, actC, act.getNumValues(), cols, wanted, warning)); } else { - LOG.debug(String.format("time[ms]: %10.2f %25s est %10.0f -- act %10.0f distinct:%5d cols:%s wanted:%s", - time, retType, estC, actC, act.getNumValues(), cols, wanted)); + LOG.debug(String.format("time[ms]: %10.2f %25s est %10.0f -- act %10.0f distinct:%5d cols:%s wanted:%s", time, + retType, estC, actC, act.getNumValues(), cols, wanted)); } } @@ -277,6 +278,9 @@ else if(ct == CompressionType.DDCFOR) { return ColGroupDDCFOR.sparsifyFOR((ColGroupDDC) g); return g; } + else if(ct == CompressionType.SDC && colIndexes.size() == 1 && !t) { + return compressSDCSingleColDirectBlock(colIndexes, cg.getNumVals()); + } final ABitmap ubm = BitmapEncoder.extractBitmap(colIndexes, in, cg.getNumVals(), cs); if(ubm == null) // no values ... therefore empty @@ -308,6 +312,143 @@ else if(ct == CompressionType.DDCFOR) { } } + private AColGroup compressSDCSingleColDirectBlock(IColIndex colIndexes, int nVal) { + final DoubleCountHashMap cMap = new DoubleCountHashMap(nVal); + final int col = colIndexes.get(0); + + countElements(cMap, col); + + double def = cMap.getMostFrequent(); + final int dictSize = cMap.size() - 1; + if(dictSize == 0) + return ColGroupConst.create(colIndexes, def); + + int defCount = cMap.getC(def).count; + cMap.replaceWithUIDs(def); + IDictionary dict = Dictionary.create(cMap.getDictionary(dictSize)); + IntArrayList offs = new IntArrayList(nRow - defCount); + AMapToData map = MapToFactory.create(nRow - defCount, dictSize); + getOffsets(offs, map, cMap, col, def); + + AOffset aoff = OffsetFactory.createOffset(offs); + + return ColGroupSDC.create(colIndexes, nRow, dict, new double[] {def}, aoff, map, null); + + } + + private void getOffsets(IntArrayList offs, AMapToData map, DoubleCountHashMap cMap, int col, double def) { + + if(in.isInSparseFormat()) { + if(def == 0) { + final SparseBlock sb = in.getSparseBlock(); + for(int r = 0; r < nRow; r++) { + if(sb.isEmpty(r)) + continue; + + final int apos = sb.pos(r); + final int alen = sb.size(r) + apos; + final int[] aix = sb.indexes(r); + final int idx = Arrays.binarySearch(aix, apos, alen, col); + if(!(idx < 0)) { + map.set(offs.size(), cMap.getId(sb.values(r)[idx])); + offs.appendValue(r); + } + } + } + else { + + final SparseBlock sb = in.getSparseBlock(); + for(int r = 0; r < nRow; r++) { + if(sb.isEmpty(r)) { + + map.set(offs.size(), cMap.getId(0.0)); + offs.appendValue(r); + } + else { + final int apos = sb.pos(r); + final int alen = sb.size(r) + apos; + final int[] aix = sb.indexes(r); + final int idx = Arrays.binarySearch(aix, apos, alen, col); + if(idx < 0) { + map.set(offs.size(), cMap.getId(0.0)); + offs.appendValue(r); + } + else { + double v = sb.values(r)[idx]; + if(!Util.eq(sb.values(r)[idx], def)) { + map.set(offs.size(), cMap.getId(v)); + offs.appendValue(r); + } + } + } + } + } + } + else if(in.getDenseBlock().isContiguous()) { + final double[] dv = in.getDenseBlockValues(); + int off = col; + for(int r = 0; r < nRow; r++, off += nCol) + if(!Util.eq(dv[off], def)) { + map.set(offs.size(), cMap.getId(dv[off])); + offs.appendValue(r); + } + } + else { + final DenseBlock db = in.getDenseBlock(); + for(int r = 0; r < nRow; r++) { + final double[] dv = db.values(r); + int off = db.pos(r) + col; + if(!Util.eq(dv[off], def)) { + map.set(offs.size(), cMap.getId(dv[off])); + offs.appendValue(r); + } + } + } + } + + private void countElements(DoubleCountHashMap map, int col) { + if(in.isInSparseFormat()) + countElementsSparse(map, col); + else if(in.getDenseBlock().isContiguous()) + countElementsDenseContiguous(map, col); + else + countElementsDenseGeneric(map, col); + } + + private void countElementsSparse(DoubleCountHashMap map, int col) { + final SparseBlock sb = in.getSparseBlock(); + for(int r = 0; r < nRow; r++) { + if(sb.isEmpty(r)) + map.increment(0.0); + else { + final int apos = sb.pos(r); + final int alen = sb.size(r) + apos; + final int[] aix = sb.indexes(r); + final int idx = Arrays.binarySearch(aix, apos, alen, col); + if(idx < 0) + map.increment(0.0); + else + map.increment(sb.values(r)[idx]); + } + } + } + + private void countElementsDenseContiguous(DoubleCountHashMap map, int col) { + final double[] dv = in.getDenseBlockValues(); + int off = col; + for(int r = 0; r < nRow; r++, off += nCol) + map.increment(dv[off]); + } + + private void countElementsDenseGeneric(DoubleCountHashMap map, int col) { + final DenseBlock db = in.getDenseBlock(); + for(int r = 0; r < nRow; r++) { + final double[] dv = db.values(r); + int off = db.pos(r) + col; + map.increment(dv[off]); + } + } + private AColGroup directCompressDDC(IColIndex colIndexes, CompressedSizeInfoColGroup cg) throws Exception { if(colIndexes.size() > 1) return directCompressDDCMultiCol(colIndexes, cg); diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupLinearFunctional.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupLinearFunctional.java index e55baf5718f..708d3512f53 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupLinearFunctional.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupLinearFunctional.java @@ -674,8 +674,8 @@ public AColGroup append(AColGroup g) { } @Override - public AColGroup appendNInternal(AColGroup[] g) { - return null; + public AColGroup appendNInternal(AColGroup[] g, int blen, int rlen) { + throw new NotImplementedException(); } @Override diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupOLE.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupOLE.java index 631da3edd17..8af0f959e0c 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupOLE.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupOLE.java @@ -665,8 +665,8 @@ public AColGroup append(AColGroup g) { } @Override - public AColGroup appendNInternal(AColGroup[] g) { - return null; + public AColGroup appendNInternal(AColGroup[] g, int blen, int rlen) { + throw new NotImplementedException(); } @Override diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupRLE.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupRLE.java index d1c62387f1f..23596c1e190 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupRLE.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupRLE.java @@ -977,8 +977,8 @@ public AColGroup append(AColGroup g) { } @Override - public AColGroup appendNInternal(AColGroup[] g) { - return null; + public AColGroup appendNInternal(AColGroup[] g, int blen, int rlen) { + throw new NotImplementedException(); } @Override diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDC.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDC.java index 91e64468b9c..476e86c9730 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDC.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDC.java @@ -26,10 +26,11 @@ import org.apache.sysds.runtime.DMLRuntimeException; import org.apache.sysds.runtime.compress.DMLCompressionException; -import org.apache.sysds.runtime.compress.colgroup.dictionary.IDictionary; import org.apache.sysds.runtime.compress.colgroup.dictionary.Dictionary; import org.apache.sysds.runtime.compress.colgroup.dictionary.DictionaryFactory; +import org.apache.sysds.runtime.compress.colgroup.dictionary.IDictionary; import org.apache.sysds.runtime.compress.colgroup.dictionary.MatrixBlockDictionary; +import org.apache.sysds.runtime.compress.colgroup.dictionary.PlaceHolderDict; import org.apache.sysds.runtime.compress.colgroup.indexes.ColIndexFactory; import org.apache.sysds.runtime.compress.colgroup.indexes.IColIndex; import org.apache.sysds.runtime.compress.colgroup.mapping.AMapToData; @@ -67,6 +68,8 @@ public class ColGroupSDC extends ASDC implements IMapToDataGroup { protected ColGroupSDC(IColIndex colIndices, int numRows, IDictionary dict, double[] defaultTuple, AOffset offsets, AMapToData data, int[] cachedCounts) { super(colIndices, numRows, dict, offsets, cachedCounts); + _data = data; + _defaultTuple = defaultTuple; if(data.getUnique() != dict.getNumberOfValues(colIndices.size())) { if(data.getUnique() != data.getMax()) throw new DMLCompressionException( @@ -77,8 +80,6 @@ protected ColGroupSDC(IColIndex colIndices, int numRows, IDictionary dict, doubl if(defaultTuple.length != colIndices.size()) throw new DMLCompressionException("Invalid construction of SDC group"); - _data = data; - _defaultTuple = defaultTuple; } public static AColGroup create(IColIndex colIndices, int numRows, IDictionary dict, double[] defaultTuple, @@ -596,30 +597,33 @@ public AColGroup append(AColGroup g) { } @Override - public AColGroup appendNInternal(AColGroup[] g) { - int sumRows = getNumRows(); + public AColGroup appendNInternal(AColGroup[] g, int blen, int rlen) { for(int i = 1; i < g.length; i++) { - if(!_colIndexes.equals(g[i]._colIndexes)) { - LOG.warn("Not same columns therefore not appending \n" + _colIndexes + "\n\n" + g[i]._colIndexes); - return null; - } + final AColGroup gs = g[i]; + if(!_colIndexes.equals(gs._colIndexes)) + throw new DMLCompressionException( + "Not same columns therefore not appending \n" + _colIndexes + "\n\n" + gs._colIndexes); - if(!(g[i] instanceof ColGroupSDC)) { - LOG.warn("Not SDC but " + g[i].getClass().getSimpleName()); - return null; - } + if(!(gs instanceof AOffsetsGroup)) + throw new DMLCompressionException("Not SDC but " + gs.getClass().getSimpleName()); - final ColGroupSDC gc = (ColGroupSDC) g[i]; - if(!gc._dict.equals(_dict)) { - LOG.warn("Not same Dictionaries therefore not appending \n" + _dict + "\n\n" + gc._dict); - return null; + if(gs instanceof ColGroupSDC) { + final ColGroupSDC gc = (ColGroupSDC) gs; + if(!gc._dict.equals(_dict)) + throw new DMLCompressionException( + "Not same Dictionaries therefore not appending \n" + _dict + "\n\n" + gc._dict); + } + else if(gs instanceof ColGroupConst) { + final ColGroupConst gc = (ColGroupConst) gs; + if(!(gc._dict instanceof PlaceHolderDict) && gc._dict.equals(_defaultTuple)) + throw new DMLCompressionException("Not same default values therefore not appending:\n" + gc._dict + + "\n\n" + Arrays.toString(_defaultTuple)); } - sumRows += gc.getNumRows(); } AMapToData nd = _data.appendN(Arrays.copyOf(g, g.length, IMapToDataGroup[].class)); AOffset no = _indexes.appendN(Arrays.copyOf(g, g.length, AOffsetsGroup[].class), getNumRows()); - return create(_colIndexes, sumRows, _dict, _defaultTuple, no, nd, null); + return create(_colIndexes, rlen, _dict, _defaultTuple, no, nd, null); } @Override diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCFOR.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCFOR.java index 4c46f907efc..65a5a42aa40 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCFOR.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCFOR.java @@ -461,8 +461,7 @@ public AColGroup append(AColGroup g) { } @Override - public AColGroup appendNInternal(AColGroup[] g) { - int sumRows = getNumRows(); + public AColGroup appendNInternal(AColGroup[] g, int blen, int rlen) { for(int i = 1; i < g.length; i++) { if(!_colIndexes.equals(g[i]._colIndexes)) { LOG.warn("Not same columns therefore not appending \n" + _colIndexes + "\n\n" + g[i]._colIndexes); @@ -479,11 +478,10 @@ public AColGroup appendNInternal(AColGroup[] g) { LOG.warn("Not same Dictionaries therefore not appending \n" + _dict + "\n\n" + gc._dict); return null; } - sumRows += gc.getNumRows(); } AMapToData nd = _data.appendN(Arrays.copyOf(g, g.length, IMapToDataGroup[].class)); AOffset no = _indexes.appendN(Arrays.copyOf(g, g.length, AOffsetsGroup[].class), getNumRows()); - return create(_colIndexes, sumRows, _dict, no, nd, null, _reference); + return create(_colIndexes, rlen, _dict, no, nd, null, _reference); } @Override diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCSingle.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCSingle.java index f2aa3167fc9..5d3be0d3f11 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCSingle.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCSingle.java @@ -27,6 +27,7 @@ import org.apache.sysds.runtime.DMLRuntimeException; import org.apache.sysds.runtime.compress.colgroup.dictionary.IDictionary; import org.apache.sysds.runtime.compress.colgroup.dictionary.PlaceHolderDict; +import org.apache.sysds.runtime.compress.DMLCompressionException; import org.apache.sysds.runtime.compress.colgroup.dictionary.Dictionary; import org.apache.sysds.runtime.compress.colgroup.dictionary.DictionaryFactory; import org.apache.sysds.runtime.compress.colgroup.indexes.ColIndexFactory; @@ -556,28 +557,31 @@ public AColGroup append(AColGroup g) { } @Override - public AColGroup appendNInternal(AColGroup[] g) { - int sumRows = getNumRows(); + public AColGroup appendNInternal(AColGroup[] g, int blen, int rlen) { for(int i = 1; i < g.length; i++) { - if(!_colIndexes.equals(g[i]._colIndexes)) { - LOG.warn("Not same columns therefore not appending \n" + _colIndexes + "\n\n" + g[i]._colIndexes); - return null; + final AColGroup gs = g[i]; + if(!_colIndexes.equals(gs._colIndexes)) + throw new DMLCompressionException( + "Not same columns therefore not appending \n" + _colIndexes + "\n\n" + gs._colIndexes); + + if(!(gs instanceof AOffsetsGroup)) + throw new DMLCompressionException("Not SDC but " + gs.getClass().getSimpleName()); + + if(gs instanceof ColGroupSDC) { + final ColGroupSDC gc = (ColGroupSDC) gs; + if(!gc._dict.equals(_dict)) + throw new DMLCompressionException( + "Not same Dictionaries therefore not appending \n" + _dict + "\n\n" + gc._dict); } - - if(!(g[i] instanceof ColGroupSDCSingle)) { - LOG.warn("Not SDCFOR but " + g[i].getClass().getSimpleName()); - return null; - } - - final ColGroupSDCSingle gc = (ColGroupSDCSingle) g[i]; - if(!gc._dict.equals(_dict)) { - LOG.warn("Not same Dictionaries therefore not appending \n" + _dict + "\n\n" + gc._dict); - return null; + else if(gs instanceof ColGroupConst) { + final ColGroupConst gc = (ColGroupConst) gs; + if(!(gc._dict instanceof PlaceHolderDict) && gc._dict.equals(_defaultTuple)) + throw new DMLCompressionException("Not same default values therefore not appending:\n" + gc._dict + + "\n\n" + Arrays.toString(_defaultTuple)); } - sumRows += gc.getNumRows(); } AOffset no = _indexes.appendN(Arrays.copyOf(g, g.length, AOffsetsGroup[].class), getNumRows()); - return create(_colIndexes, sumRows, _dict, _defaultTuple, no, null); + return create(_colIndexes, rlen, _dict, _defaultTuple, no, null); } @Override diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCSingleZeros.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCSingleZeros.java index 7da660e6cf6..7a3309aafa0 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCSingleZeros.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCSingleZeros.java @@ -826,6 +826,10 @@ public AColGroup sliceRows(int rl, int ru) { OffsetSliceInfo off = _indexes.slice(rl, ru); if(off.lIndex == -1) return null; + if(CompressedMatrixBlock.debug){ + if(off.offsetSlice.getOffsetToFirst() < 0 || off.offsetSlice.getOffsetToLast() > ru-rl) + throw new DMLCompressionException("Failed to slice : " + rl + " " + ru + " in: " + this); + } return create(_colIndexes, ru - rl, _dict, off.offsetSlice, null); } @@ -840,28 +844,30 @@ public AColGroup append(AColGroup g) { } @Override - public AColGroup appendNInternal(AColGroup[] g) { - int sumRows = getNumRows(); + public AColGroup appendNInternal(AColGroup[] g, int blen, int rlen) { + for(int i = 1; i < g.length; i++) { - if(!_colIndexes.equals(g[i]._colIndexes)) { - LOG.warn("Not same columns therefore not appending \n" + _colIndexes + "\n\n" + g[i]._colIndexes); + final AColGroup gs = g[i]; + if(!_colIndexes.equals(gs._colIndexes)) { + LOG.warn("Not same columns therefore not appending \n" + _colIndexes + "\n\n" + gs._colIndexes); return null; } - if(!(g[i] instanceof ColGroupSDCSingleZeros)) { - LOG.warn("Not SDCFOR but " + g[i].getClass().getSimpleName()); + if(!(gs instanceof AOffsetsGroup )) { + LOG.warn("Not SDCFOR but " + gs.getClass().getSimpleName()); return null; } - final ColGroupSDCSingleZeros gc = (ColGroupSDCSingleZeros) g[i]; - if(!gc._dict.equals(_dict)) { - LOG.warn("Not same Dictionaries therefore not appending \n" + _dict + "\n\n" + gc._dict); - return null; + if( gs instanceof ColGroupSDCSingleZeros){ + final ColGroupSDCSingleZeros gc = (ColGroupSDCSingleZeros) gs; + if(!gc._dict.equals(_dict)) { + LOG.warn("Not same Dictionaries therefore not appending \n" + _dict + "\n\n" + gc._dict); + return null; + } } - sumRows += gc.getNumRows(); } - AOffset no = _indexes.appendN(Arrays.copyOf(g, g.length, AOffsetsGroup[].class), getNumRows()); - return create(_colIndexes, sumRows, _dict, no, null); + AOffset no = _indexes.appendN(Arrays.copyOf(g, g.length, AOffsetsGroup[].class), blen); + return create(_colIndexes, rlen, _dict, no, null); } @Override diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCZeros.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCZeros.java index c26c39c0b85..f4c7c6f615b 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCZeros.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCZeros.java @@ -766,30 +766,32 @@ public AColGroup append(AColGroup g) { } @Override - public AColGroup appendNInternal(AColGroup[] g) { - int sumRows = getNumRows(); + public AColGroup appendNInternal(AColGroup[] g, int blen, int rlen) { + for(int i = 1; i < g.length; i++) { - if(!_colIndexes.equals(g[i]._colIndexes)) { + final AColGroup gs = g[i]; + if(!_colIndexes.equals(gs._colIndexes)) { LOG.warn("Not same columns therefore not appending \n" + _colIndexes + "\n\n" + g[i]._colIndexes); return null; } - if(!(g[i] instanceof ColGroupSDCZeros)) { - LOG.warn("Not SDCFOR but " + g[i].getClass().getSimpleName()); + if(!(gs instanceof AOffsetsGroup )) { + LOG.warn("Not valid OffsetGroup but " + gs.getClass().getSimpleName()); return null; } - final ColGroupSDCZeros gc = (ColGroupSDCZeros) g[i]; - if(!gc._dict.equals(_dict)) { - LOG.warn("Not same Dictionaries therefore not appending \n" + _dict + "\n\n" + gc._dict); - return null; + if( gs instanceof ColGroupSDCZeros){ + final ColGroupSDCZeros gc = (ColGroupSDCZeros) gs; + if(!gc._dict.equals(_dict)) { + LOG.warn("Not same Dictionaries therefore not appending \n" + _dict + "\n\n" + gc._dict); + return null; + } } - sumRows += gc.getNumRows(); } AMapToData nd = _data.appendN(Arrays.copyOf(g, g.length, IMapToDataGroup[].class)); - AOffset no = _indexes.appendN(Arrays.copyOf(g, g.length, AOffsetsGroup[].class), getNumRows()); + AOffset no = _indexes.appendN(Arrays.copyOf(g, g.length, AOffsetsGroup[].class), blen); - return create(_colIndexes, sumRows, _dict, no, nd, null); + return create(_colIndexes, rlen, _dict, no, nd, null); } @Override diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupUncompressed.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupUncompressed.java index caba3c3a2e7..17e954b2197 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupUncompressed.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupUncompressed.java @@ -73,8 +73,8 @@ public class ColGroupUncompressed extends AColGroup { private static final long serialVersionUID = -8254271148043292199L; /** - * We store the contents of the columns as a MatrixBlock to take advantage of high-performance routines available - * for this data structure. + * We store the contents of the columns as a MatrixBlock to take advantage of high-performance routines available for + * this data structure. */ private final MatrixBlock _data; @@ -216,8 +216,10 @@ private void decompressToDenseBlockDenseDataAllColumns(DenseBlock db, int rl, in final int offS = tb.pos(row); final double[] c = db.values(offT); final int off = db.pos(offT); - for(int j = 0; j < nCol; j++) - c[off + j] += values[offS + j]; + for(int j = 0; j < nCol; j++) { + double v = values[offS + j]; + c[off + j] += v; + } } } @@ -558,8 +560,7 @@ public void leftMultByAColGroup(AColGroup lhs, MatrixBlock result, int nRows) { else if(lhs instanceof APreAgg) leftMultByAPreAggColGroup((APreAgg) lhs, result); else - throw new DMLCompressionException( - "Not supported leftMult colgroup type: " + lhs.getClass().getSimpleName()); + throw new DMLCompressionException("Not supported leftMult colgroup type: " + lhs.getClass().getSimpleName()); } private void leftMultByAPreAggColGroup(APreAgg paCG, MatrixBlock result) { @@ -678,8 +679,8 @@ private void tsmmUncompressedColGroup(ColGroupUncompressed lhs, MatrixBlock resu final int[] aix = sb.indexes(row); final double[] avals = sb.values(row); for(int col = apos; col < alen; col++) - DictLibMatrixMult.addToUpperTriangle(nCols, lhs._colIndexes.get(row), _colIndexes.get(aix[col]), - resV, avals[col]); + DictLibMatrixMult.addToUpperTriangle(nCols, lhs._colIndexes.get(row), _colIndexes.get(aix[col]), resV, + avals[col]); } } else { @@ -816,8 +817,23 @@ public AColGroup append(AColGroup g) { } @Override - public AColGroup appendNInternal(AColGroup[] g) { - return null; + public AColGroup appendNInternal(AColGroup[] g, int blen, int rlen) { + final MatrixBlock ret = new MatrixBlock(rlen, _colIndexes.size(), _data.isInSparseFormat()); + ret.allocateBlock(); + final SparseBlock sb = ret.getSparseBlock(); + final DenseBlock db = ret.getDenseBlock(); + final IColIndex target = ColIndexFactory.create(_colIndexes.size()); + for(int i = 0; i < g.length; i++) { + final int start = i * blen; + final int end = Math.min(i * blen + blen, rlen); + final AColGroup gs = g[i]; + if(_data.isInSparseFormat()) + gs.copyAndSet(target).decompressToSparseBlock(sb, 0, end - start, start, 0); + else + gs.copyAndSet(target).decompressToDenseBlock(db, 0, end - start, start, 0); + } + ret.recomputeNonZeros(); + return new ColGroupUncompressed(ret, _colIndexes); } @Override @@ -855,7 +871,7 @@ public CompressedSizeInfoColGroup getCompressionInfo(int nRow) { @Override public AColGroup copyAndSet(IColIndex colIndexes) { - return ColGroupUncompressed.create(_data, colIndexes); + return new ColGroupUncompressed(_data, colIndexes); } @Override diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/ADictionary.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/ADictionary.java index 9859455307b..67f546c6ac5 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/ADictionary.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/ADictionary.java @@ -52,6 +52,11 @@ public final boolean equals(Object o) { return false; } + @Override + public final boolean equals(double[] v) { + return equals(new Dictionary(v)); + } + /** * Make a double into a string, if the double is a whole number then return it without decimal points * diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/Dictionary.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/Dictionary.java index 12a703627aa..983dc84b507 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/Dictionary.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/Dictionary.java @@ -1079,15 +1079,16 @@ public void TSMMToUpperTriangleDenseScaling(double[] left, IColIndex rowsLeft, I } @Override - public void TSMMToUpperTriangleSparseScaling(SparseBlock left, IColIndex rowsLeft, IColIndex colsRight, - int[] scale, MatrixBlock result) { + public void TSMMToUpperTriangleSparseScaling(SparseBlock left, IColIndex rowsLeft, IColIndex colsRight, int[] scale, + MatrixBlock result) { DictLibMatrixMult.TSMMToUpperTriangleSparseDenseScaling(left, _values, rowsLeft, colsRight, scale, result); } @Override public boolean equals(IDictionary o) { - if(o instanceof Dictionary) + if(o instanceof Dictionary) { return Arrays.equals(_values, ((Dictionary) o)._values); + } else if(o instanceof MatrixBlockDictionary) { final MatrixBlock mb = ((MatrixBlockDictionary) o).getMatrixBlock(); if(mb.isInSparseFormat()) @@ -1099,7 +1100,7 @@ else if(o instanceof MatrixBlockDictionary) { } @Override - public IDictionary cbind(IDictionary that, int nCol){ + public IDictionary cbind(IDictionary that, int nCol) { int nRowThat = that.getNumberOfValues(nCol); int nColThis = _values.length / nRowThat; MatrixBlockDictionary mbd = getMBDict(nColThis); @@ -1107,12 +1108,12 @@ public IDictionary cbind(IDictionary that, int nCol){ } @Override - public IDictionary reorder(int[] reorder){ + public IDictionary reorder(int[] reorder) { double[] retV = new double[_values.length]; Dictionary ret = new Dictionary(retV); int nRows = _values.length / reorder.length; - for(int r = 0; r < nRows; r++){ + for(int r = 0; r < nRows; r++) { int off = r * reorder.length; for(int c = 0; c < reorder.length; c++) retV[off + c] = _values[off + reorder[c]]; diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/IDictionary.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/IDictionary.java index d4843fc6d76..2f3d435673a 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/IDictionary.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/IDictionary.java @@ -75,8 +75,8 @@ public static enum DictType { public long getInMemorySize(); /** - * Aggregate all the contained values, useful in value only computations where the operation is iterating through - * all values contained in the dictionary. + * Aggregate all the contained values, useful in value only computations where the operation is iterating through all + * values contained in the dictionary. * * @param init The initial Value, in cases such as Max value, this could be -infinity * @param fn The Function to apply to values @@ -572,8 +572,7 @@ public void addToEntryVectorized(double[] v, int f1, int f2, int f3, int f4, int /** * Allocate a new dictionary where the tuple given is subtracted from all tuples in the previous dictionary. * - * @param tuple a double list representing a tuple, it is given that the tuple with is the same as this - * dictionaries. + * @param tuple a double list representing a tuple, it is given that the tuple with is the same as this dictionaries. * @return a new instance of dictionary with the tuple subtracted. */ public IDictionary subtractTuple(double[] tuple); @@ -788,8 +787,8 @@ public CM_COV_Object centralMomentWithReference(CM_COV_Object ret, ValueFunction public double getSparsity(); /** - * Multiply the v value with the dictionary entry at dictIdx and add it to the ret matrix at the columns specified - * in the int array. + * Multiply the v value with the dictionary entry at dictIdx and add it to the ret matrix at the columns specified in + * the int array. * * @param v Value to multiply * @param ret Output dense double array location @@ -869,8 +868,7 @@ public CM_COV_Object centralMomentWithReference(CM_COV_Object ret, ValueFunction * @param colsRight Offset cols on the right * @param result The output matrix block */ - public void TSMMToUpperTriangleSparse(SparseBlock left, IColIndex rowsLeft, IColIndex colsRight, - MatrixBlock result); + public void TSMMToUpperTriangleSparse(SparseBlock left, IColIndex rowsLeft, IColIndex colsRight, MatrixBlock result); /** * Matrix multiplication but allocate output in upper triangle and twice if on diagonal, note this is left @@ -925,6 +923,14 @@ public void TSMMToUpperTriangleSparseScaling(SparseBlock left, IColIndex rowsLef */ public boolean equals(Object o); + /** + * Indicate if this object is equal to the given array of doubles. + * + * @param v The list of double values + * @return If they are equal to this. + */ + public boolean equals(double[] v); + /** * Indicate if the other dictionary is equal to this. * diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/PlaceHolderDict.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/PlaceHolderDict.java index 067396ffbd9..94fa9ef5289 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/PlaceHolderDict.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/PlaceHolderDict.java @@ -29,6 +29,7 @@ import org.apache.sysds.runtime.functionobjects.Builtin; import org.apache.sysds.runtime.functionobjects.ValueFunction; import org.apache.sysds.runtime.instructions.cp.CM_COV_Object; +import org.apache.sysds.runtime.io.IOUtilFunctions; import org.apache.sysds.runtime.matrix.data.MatrixBlock; import org.apache.sysds.runtime.matrix.operators.BinaryOperator; import org.apache.sysds.runtime.matrix.operators.ScalarOperator; @@ -36,467 +37,474 @@ public class PlaceHolderDict implements IDictionary, Serializable { - private static final long serialVersionUID = 9176356558592L; - - private static final String errMessage = "PlaceHolderDict does not support Operations, and is purely intended for serialization"; - - /** The number of values supposed to be contained in this dictionary */ - private final int nVal; - - public PlaceHolderDict(int nVal) { - this.nVal = nVal; - } - - @Override - public double[] getValues() { - throw new RuntimeException(errMessage); - } - - @Override - public double getValue(int i) { - throw new RuntimeException(errMessage); - } - - @Override - public double getValue(int r, int col, int nCol) { - throw new RuntimeException(errMessage); - } - - @Override - public long getInMemorySize() { - return 16 + 4; - } - - @Override - public double aggregate(double init, Builtin fn) { - throw new RuntimeException(errMessage); - } - - @Override - public double aggregateWithReference(double init, Builtin fn, double[] reference, boolean def) { - throw new RuntimeException(errMessage); - } - - @Override - public double[] aggregateRows(Builtin fn, int nCol) { - throw new RuntimeException(errMessage); - } - - @Override - public double[] aggregateRowsWithDefault(Builtin fn, double[] defaultTuple) { - throw new RuntimeException(errMessage); - } - - @Override - public double[] aggregateRowsWithReference(Builtin fn, double[] reference) { - throw new RuntimeException(errMessage); - } - - @Override - public void aggregateCols(double[] c, Builtin fn, IColIndex colIndexes) { - throw new RuntimeException(errMessage); - } - - @Override - public void aggregateColsWithReference(double[] c, Builtin fn, IColIndex colIndexes, double[] reference, - boolean def) { - throw new RuntimeException(errMessage); - } - - @Override - public IDictionary applyScalarOp(ScalarOperator op) { - throw new RuntimeException(errMessage); - } - - @Override - public IDictionary applyScalarOpAndAppend(ScalarOperator op, double v0, int nCol) { - throw new RuntimeException(errMessage); - } - - @Override - public IDictionary applyUnaryOp(UnaryOperator op) { - throw new RuntimeException(errMessage); - } - - @Override - public IDictionary applyUnaryOpAndAppend(UnaryOperator op, double v0, int nCol) { - throw new RuntimeException(errMessage); - } - - @Override - public IDictionary applyScalarOpWithReference(ScalarOperator op, double[] reference, double[] newReference) { - throw new RuntimeException(errMessage); - } - - @Override - public IDictionary applyUnaryOpWithReference(UnaryOperator op, double[] reference, double[] newReference) { - throw new RuntimeException(errMessage); - } - - @Override - public IDictionary binOpLeft(BinaryOperator op, double[] v, IColIndex colIndexes) { - throw new RuntimeException(errMessage); - } - - @Override - public IDictionary binOpLeftAndAppend(BinaryOperator op, double[] v, IColIndex colIndexes) { - throw new RuntimeException(errMessage); - } - - @Override - public IDictionary binOpLeftWithReference(BinaryOperator op, double[] v, IColIndex colIndexes, double[] reference, - double[] newReference) { - throw new RuntimeException(errMessage); - } - - @Override - public IDictionary binOpRight(BinaryOperator op, double[] v, IColIndex colIndexes) { - throw new RuntimeException(errMessage); - } - - @Override - public IDictionary binOpRightAndAppend(BinaryOperator op, double[] v, IColIndex colIndexes) { - throw new RuntimeException(errMessage); - } - - @Override - public IDictionary binOpRight(BinaryOperator op, double[] v) { - throw new RuntimeException(errMessage); - } - - @Override - public IDictionary binOpRightWithReference(BinaryOperator op, double[] v, IColIndex colIndexes, double[] reference, - double[] newReference) { - throw new RuntimeException(errMessage); - } - - @Override - public void write(DataOutput out) throws IOException { - out.writeByte(DictionaryFactory.Type.PLACE_HOLDER.ordinal()); - out.writeInt(nVal); - } - - public static PlaceHolderDict read(DataInput in) throws IOException { - int nVals = in.readInt(); - return new PlaceHolderDict(nVals); - } - - @Override - public long getExactSizeOnDisk() { - return 1 + 4; - } - - @Override - public DictType getDictType() { - throw new RuntimeException(errMessage); - } - - @Override - public int getNumberOfValues(int ncol) { - return nVal; - } - - @Override - public double[] sumAllRowsToDouble(int nrColumns) { - throw new RuntimeException(errMessage); - } - - @Override - public double[] sumAllRowsToDoubleWithDefault(double[] defaultTuple) { - throw new RuntimeException(errMessage); - } - - @Override - public double[] sumAllRowsToDoubleWithReference(double[] reference) { - throw new RuntimeException(errMessage); - } - - @Override - public double[] sumAllRowsToDoubleSq(int nrColumns) { - throw new RuntimeException(errMessage); - } - - @Override - public double[] sumAllRowsToDoubleSqWithDefault(double[] defaultTuple) { - throw new RuntimeException(errMessage); - } - - @Override - public double[] sumAllRowsToDoubleSqWithReference(double[] reference) { - throw new RuntimeException(errMessage); - } - - @Override - public double[] productAllRowsToDouble(int nrColumns) { - throw new RuntimeException(errMessage); - } - - @Override - public double[] productAllRowsToDoubleWithDefault(double[] defaultTuple) { - throw new RuntimeException(errMessage); - } - - @Override - public double[] productAllRowsToDoubleWithReference(double[] reference) { - throw new RuntimeException(errMessage); - } - - @Override - public void colSum(double[] c, int[] counts, IColIndex colIndexes) { - throw new RuntimeException(errMessage); - } - - @Override - public void colSumSq(double[] c, int[] counts, IColIndex colIndexes) { - throw new RuntimeException(errMessage); - } - - @Override - public void colSumSqWithReference(double[] c, int[] counts, IColIndex colIndexes, double[] reference) { - throw new RuntimeException(errMessage); - } - - @Override - public double sum(int[] counts, int nCol) { - throw new RuntimeException(errMessage); - } - - @Override - public double sumSq(int[] counts, int nCol) { - throw new RuntimeException(errMessage); - } - - @Override - public double sumSqWithReference(int[] counts, double[] reference) { - throw new RuntimeException(errMessage); - } - - @Override - public String getString(int colIndexes) { - return ""; // get string empty - } - - @Override - public IDictionary sliceOutColumnRange(int idxStart, int idxEnd, int previousNumberOfColumns) { - throw new RuntimeException(errMessage); - } - - @Override - public boolean containsValue(double pattern) { - throw new RuntimeException(errMessage); - } - - @Override - public boolean containsValueWithReference(double pattern, double[] reference) { - throw new RuntimeException(errMessage); - } - - @Override - public long getNumberNonZeros(int[] counts, int nCol) { - return -1; - } - - @Override - public long getNumberNonZerosWithReference(int[] counts, double[] reference, int nRows) { - throw new RuntimeException(errMessage); - } - - @Override - public void addToEntry(double[] v, int fr, int to, int nCol) { - throw new RuntimeException(errMessage); - } - - @Override - public void addToEntry(double[] v, int fr, int to, int nCol, int rep) { - throw new RuntimeException(errMessage); - } - - @Override - public void addToEntryVectorized(double[] v, int f1, int f2, int f3, int f4, int f5, int f6, int f7, int f8, int t1, - int t2, int t3, int t4, int t5, int t6, int t7, int t8, int nCol) { - throw new RuntimeException(errMessage); - } - - @Override - public IDictionary subtractTuple(double[] tuple) { - throw new RuntimeException(errMessage); - } - - @Override - public MatrixBlockDictionary getMBDict(int nCol) { - throw new RuntimeException(errMessage); - } - - @Override - public IDictionary scaleTuples(int[] scaling, int nCol) { - throw new RuntimeException(errMessage); - } - - @Override - public IDictionary preaggValuesFromDense(int numVals, IColIndex colIndexes, IColIndex aggregateColumns, double[] b, - int cut) { - throw new RuntimeException(errMessage); - } - - @Override - public IDictionary replace(double pattern, double replace, int nCol) { - throw new RuntimeException(errMessage); - } - - @Override - public IDictionary replaceWithReference(double pattern, double replace, double[] reference) { - throw new RuntimeException(errMessage); - } - - @Override - public void product(double[] ret, int[] counts, int nCol) { - throw new RuntimeException(errMessage); - } - - @Override - public void productWithDefault(double[] ret, int[] counts, double[] def, int defCount) { - throw new RuntimeException(errMessage); - } - - @Override - public void productWithReference(double[] ret, int[] counts, double[] reference, int refCount) { - throw new RuntimeException(errMessage); - } - - @Override - public void colProduct(double[] res, int[] counts, IColIndex colIndexes) { - throw new RuntimeException(errMessage); - } - - @Override - public void colProductWithReference(double[] res, int[] counts, IColIndex colIndexes, double[] reference) { - throw new RuntimeException(errMessage); - } - - @Override - public CM_COV_Object centralMoment(ValueFunction fn, int[] counts, int nRows) { - throw new RuntimeException(errMessage); - } - - @Override - public CM_COV_Object centralMoment(CM_COV_Object ret, ValueFunction fn, int[] counts, int nRows) { - throw new RuntimeException(errMessage); - } - - @Override - public CM_COV_Object centralMomentWithDefault(ValueFunction fn, int[] counts, double def, int nRows) { - throw new RuntimeException(errMessage); - } - - @Override - public CM_COV_Object centralMomentWithDefault(CM_COV_Object ret, ValueFunction fn, int[] counts, double def, - int nRows) { - throw new RuntimeException(errMessage); - } - - @Override - public CM_COV_Object centralMomentWithReference(ValueFunction fn, int[] counts, double reference, int nRows) { - throw new RuntimeException(errMessage); - } - - @Override - public CM_COV_Object centralMomentWithReference(CM_COV_Object ret, ValueFunction fn, int[] counts, double reference, - int nRows) { - throw new RuntimeException(errMessage); - } - - @Override - public IDictionary rexpandCols(int max, boolean ignore, boolean cast, int nCol) { - throw new RuntimeException(errMessage); - } - - @Override - public IDictionary rexpandColsWithReference(int max, boolean ignore, boolean cast, int reference) { - throw new RuntimeException(errMessage); - } - - @Override - public double getSparsity() { - throw new RuntimeException(errMessage); - } - - @Override - public void multiplyScalar(double v, double[] ret, int off, int dictIdx, IColIndex cols) { - throw new RuntimeException(errMessage); - } - - @Override - public void TSMMWithScaling(int[] counts, IColIndex rows, IColIndex cols, MatrixBlock ret) { - throw new RuntimeException(errMessage); - } - - @Override - public void MMDict(IDictionary right, IColIndex rowsLeft, IColIndex colsRight, MatrixBlock result) { - throw new RuntimeException(errMessage); - } - - @Override - public void MMDictDense(double[] left, IColIndex rowsLeft, IColIndex colsRight, MatrixBlock result) { - throw new RuntimeException(errMessage); - } - - @Override - public void MMDictSparse(SparseBlock left, IColIndex rowsLeft, IColIndex colsRight, MatrixBlock result) { - throw new RuntimeException(errMessage); - } - - @Override - public void TSMMToUpperTriangle(IDictionary right, IColIndex rowsLeft, IColIndex colsRight, MatrixBlock result) { - throw new RuntimeException(errMessage); - } - - @Override - public void TSMMToUpperTriangleDense(double[] left, IColIndex rowsLeft, IColIndex colsRight, MatrixBlock result) { - throw new RuntimeException(errMessage); - } - - @Override - public void TSMMToUpperTriangleSparse(SparseBlock left, IColIndex rowsLeft, IColIndex colsRight, - MatrixBlock result) { - throw new RuntimeException(errMessage); - } - - @Override - public void TSMMToUpperTriangleScaling(IDictionary right, IColIndex rowsLeft, IColIndex colsRight, int[] scale, - MatrixBlock result) { - throw new RuntimeException(errMessage); - } - - @Override - public void TSMMToUpperTriangleDenseScaling(double[] left, IColIndex rowsLeft, IColIndex colsRight, int[] scale, - MatrixBlock result) { - throw new RuntimeException(errMessage); - } - - @Override - public void TSMMToUpperTriangleSparseScaling(SparseBlock left, IColIndex rowsLeft, IColIndex colsRight, int[] scale, - MatrixBlock result) { - throw new RuntimeException(errMessage); - } - - @Override - public IDictionary cbind(IDictionary that, int nCol) { - throw new RuntimeException(errMessage); - } - - @Override - public boolean equals(IDictionary o) { - return o instanceof PlaceHolderDict; - } - - @Override - public IDictionary reorder(int[] reorder) { - throw new RuntimeException(errMessage); - } - - @Override - public IDictionary clone() { - return new PlaceHolderDict(nVal); - } + private static final long serialVersionUID = 9176356558592L; + + private static final String errMessage = "PlaceHolderDict does not support Operations, and is purely intended for serialization"; + + /** The number of values supposed to be contained in this dictionary */ + private final int nVal; + + public PlaceHolderDict(int nVal) { + this.nVal = nVal; + } + + @Override + public double[] getValues() { + throw new RuntimeException(errMessage); + } + + @Override + public double getValue(int i) { + throw new RuntimeException(errMessage); + } + + @Override + public double getValue(int r, int col, int nCol) { + throw new RuntimeException(errMessage); + } + + @Override + public long getInMemorySize() { + return 16 + 4; + } + + @Override + public double aggregate(double init, Builtin fn) { + throw new RuntimeException(errMessage); + } + + @Override + public double aggregateWithReference(double init, Builtin fn, double[] reference, boolean def) { + throw new RuntimeException(errMessage); + } + + @Override + public double[] aggregateRows(Builtin fn, int nCol) { + throw new RuntimeException(errMessage); + } + + @Override + public double[] aggregateRowsWithDefault(Builtin fn, double[] defaultTuple) { + throw new RuntimeException(errMessage); + } + + @Override + public double[] aggregateRowsWithReference(Builtin fn, double[] reference) { + throw new RuntimeException(errMessage); + } + + @Override + public void aggregateCols(double[] c, Builtin fn, IColIndex colIndexes) { + throw new RuntimeException(errMessage); + } + + @Override + public void aggregateColsWithReference(double[] c, Builtin fn, IColIndex colIndexes, double[] reference, + boolean def) { + throw new RuntimeException(errMessage); + } + + @Override + public IDictionary applyScalarOp(ScalarOperator op) { + throw new RuntimeException(errMessage); + } + + @Override + public IDictionary applyScalarOpAndAppend(ScalarOperator op, double v0, int nCol) { + throw new RuntimeException(errMessage); + } + + @Override + public IDictionary applyUnaryOp(UnaryOperator op) { + throw new RuntimeException(errMessage); + } + + @Override + public IDictionary applyUnaryOpAndAppend(UnaryOperator op, double v0, int nCol) { + throw new RuntimeException(errMessage); + } + + @Override + public IDictionary applyScalarOpWithReference(ScalarOperator op, double[] reference, double[] newReference) { + throw new RuntimeException(errMessage); + } + + @Override + public IDictionary applyUnaryOpWithReference(UnaryOperator op, double[] reference, double[] newReference) { + throw new RuntimeException(errMessage); + } + + @Override + public IDictionary binOpLeft(BinaryOperator op, double[] v, IColIndex colIndexes) { + throw new RuntimeException(errMessage); + } + + @Override + public IDictionary binOpLeftAndAppend(BinaryOperator op, double[] v, IColIndex colIndexes) { + throw new RuntimeException(errMessage); + } + + @Override + public IDictionary binOpLeftWithReference(BinaryOperator op, double[] v, IColIndex colIndexes, double[] reference, + double[] newReference) { + throw new RuntimeException(errMessage); + } + + @Override + public IDictionary binOpRight(BinaryOperator op, double[] v, IColIndex colIndexes) { + throw new RuntimeException(errMessage); + } + + @Override + public IDictionary binOpRightAndAppend(BinaryOperator op, double[] v, IColIndex colIndexes) { + throw new RuntimeException(errMessage); + } + + @Override + public IDictionary binOpRight(BinaryOperator op, double[] v) { + throw new RuntimeException(errMessage); + } + + @Override + public IDictionary binOpRightWithReference(BinaryOperator op, double[] v, IColIndex colIndexes, double[] reference, + double[] newReference) { + throw new RuntimeException(errMessage); + } + + @Override + public void write(DataOutput out) throws IOException { + byte[] o = new byte[5]; + o[0] = (byte) DictionaryFactory.Type.PLACE_HOLDER.ordinal(); + IOUtilFunctions.intToBa(nVal, o, 1); + out.write(o); + } + + public static PlaceHolderDict read(DataInput in) throws IOException { + int nVals = in.readInt(); + return new PlaceHolderDict(nVals); + } + + @Override + public long getExactSizeOnDisk() { + return 1 + 4; + } + + @Override + public DictType getDictType() { + throw new RuntimeException(errMessage); + } + + @Override + public int getNumberOfValues(int ncol) { + return nVal; + } + + @Override + public double[] sumAllRowsToDouble(int nrColumns) { + throw new RuntimeException(errMessage); + } + + @Override + public double[] sumAllRowsToDoubleWithDefault(double[] defaultTuple) { + throw new RuntimeException(errMessage); + } + + @Override + public double[] sumAllRowsToDoubleWithReference(double[] reference) { + throw new RuntimeException(errMessage); + } + + @Override + public double[] sumAllRowsToDoubleSq(int nrColumns) { + throw new RuntimeException(errMessage); + } + + @Override + public double[] sumAllRowsToDoubleSqWithDefault(double[] defaultTuple) { + throw new RuntimeException(errMessage); + } + + @Override + public double[] sumAllRowsToDoubleSqWithReference(double[] reference) { + throw new RuntimeException(errMessage); + } + + @Override + public double[] productAllRowsToDouble(int nrColumns) { + throw new RuntimeException(errMessage); + } + + @Override + public double[] productAllRowsToDoubleWithDefault(double[] defaultTuple) { + throw new RuntimeException(errMessage); + } + + @Override + public double[] productAllRowsToDoubleWithReference(double[] reference) { + throw new RuntimeException(errMessage); + } + + @Override + public void colSum(double[] c, int[] counts, IColIndex colIndexes) { + throw new RuntimeException(errMessage); + } + + @Override + public void colSumSq(double[] c, int[] counts, IColIndex colIndexes) { + throw new RuntimeException(errMessage); + } + + @Override + public void colSumSqWithReference(double[] c, int[] counts, IColIndex colIndexes, double[] reference) { + throw new RuntimeException(errMessage); + } + + @Override + public double sum(int[] counts, int nCol) { + throw new RuntimeException(errMessage); + } + + @Override + public double sumSq(int[] counts, int nCol) { + throw new RuntimeException(errMessage); + } + + @Override + public double sumSqWithReference(int[] counts, double[] reference) { + throw new RuntimeException(errMessage); + } + + @Override + public String getString(int colIndexes) { + return ""; // get string empty + } + + @Override + public IDictionary sliceOutColumnRange(int idxStart, int idxEnd, int previousNumberOfColumns) { + throw new RuntimeException(errMessage); + } + + @Override + public boolean containsValue(double pattern) { + throw new RuntimeException(errMessage); + } + + @Override + public boolean containsValueWithReference(double pattern, double[] reference) { + throw new RuntimeException(errMessage); + } + + @Override + public long getNumberNonZeros(int[] counts, int nCol) { + return -1; + } + + @Override + public long getNumberNonZerosWithReference(int[] counts, double[] reference, int nRows) { + throw new RuntimeException(errMessage); + } + + @Override + public void addToEntry(double[] v, int fr, int to, int nCol) { + throw new RuntimeException(errMessage); + } + + @Override + public void addToEntry(double[] v, int fr, int to, int nCol, int rep) { + throw new RuntimeException(errMessage); + } + + @Override + public void addToEntryVectorized(double[] v, int f1, int f2, int f3, int f4, int f5, int f6, int f7, int f8, int t1, + int t2, int t3, int t4, int t5, int t6, int t7, int t8, int nCol) { + throw new RuntimeException(errMessage); + } + + @Override + public IDictionary subtractTuple(double[] tuple) { + throw new RuntimeException(errMessage); + } + + @Override + public MatrixBlockDictionary getMBDict(int nCol) { + throw new RuntimeException(errMessage); + } + + @Override + public IDictionary scaleTuples(int[] scaling, int nCol) { + throw new RuntimeException(errMessage); + } + + @Override + public IDictionary preaggValuesFromDense(int numVals, IColIndex colIndexes, IColIndex aggregateColumns, double[] b, + int cut) { + throw new RuntimeException(errMessage); + } + + @Override + public IDictionary replace(double pattern, double replace, int nCol) { + throw new RuntimeException(errMessage); + } + + @Override + public IDictionary replaceWithReference(double pattern, double replace, double[] reference) { + throw new RuntimeException(errMessage); + } + + @Override + public void product(double[] ret, int[] counts, int nCol) { + throw new RuntimeException(errMessage); + } + + @Override + public void productWithDefault(double[] ret, int[] counts, double[] def, int defCount) { + throw new RuntimeException(errMessage); + } + + @Override + public void productWithReference(double[] ret, int[] counts, double[] reference, int refCount) { + throw new RuntimeException(errMessage); + } + + @Override + public void colProduct(double[] res, int[] counts, IColIndex colIndexes) { + throw new RuntimeException(errMessage); + } + + @Override + public void colProductWithReference(double[] res, int[] counts, IColIndex colIndexes, double[] reference) { + throw new RuntimeException(errMessage); + } + + @Override + public CM_COV_Object centralMoment(ValueFunction fn, int[] counts, int nRows) { + throw new RuntimeException(errMessage); + } + + @Override + public CM_COV_Object centralMoment(CM_COV_Object ret, ValueFunction fn, int[] counts, int nRows) { + throw new RuntimeException(errMessage); + } + + @Override + public CM_COV_Object centralMomentWithDefault(ValueFunction fn, int[] counts, double def, int nRows) { + throw new RuntimeException(errMessage); + } + + @Override + public CM_COV_Object centralMomentWithDefault(CM_COV_Object ret, ValueFunction fn, int[] counts, double def, + int nRows) { + throw new RuntimeException(errMessage); + } + + @Override + public CM_COV_Object centralMomentWithReference(ValueFunction fn, int[] counts, double reference, int nRows) { + throw new RuntimeException(errMessage); + } + + @Override + public CM_COV_Object centralMomentWithReference(CM_COV_Object ret, ValueFunction fn, int[] counts, double reference, + int nRows) { + throw new RuntimeException(errMessage); + } + + @Override + public IDictionary rexpandCols(int max, boolean ignore, boolean cast, int nCol) { + throw new RuntimeException(errMessage); + } + + @Override + public IDictionary rexpandColsWithReference(int max, boolean ignore, boolean cast, int reference) { + throw new RuntimeException(errMessage); + } + + @Override + public double getSparsity() { + throw new RuntimeException(errMessage); + } + + @Override + public void multiplyScalar(double v, double[] ret, int off, int dictIdx, IColIndex cols) { + throw new RuntimeException(errMessage); + } + + @Override + public void TSMMWithScaling(int[] counts, IColIndex rows, IColIndex cols, MatrixBlock ret) { + throw new RuntimeException(errMessage); + } + + @Override + public void MMDict(IDictionary right, IColIndex rowsLeft, IColIndex colsRight, MatrixBlock result) { + throw new RuntimeException(errMessage); + } + + @Override + public void MMDictDense(double[] left, IColIndex rowsLeft, IColIndex colsRight, MatrixBlock result) { + throw new RuntimeException(errMessage); + } + + @Override + public void MMDictSparse(SparseBlock left, IColIndex rowsLeft, IColIndex colsRight, MatrixBlock result) { + throw new RuntimeException(errMessage); + } + + @Override + public void TSMMToUpperTriangle(IDictionary right, IColIndex rowsLeft, IColIndex colsRight, MatrixBlock result) { + throw new RuntimeException(errMessage); + } + + @Override + public void TSMMToUpperTriangleDense(double[] left, IColIndex rowsLeft, IColIndex colsRight, MatrixBlock result) { + throw new RuntimeException(errMessage); + } + + @Override + public void TSMMToUpperTriangleSparse(SparseBlock left, IColIndex rowsLeft, IColIndex colsRight, + MatrixBlock result) { + throw new RuntimeException(errMessage); + } + + @Override + public void TSMMToUpperTriangleScaling(IDictionary right, IColIndex rowsLeft, IColIndex colsRight, int[] scale, + MatrixBlock result) { + throw new RuntimeException(errMessage); + } + + @Override + public void TSMMToUpperTriangleDenseScaling(double[] left, IColIndex rowsLeft, IColIndex colsRight, int[] scale, + MatrixBlock result) { + throw new RuntimeException(errMessage); + } + + @Override + public void TSMMToUpperTriangleSparseScaling(SparseBlock left, IColIndex rowsLeft, IColIndex colsRight, int[] scale, + MatrixBlock result) { + throw new RuntimeException(errMessage); + } + + @Override + public IDictionary cbind(IDictionary that, int nCol) { + throw new RuntimeException(errMessage); + } + + @Override + public boolean equals(IDictionary o) { + return o instanceof PlaceHolderDict; + } + + @Override + public final boolean equals(double[] v) { + return false; + } + + @Override + public IDictionary reorder(int[] reorder) { + throw new RuntimeException(errMessage); + } + + @Override + public IDictionary clone() { + return new PlaceHolderDict(nVal); + } } diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/indexes/ArrayIndex.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/indexes/ArrayIndex.java index 605b5999c09..0c0693d53c8 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/indexes/ArrayIndex.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/indexes/ArrayIndex.java @@ -25,6 +25,7 @@ import java.util.Arrays; import java.util.stream.IntStream; +import org.apache.sysds.runtime.io.IOUtilFunctions; import org.apache.sysds.utils.MemoryEstimates; public class ArrayIndex extends AColIndex { @@ -54,10 +55,38 @@ public IColIndex shift(int i) { @Override public void write(DataOutput out) throws IOException { - out.writeByte(ColIndexType.ARRAY.ordinal()); - out.writeInt(cols.length); + if(cols.length < 100) + writeBuffered(out); + else + writeGeneric(out); + } + + public void writeBuffered(DataOutput out) throws IOException { + byte[] o = new byte[cols.length * 4 + 4 + 1]; + o[0] = (byte) ColIndexType.ARRAY.ordinal(); + IOUtilFunctions.intToBa(cols.length, o, 1); for(int i = 0; i < cols.length; i++) - out.writeInt(cols[i]); + IOUtilFunctions.intToBa(cols[i], o, i * 4 + 5); + out.write(o); + } + + public void writeGeneric(DataOutput out) throws IOException { + byte[] o = new byte[512]; + + o[0] = (byte) ColIndexType.ARRAY.ordinal(); + IOUtilFunctions.intToBa(cols.length, o, 1); + out.write(o, 0, 5); + + int i = 0; + while(i + 512 / 4 < cols.length) { + for(int of = 0; of < o.length; of += 4, i++) + IOUtilFunctions.intToBa(cols[i], o, of); + out.write(o); + } + int of = 0; + for(; i < cols.length; of += 4, i++) + IOUtilFunctions.intToBa(cols[i], o, of); + out.write(o, 0, of); } public static ArrayIndex read(DataInput in) throws IOException { @@ -219,7 +248,7 @@ public boolean contains(int i) { @Override public double avgOfIndex() { double s = 0.0; - for(int i = 0; i < cols.length; i++) + for(int i = 0; i < cols.length; i++) s += cols[i]; return s / cols.length; } diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/indexes/SingleIndex.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/indexes/SingleIndex.java index 97325460537..2b14ecc3e73 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/indexes/SingleIndex.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/indexes/SingleIndex.java @@ -23,6 +23,7 @@ import java.io.IOException; import org.apache.sysds.runtime.compress.DMLCompressionException; +import org.apache.sysds.runtime.io.IOUtilFunctions; public class SingleIndex extends AColIndex { private final int idx; @@ -52,8 +53,10 @@ public IIterate iterator() { } public void write(DataOutput out) throws IOException { - out.writeByte(ColIndexType.SINGLE.ordinal()); - out.writeInt(idx); + byte[] o = new byte[5]; + o[0] = (byte) ColIndexType.SINGLE.ordinal(); + IOUtilFunctions.intToBa(idx, o, 1); + out.write(o); } @Override diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/MapToBit.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/MapToBit.java index 083e793d0cc..0822f08f998 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/MapToBit.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/MapToBit.java @@ -412,16 +412,17 @@ public AMapToData appendN(IMapToDataGroup[] d) { p += gd.getMapToData().size(); final long[] ret = new long[longSize(p)]; - long[] or = _data; - System.arraycopy(or, 0, ret, 0, or.length); - - p = size(); - for(int i = 1; i < d.length; i++) { - final MapToBit mm = (MapToBit) d[i].getMapToData(); - final int ms = mm.size(); - or = mm._data; - BitSetArray.setVectorizedLongs(p, p + ms, ret, or); - p += ms; + long[] or = null; + + p = 0; + for(int i = 0; i < d.length; i++) { + if(d[i].getMapToData().size() > 0) { + final MapToBit mm = (MapToBit) d[i].getMapToData(); + final int ms = mm.size(); + or = mm._data; + BitSetArray.setVectorizedLongs(p, p + ms, ret, or); + p += ms; + } } BitSet retBS = BitSet.valueOf(ret); @@ -441,7 +442,7 @@ public int getMaxPossible() { public String toString() { StringBuilder sb = new StringBuilder(); sb.append(super.toString()); - sb.append("size: " + _size); + sb.append(" size: " + _size); sb.append(" longLength:["); sb.append(_data.length); sb.append("]"); diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/MapToByte.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/MapToByte.java index bdc6e4e2a8a..837468d3ebf 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/MapToByte.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/MapToByte.java @@ -238,14 +238,16 @@ public AMapToData appendN(IMapToDataGroup[] d) { int p = 0; // pointer for(IMapToDataGroup gd : d) p += gd.getMapToData().size(); - final byte[] ret = Arrays.copyOf(_data, p); - - p = size(); - for(int i = 1; i < d.length; i++) { - final MapToByte mm = (MapToByte) d[i].getMapToData(); - final int ms = mm.size(); - System.arraycopy(mm._data, 0, ret, p, ms); - p += ms; + final byte[] ret = new byte[p]; + + p = 0; + for(int i = 0; i < d.length; i++) { + if(d[i].getMapToData().size() > 0) { + final MapToByte mm = (MapToByte) d[i].getMapToData(); + final int ms = mm.size(); + System.arraycopy(mm._data, 0, ret, p, ms); + p += ms; + } } if(getUnique() < 127) @@ -255,14 +257,14 @@ public AMapToData appendN(IMapToDataGroup[] d) { } @Override - public int getMaxPossible(){ + public int getMaxPossible() { return 256; } @Override public boolean equals(AMapToData e) { return e instanceof MapToByte && // - e.getUnique() == getUnique() &&// + e.getUnique() == getUnique() && // Arrays.equals(((MapToByte) e)._data, _data); } } diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/MapToChar.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/MapToChar.java index ae203208091..bdab7891b82 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/MapToChar.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/MapToChar.java @@ -28,6 +28,7 @@ import org.apache.commons.lang3.NotImplementedException; import org.apache.sysds.runtime.compress.colgroup.IMapToDataGroup; import org.apache.sysds.runtime.compress.colgroup.mapping.MapToFactory.MAP_TYPE; +import org.apache.sysds.runtime.io.IOUtilFunctions; import org.apache.sysds.utils.MemoryEstimates; public class MapToChar extends AMapToData { @@ -110,8 +111,27 @@ public void write(DataOutput out) throws IOException { out.writeByte(MAP_TYPE.CHAR.ordinal()); out.writeInt(getUnique()); out.writeInt(_data.length); - for(int i = 0; i < _data.length; i++) - out.writeChar(_data[i]); + final int BS = 100; + if(_data.length > BS) { + final byte[] buff = new byte[BS*2]; + for(int i = 0; i < _data.length; ) { + if(i + BS <= _data.length) { + for(int o = 0; o < BS; o++) { + IOUtilFunctions.shortToBa(_data[i++], buff, o * 2); + } + out.write(buff); + } + else {// remaining. + for(; i < _data.length; i++) + out.writeChar(_data[i]); + } + } + } + else { + for(int i = 0; i < _data.length; i++) + out.writeChar(_data[i]); + } + } protected static MapToChar readFields(DataInput in) throws IOException { @@ -258,28 +278,30 @@ public AMapToData appendN(IMapToDataGroup[] d) { int p = 0; // pointer for(IMapToDataGroup gd : d) p += gd.getMapToData().size(); - final char[] ret = Arrays.copyOf(_data, p); - - p = size(); - for(int i = 1; i < d.length; i++) { - final MapToChar mm = (MapToChar) d[i].getMapToData(); - final int ms = mm.size(); - System.arraycopy(mm._data, 0, ret, p, ms); - p += ms; + final char[] ret = new char[p]; + + p = 0; + for(int i = 0; i < d.length; i++) { + if(d[i].getMapToData().size() > 0) { + final MapToChar mm = (MapToChar) d[i].getMapToData(); + final int ms = mm.size(); + System.arraycopy(mm._data, 0, ret, p, ms); + p += ms; + } } return new MapToChar(getUnique(), ret); } @Override - public int getMaxPossible(){ + public int getMaxPossible() { return Character.MAX_VALUE; } @Override public boolean equals(AMapToData e) { return e instanceof MapToChar && // - e.getUnique() == getUnique() &&// + e.getUnique() == getUnique() && // Arrays.equals(((MapToChar) e)._data, _data); } } diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/MapToCharPByte.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/MapToCharPByte.java index 02a8a6e4b40..cb7d6199cf2 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/MapToCharPByte.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/MapToCharPByte.java @@ -244,7 +244,26 @@ public AMapToData append(AMapToData t) { @Override public AMapToData appendN(IMapToDataGroup[] d) { - throw new NotImplementedException(); + + int p = 0; // pointer + for(IMapToDataGroup gd : d) + p += gd.getMapToData().size(); + final char[] ret = new char[p]; + final byte[] retb = new byte[p]; + + p = 0; + for(int i = 0; i < d.length; i++) { + if(d[i].getMapToData().size() > 0) { + final MapToCharPByte mm = (MapToCharPByte) d[i].getMapToData(); + final int ms = mm.size(); + System.arraycopy(mm._data_c, 0, ret, p, ms); + System.arraycopy(mm._data_b, 0, retb, p, ms); + p += ms; + } + } + + return new MapToCharPByte(getUnique(), ret, retb); + } @Override diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/MapToInt.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/MapToInt.java index a2ecf221f48..b3c509b78cf 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/MapToInt.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/MapToInt.java @@ -259,28 +259,30 @@ public AMapToData appendN(IMapToDataGroup[] d) { int p = 0; // pointer for(IMapToDataGroup gd : d) p += gd.getMapToData().size(); - final int[] ret = Arrays.copyOf(_data, p); - - p = size(); - for(int i = 1; i < d.length; i++) { - final MapToInt mm = (MapToInt) d[i].getMapToData(); - final int ms = mm.size(); - System.arraycopy(mm._data, 0, ret, p, ms); - p += ms; + final int[] ret = new int[p]; + + p = 0; + for(int i = 0; i < d.length; i++) { + if(d[i].getMapToData().size() > 0) { + final MapToInt mm = (MapToInt) d[i].getMapToData(); + final int ms = mm.size(); + System.arraycopy(mm._data, 0, ret, p, ms); + p += ms; + } } return new MapToInt(getUnique(), ret); } @Override - public int getMaxPossible(){ + public int getMaxPossible() { return Integer.MAX_VALUE; } @Override public boolean equals(AMapToData e) { return e instanceof MapToInt && // - e.getUnique() == getUnique() &&// + e.getUnique() == getUnique() && // Arrays.equals(((MapToInt) e)._data, _data); } } diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/MapToZero.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/MapToZero.java index 090d6400117..e3797dce3fd 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/MapToZero.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/MapToZero.java @@ -163,13 +163,23 @@ public AMapToData append(AMapToData t) { @Override public AMapToData appendN(IMapToDataGroup[] d) { int p = 0; // pointer - for(IMapToDataGroup gd : d) - p += gd.getMapToData().size(); + boolean allZ = true; + for(IMapToDataGroup gd : d) { + AMapToData m = gd.getMapToData(); + + p += m.size(); + if(!(m instanceof MapToZero)) + allZ = false; + } + + if(!allZ) + throw new RuntimeException("Not supported combining different types of map"); + return new MapToZero(p); } @Override - public int getMaxPossible(){ + public int getMaxPossible() { return 1; } diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/offset/AOffset.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/offset/AOffset.java index 4c10fa7facd..adefe49a528 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/offset/AOffset.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/offset/AOffset.java @@ -50,20 +50,20 @@ public abstract class AOffset implements Serializable { protected static final Log LOG = LogFactory.getLog(AOffset.class.getName()); + /** The skip list stride size, aka how many indexes skipped for each index. */ + protected static final int skipStride = 1000; + + /** SoftReference of the skip list to be dematerialized on memory pressure */ + private volatile SoftReference skipList = null; + /** Thread local cache for a single recently used Iterator, this is used for cache blocking */ - private ThreadLocal cacheRow = new ThreadLocal<>() { + private volatile ThreadLocal cacheRow = new ThreadLocal<>() { @Override protected OffsetCache initialValue() { return null; } }; - /** The skiplist stride size, aka how many indexes skipped for each index. */ - protected static final int skipStride = 1000; - - /** SoftReference of the skip list to be dematerialized on memory pressure */ - private SoftReference skipList = null; - /** * Get an iterator of the offsets while also maintaining the data index pointer. * @@ -104,6 +104,17 @@ else if(getLength() < skipStride) return getIteratorLargeOffset(row); } + private AIterator getIteratorSkipCache(int row){ + if(row <= getOffsetToFirst()) + return getIterator(); + else if(row > getOffsetToLast()) + return null; + else if(getLength() < skipStride) + return getIteratorSmallOffset(row); + else + return getIteratorLargeOffset(row); + } + private AIterator getIteratorSmallOffset(int row) { AIterator it = getIterator(); it.skipTo(row); @@ -111,7 +122,7 @@ private AIterator getIteratorSmallOffset(int row) { return it; } - private AIterator getIteratorLargeOffset(int row) { + private final AIterator getIteratorLargeOffset(int row) { if(skipList == null || skipList.get() == null) constructSkipList(); final OffsetCacheV2[] skip = skipList.get(); @@ -125,22 +136,23 @@ private AIterator getIteratorLargeOffset(int row) { return it; } - private synchronized void constructSkipList() { + public synchronized void constructSkipList() { if(skipList != null && skipList.get() != null) return; // not actual accurate but applicable. - final int skipSize = getLength() / skipStride + 1; + final int last = getOffsetToLast(); + final int skipSize = last / skipStride + 1; if(skipSize == 0) return; final OffsetCacheV2[] skipListTmp = new OffsetCacheV2[skipSize]; final AIterator it = getIterator(); - final int last = getOffsetToLast(); int skipListIdx = 0; - while(it.value() < last) { - for(int i = 0; i < skipStride && it.value() < last; i++) + while(it.value() < last && skipListIdx < skipListTmp.length) { + int next = skipListIdx * skipStride + skipStride; + while(it.value() < next && it.value() < last) it.next(); skipListTmp[skipListIdx++] = new OffsetCacheV2(it.value(), it.getDataIndex(), it.getOffsetsIndex()); } @@ -241,8 +253,8 @@ else if(rl == ru - 1) { } } - protected final void preAggregateDenseMapRow(double[] mV, int off, double[] preAV, int cu, int nVal, - AMapToData data, AIterator it) { + protected final void preAggregateDenseMapRow(double[] mV, int off, double[] preAV, int cu, int nVal, AMapToData data, + AIterator it) { final int last = getOffsetToLast(); if(cu <= last) preAggregateDenseMapRowBellowEnd(mV, off, preAV, cu, nVal, data, it); @@ -250,8 +262,8 @@ protected final void preAggregateDenseMapRow(double[] mV, int off, double[] preA preAggregateDenseMapRowEnd(mV, off, preAV, last, nVal, data, it); } - protected final void preAggregateDenseMapRowBellowEnd(final double[] mV, final int off, final double[] preAV, - int cu, final int nVal, final AMapToData data, final AIterator it) { + protected final void preAggregateDenseMapRowBellowEnd(final double[] mV, final int off, final double[] preAV, int cu, + final int nVal, final AMapToData data, final AIterator it) { it.offset += off; cu += off; while(it.offset < cu) { @@ -426,6 +438,11 @@ private void preAggregateSparseMapRows(SparseBlock sb, double[] preAV, int rl, i } } + @Override + public boolean equals(Object o) { + return o instanceof AOffset && this.equals((AOffset) o); + } + public boolean equals(AOffset b) { if(getOffsetToLast() == b.getOffsetToLast()) { int last = getOffsetToLast(); @@ -461,7 +478,7 @@ public boolean equals(AOffset b) { public abstract int getLength(); public OffsetSliceInfo slice(int l, int u) { - AIterator it = getIterator(l); + AIterator it = getIteratorSkipCache(l); if(it == null || it.value() >= u) return new OffsetSliceInfo(-1, -1, new OffsetEmpty()); else if(l <= getOffsetToFirst() && u > getOffsetToLast()) { @@ -470,13 +487,13 @@ else if(l <= getOffsetToFirst() && u > getOffsetToLast()) { else return new OffsetSliceInfo(0, getSize(), moveIndex(l)); } - int low = it.getDataIndex(); - int lowOff = it.getOffsetsIndex(); - int lowValue = it.value(); + final int low = it.getDataIndex(); + final int lowOff = it.getOffsetsIndex(); + final int lowValue = it.value(); - int high = it.getDataIndex(); - int highOff = it.getOffsetsIndex(); - int highValue = it.value(); + int high = low; + int highOff = lowOff; + int highValue = lowValue; if(u >= getOffsetToLast()) { // If including the last do not iterate. high = getSize() - 1; highOff = getLength(); @@ -484,22 +501,20 @@ else if(l <= getOffsetToFirst() && u > getOffsetToLast()) { } else { // Have to iterate through until we find last. while(it.value() < u) { + // TODO add previous command that would allow us to simplify this loop. high = it.getDataIndex(); highOff = it.getOffsetsIndex(); highValue = it.value(); it.next(); } } - - lowValue -= l; - highValue -= l; - + if(low == high) - return new OffsetSliceInfo(low, high + 1, new OffsetSingle(lowValue)); + return new OffsetSliceInfo(low, high + 1, new OffsetSingle(lowValue - l)); else if(low + 1 == high) - return new OffsetSliceInfo(low, high + 1, new OffsetTwo(lowValue, highValue)); + return new OffsetSliceInfo(low, high + 1, new OffsetTwo(lowValue - l, highValue - l)); else - return ((ISliceOffset) this).slice(lowOff, highOff, lowValue, highValue, low, high); + return ((ISliceOffset) this).slice(lowOff, highOff, lowValue - l, highValue - l, low, high); } /** @@ -546,18 +561,25 @@ public AOffset appendN(AOffsetsGroup[] g, int s) { int ss = 0; for(AOffsetsGroup gs : g) { final AOffset tof = gs.getOffsets(); - final AOffsetIterator tofit = tof.getOffsetIterator(); - final int last = tof.getOffsetToLast() + ss; - int v = tofit.value() + ss; - while(v < last) { + if(!(tof instanceof OffsetEmpty)) { + final AOffsetIterator tofit = tof.getOffsetIterator(); + final int last = tof.getOffsetToLast() + ss; + int v = tofit.value() + ss; + while(v < last) { + r.appendValue(v); + v = tofit.next() + ss; + } r.appendValue(v); - v = tofit.next() + ss; } - r.appendValue(v); ss += s; } - return OffsetFactory.createOffset(r); + try { + return OffsetFactory.createOffset(r); + } + catch(Exception e) { + throw new DMLCompressionException("failed to combine" + Arrays.toString(g) + " with S sizes: " + s); + } } @Override diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/offset/OffsetByte.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/offset/OffsetByte.java index 22d5be5987b..9e009f47010 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/offset/OffsetByte.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/offset/OffsetByte.java @@ -25,6 +25,7 @@ import org.apache.sysds.runtime.compress.CompressedMatrixBlock; import org.apache.sysds.runtime.compress.colgroup.AOffsetsGroup; +import org.apache.sysds.runtime.io.IOUtilFunctions; import org.apache.sysds.utils.MemoryEstimates; public class OffsetByte extends AOffsetByte { @@ -71,11 +72,13 @@ public AOffsetIterator getOffsetIterator() { @Override public void write(DataOutput out) throws IOException { - out.writeByte(OffsetFactory.OFF_TYPE_SPECIALIZATIONS.BYTE.ordinal()); - out.writeInt(offsetToFirst); - out.writeInt(offsets.length); - out.writeInt(offsetToLast); - out.writeInt(size); + final byte[] its = new byte[4 *4 + 1]; + its[0] = (byte) OffsetFactory.OFF_TYPE_SPECIALIZATIONS.BYTE.ordinal(); + IOUtilFunctions.intToBa(offsetToFirst, its, 1); + IOUtilFunctions.intToBa(offsets.length, its, 5); + IOUtilFunctions.intToBa(offsetToLast, its, 9); + IOUtilFunctions.intToBa(size, its, 13); + out.write(its); out.write(offsets); } diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/offset/OffsetByteNZ.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/offset/OffsetByteNZ.java index 5999675b86a..1328ae328e2 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/offset/OffsetByteNZ.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/offset/OffsetByteNZ.java @@ -24,6 +24,7 @@ import java.util.Arrays; import org.apache.sysds.runtime.compress.CompressedMatrixBlock; +import org.apache.sysds.runtime.io.IOUtilFunctions; import org.apache.sysds.utils.MemoryEstimates; public class OffsetByteNZ extends AOffsetByte { @@ -55,10 +56,12 @@ public AOffsetIterator getOffsetIterator() { @Override public void write(DataOutput out) throws IOException { - out.writeByte(OffsetFactory.OFF_TYPE_SPECIALIZATIONS.BYTENZ.ordinal()); - out.writeInt(offsetToFirst); - out.writeInt(offsets.length); - out.writeInt(offsetToLast); + final byte[] its = new byte[4 *3 + 1]; + its[0] = (byte) OffsetFactory.OFF_TYPE_SPECIALIZATIONS.BYTENZ.ordinal(); + IOUtilFunctions.intToBa(offsetToFirst, its, 1); + IOUtilFunctions.intToBa(offsets.length, its, 5); + IOUtilFunctions.intToBa(offsetToLast, its, 9); + out.write(its); out.write(offsets); } diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/offset/OffsetByteUNZ.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/offset/OffsetByteUNZ.java index d476fd4dbc5..652ce8c1cdd 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/offset/OffsetByteUNZ.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/offset/OffsetByteUNZ.java @@ -24,6 +24,7 @@ import java.util.Arrays; import org.apache.sysds.runtime.compress.CompressedMatrixBlock; +import org.apache.sysds.runtime.io.IOUtilFunctions; import org.apache.sysds.utils.MemoryEstimates; public class OffsetByteUNZ extends AOffsetByte { @@ -55,10 +56,12 @@ public AOffsetIterator getOffsetIterator() { @Override public void write(DataOutput out) throws IOException { - out.writeByte(OffsetFactory.OFF_TYPE_SPECIALIZATIONS.BYTEUNZ.ordinal()); - out.writeInt(offsetToFirst); - out.writeInt(offsets.length); - out.writeInt(offsetToLast); + final byte[] its = new byte[4 * 3 + 1]; + its[0] = (byte) OffsetFactory.OFF_TYPE_SPECIALIZATIONS.BYTEUNZ.ordinal(); + IOUtilFunctions.intToBa(offsetToFirst, its, 1); + IOUtilFunctions.intToBa(offsets.length, its, 5); + IOUtilFunctions.intToBa(offsetToLast, its, 9); + out.write(its); out.write(offsets); } diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/offset/OffsetFactory.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/offset/OffsetFactory.java index 2d2cfb874c8..319b7ce89f9 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/offset/OffsetFactory.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/offset/OffsetFactory.java @@ -273,7 +273,6 @@ private static AOffset createByte(int[] indexes, int apos, int alen) { } boolean noOverHalf = getNoOverHalf(offsets); return OffsetByte.create(offsets, offsetToFirst, offsetToLast, alen - apos, noZero, noOverHalf); - } private static int calcSize(int[] indexes, int apos, int alen, int offMax) { diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/offset/OffsetTwo.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/offset/OffsetTwo.java index bd1545f8c7f..a76fc2112d0 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/offset/OffsetTwo.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/offset/OffsetTwo.java @@ -97,7 +97,9 @@ public static OffsetTwo readFields(DataInput in) throws IOException { @Override public OffsetSliceInfo slice(int l, int u) { if(l <= first) { - if(u > last) + if(u < first) + return new OffsetSliceInfo(-1, -1, new OffsetEmpty()); + else if(u > last) return new OffsetSliceInfo(0, 2, moveIndex(l)); else return new OffsetSliceInfo(0, 1, new OffsetSingle(first - l)); diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/scheme/CompressionScheme.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/scheme/CompressionScheme.java index 42f3671c981..27f2a4e0235 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/scheme/CompressionScheme.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/scheme/CompressionScheme.java @@ -30,6 +30,7 @@ import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.apache.sysds.runtime.compress.CompressedMatrixBlock; +import org.apache.sysds.runtime.compress.CompressedMatrixBlockFactory; import org.apache.sysds.runtime.compress.DMLCompressionException; import org.apache.sysds.runtime.compress.colgroup.AColGroup; import org.apache.sysds.runtime.matrix.data.LibMatrixReorg; @@ -184,15 +185,16 @@ public CompressedMatrixBlock updateAndEncode(MatrixBlock mb, int k) { return updateAndEncode(mb); validateInput(mb); + final int nRow = mb.getNumRows(); + final int nCol = mb.getNumColumns(); boolean transposed = false; - if(mb.getSparsity() < 0.1) { + if(CompressedMatrixBlockFactory.transposeHeuristics(encodings.length, mb)) { transposed = true; mb = LibMatrixReorg.transpose(mb, k, true); } final ExecutorService pool = CommonThreadPool.get(k); try { - final int nCol = mb.getNumColumns(); AColGroup[] ret = new AColGroup[encodings.length]; List tasks = new ArrayList<>(); int taskSize = Math.max(1, encodings.length / (4 * k)); @@ -203,7 +205,9 @@ public CompressedMatrixBlock updateAndEncode(MatrixBlock mb, int k) { t.get(); List retA = new ArrayList<>(Arrays.asList(ret)); - return new CompressedMatrixBlock(mb.getNumRows(), nCol, mb.getNonZeros(), false, retA); + + return new CompressedMatrixBlock(nRow, nCol, mb.getNonZeros(), false, retA); + } catch(Exception e) { diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/scheme/SDCSchemeSC.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/scheme/SDCSchemeSC.java index fd893f40abf..e273b814d1f 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/scheme/SDCSchemeSC.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/scheme/SDCSchemeSC.java @@ -237,10 +237,10 @@ protected Pair tryUpdateAndEncode(MatrixBlock data, IColI final int nRow = data.getNumRows(); Pair e = encodeAndUpdate(data, cols.get(0)); - if(lastDict == null || lastDict.getNumberOfValues(columns.size()) != map.size()) - lastDict = DictionaryFactory.create(map); + allocateDictionary(); final AOffset off = OffsetFactory.createOffset(e.getKey()); + off.constructSkipList(); AColGroup g = ColGroupSDC.create(columns, nRow, lastDict, // new double[] {def}, off, e.getValue(), null); return new Pair<>(this, g); @@ -280,7 +280,29 @@ private Pair encodeAndUpdateSparse(MatrixBlock data, i } private Pair encodeAndUpdateDense(MatrixBlock data, int col) { - throw new NotImplementedException(); + final int nRow = data.getNumRows(); + final double[] vals = data.getDenseBlockValues(); + final int nCol = data.getNumColumns(); + final int max = nRow * nCol; // guaranteed lower than intmax. + + IntArrayList off = getCachedArray(0); + IntArrayList val = getCachedArray(1); + + // full iteration + for(int i = 0, o = col; o < max; i++, o += nCol) + if(!Util.eq(vals[o], def)){ + off.appendValue(i); + val.appendValue(map.increment(vals[o])); + } + + // Only cells with non default values. + AMapToData d = MapToFactory.create(off.size(), map.size()); + for(int i = 0; i < off.size(); i++) { + int o = off.get(i) * nCol + col; + d.set(i, map.getId(vals[o])); + } + return new Pair<>(off, d); + } private Pair encodeAndUpdateGeneric(MatrixBlock data, int col) { @@ -500,8 +522,8 @@ protected Pair encodeAndUpdateSparseT(MatrixBlock data, I d.set(i, dt.get(i)); allocateDictionary(); - AColGroup g = ColGroupSDC.create(columns, nRow, lastDict, new double[] {def}, - OffsetFactory.createOffset(off), d, null); + AColGroup g = ColGroupSDC.create(columns, nRow, lastDict, new double[] {def}, OffsetFactory.createOffset(off), + d, null); return new Pair<>(this, g); } } diff --git a/src/main/java/org/apache/sysds/runtime/compress/estim/EstimationFactors.java b/src/main/java/org/apache/sysds/runtime/compress/estim/EstimationFactors.java index b0e188c7ce2..130d0f77f82 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/estim/EstimationFactors.java +++ b/src/main/java/org/apache/sysds/runtime/compress/estim/EstimationFactors.java @@ -36,7 +36,7 @@ public class EstimationFactors { protected final int largestOff; /** The frequencies of the Non zero tuples in the columns */ protected final int[] frequencies; - /** The Number of values in the collection not Zero , Also refered to as singletons */ + /** The Number of values in the collection not Zero, also referred to as singletons */ protected final int numSingle; /** The Number of rows in the column group */ protected final int numRows; diff --git a/src/main/java/org/apache/sysds/runtime/compress/io/ReaderCompressed.java b/src/main/java/org/apache/sysds/runtime/compress/io/ReaderCompressed.java index b25846ad7fc..da725ab0290 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/io/ReaderCompressed.java +++ b/src/main/java/org/apache/sysds/runtime/compress/io/ReaderCompressed.java @@ -21,9 +21,12 @@ import java.io.IOException; import java.io.InputStream; +import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Future; import org.apache.commons.lang3.NotImplementedException; import org.apache.hadoop.fs.FileSystem; @@ -43,27 +46,35 @@ import org.apache.sysds.runtime.io.MatrixReader; import org.apache.sysds.runtime.matrix.data.MatrixBlock; import org.apache.sysds.runtime.matrix.data.MatrixIndexes; +import org.apache.sysds.runtime.util.CommonThreadPool; public final class ReaderCompressed extends MatrixReader { + private final int k; + + public ReaderCompressed() { + this.k = OptimizerUtils.getParallelBinaryReadParallelism(); + } + + public ReaderCompressed(int k) { + this.k = k; + } + public static ReaderCompressed create() { - return new ReaderCompressed(); + int numThreads = OptimizerUtils.getParallelBinaryReadParallelism(); + return new ReaderCompressed(numThreads); } - public static MatrixBlock readCompressedMatrixFromHDFS(String fname) throws IOException { - return create().readMatrixFromHDFS(fname, 10, 10, 10, 100); + public static MatrixBlock readCompressedMatrixFromHDFS(String fname, long rlen, long clen, int blen) throws IOException { + return create().readMatrixFromHDFS(fname, rlen, clen, blen, 100); } @Override public MatrixBlock readMatrixFromHDFS(String fname, long rlen, long clen, int blen, long estnnz) throws IOException, DMLRuntimeException { - JobConf job = new JobConf(ConfigurationManager.getCachedJobConf()); Path path = new Path(fname); FileSystem fs = IOUtilFunctions.getFileSystem(path, job); - - checkValidInputFile(fs, path); - return readCompressedMatrix(fname, job, fs, (int) rlen, (int) clen, blen); } @@ -73,7 +84,56 @@ public MatrixBlock readMatrixFromInputStream(InputStream is, long rlen, long cle throw new NotImplementedException("Not implemented reading compressedMatrix from input stream"); } - private static MatrixBlock readCompressedMatrix(String fname, JobConf job, FileSystem fs, int rlen, int clen, + private MatrixBlock readCompressedMatrix(String fname, JobConf job, FileSystem fs, int rlen, int clen, int blen) + throws IOException { + if(k > 1) + return readCompressedMatrixParallel(fname, job, fs, rlen, clen, blen); + else + return readCompressedMatrixSingleThread(fname, job, fs, rlen, clen, blen); + } + + private MatrixBlock readCompressedMatrixParallel(String fname, final JobConf job, FileSystem fs, int rlen, int clen, + int blen) throws IOException { + + final Map data = new HashMap<>(); + Map> dicts = null; + final ExecutorService pool = CommonThreadPool.get(k); + try { + List>> rt = new ArrayList<>(); + List>>> dt = new ArrayList<>(); + for(Path subPath : IOUtilFunctions.getSequenceFilePaths(fs, new Path(fname))) { + final Path sp = subPath; + rt.add(pool.submit(() -> readColumnGroups(sp, job))); + } + + final Path dictPath = new Path(fname + ".dict"); + final boolean dictExists = fs.exists(dictPath); + if(dictExists) { + dicts = new HashMap<>(); + for(Path subPath : IOUtilFunctions.getSequenceFilePaths(fs, dictPath)) { + final Path sp = subPath; + dt.add(pool.submit(() -> readDictionaries(sp, job))); + } + } + + for(Future> e : rt) + data.putAll(e.get()); + + if(dictExists && dicts != null) + for(Future>> e : dt) + dicts.putAll(e.get()); + + return CLALibStack.combine(data, dicts, rlen, clen, blen, k); + } + catch(Exception e) { + throw new IOException("failed parallel reading ", e); + } + finally { + pool.shutdown(); + } + } + + private MatrixBlock readCompressedMatrixSingleThread(String fname, JobConf job, FileSystem fs, int rlen, int clen, int blen) throws IOException { final Map data = new HashMap<>(); @@ -94,7 +154,7 @@ private static MatrixBlock readCompressedMatrix(String fname, JobConf job, FileS if(data.containsValue(null)) throw new DMLCompressionException("Invalid read data contains null:"); - return CLALibStack.combine(data, dicts, OptimizerUtils.getParallelTextWriteParallelism()); + return CLALibStack.combine(data, dicts, k); } private static Map readColumnGroups(Path path, JobConf job) throws IOException { diff --git a/src/main/java/org/apache/sysds/runtime/compress/io/WriterCompressed.java b/src/main/java/org/apache/sysds/runtime/compress/io/WriterCompressed.java index 744a5e945a5..9c934592089 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/io/WriterCompressed.java +++ b/src/main/java/org/apache/sysds/runtime/compress/io/WriterCompressed.java @@ -46,6 +46,7 @@ import org.apache.sysds.runtime.compress.lib.CLALibSeparator; import org.apache.sysds.runtime.compress.lib.CLALibSeparator.SeparatedGroups; import org.apache.sysds.runtime.compress.lib.CLALibSlice; +import org.apache.sysds.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer; import org.apache.sysds.runtime.instructions.spark.CompressionSPInstruction.CompressionFunction; import org.apache.sysds.runtime.instructions.spark.utils.RDDConverterUtils; import org.apache.sysds.runtime.io.FileFormatProperties; @@ -62,7 +63,7 @@ public final class WriterCompressed extends MatrixWriter { protected static final Log LOG = LogFactory.getLog(WriterCompressed.class.getName()); protected static int jobUse = 0; - protected static JobConf job = new JobConf(ConfigurationManager.getCachedJobConf()); + protected static JobConf job = new JobConf(ConfigurationManager.getCachedJobConf()); private String fname; @@ -133,13 +134,12 @@ public void writeEmptyMatrixToHDFS(String fname, long rlen, long clen, int blen) } private void write(MatrixBlock src, final String fname, final int blen) throws IOException { - jobUse ++; - if(jobUse > 30){ + jobUse++; + if(jobUse > 30) { job = new JobConf(ConfigurationManager.getCachedJobConf()); jobUse = 0; } - - + if(this.fname != fname) { this.fname = fname; this.writers = null; @@ -147,7 +147,9 @@ private void write(MatrixBlock src, final String fname, final int blen) throws I fs = IOUtilFunctions.getFileSystem(new Path(fname), job); - final int k = OptimizerUtils.getParallelBinaryWriteParallelism(); + int k = OptimizerUtils.getParallelBinaryWriteParallelism(); + + k = Math.min(k, (int)(src.getInMemorySize() / InfrastructureAnalyzer.getBlockSize(fs))); final int rlen = src.getNumRows(); final int clen = src.getNumColumns(); // Try to compress! @@ -200,7 +202,6 @@ private void writeMultiBlockUncompressed(MatrixBlock b, final int rlen, final in private void writeMultiBlockCompressed(MatrixBlock b, final int rlen, final int clen, final int blen, int k) throws IOException { - if(k > 1) writeMultiBlockCompressedParallel(b, rlen, clen, blen, k); else @@ -211,6 +212,7 @@ private void writeMultiBlockCompressed(MatrixBlock b, final int rlen, final int private void writeMultiBlockCompressedSingleThread(MatrixBlock mb, final int rlen, final int clen, final int blen) throws IOException { try { + final CompressedMatrixBlock cmb = (CompressedMatrixBlock) mb; setupWrite(); final Path path = new Path(fname); @@ -219,7 +221,7 @@ private void writeMultiBlockCompressedSingleThread(MatrixBlock mb, final int rle final int sC = bc * blen; final int mC = Math.min(sC + blen, clen) - 1; // slice out the current columns - final CompressedMatrixBlock mc = CLALibSlice.sliceColumns((CompressedMatrixBlock) mb, sC, mC); + final CompressedMatrixBlock mc = CLALibSlice.sliceColumns(cmb, sC, mC); final SeparatedGroups s = CLALibSeparator.split(mc.getColGroups()); final CompressedMatrixBlock rmc = new CompressedMatrixBlock(mc.getNumRows(), mc.getNumColumns(), mc.getNonZeros(), false, s.indexStructures); @@ -260,22 +262,18 @@ private void writeMultiBlockCompressedParallel(MatrixBlock b, final int rlen, fi writerLocks[i] = new ReentrantLock(); } + final int colBlocks = (int) Math.ceil((double) clen / blen ); + final int nBlocks = (int) Math.ceil((double) rlen / blen); + final int blocksPerThread = Math.max(1, nBlocks * colBlocks / k ); + int i = 0; for(int bc = 0; bc * blen < clen; bc++) {// column blocks final int sC = bc * blen; final int mC = Math.min(sC + blen, clen) - 1; - // slice out the current columns final CompressedMatrixBlock mc = CLALibSlice.sliceColumns((CompressedMatrixBlock) b, sC, mC); - // slice out row blocks in this. - // final List blocks = CLALibSlice.sliceBlocks(mc, blen, k); // Slice compressed blocks - final SeparatedGroups s = CLALibSeparator.split(mc.getColGroups()); final CompressedMatrixBlock rmc = new CompressedMatrixBlock(mc.getNumRows(), mc.getNumColumns(), mc.getNonZeros(), false, s.indexStructures); - // slice out row blocks in this. - // List blocks = CLALibSlice.sliceBlocks(rmc, blen, 1); // Slice compressed blocks - final int nBlocks = (int) Math.ceil((double) rlen / blen); - final int blocksPerThread = Math.max(1, nBlocks / k); for(int block = 0; block < nBlocks; block += blocksPerThread) { WriteTask we = new WriteTask(i++ % k, rmc, bc, block, Math.min(nBlocks, block + blocksPerThread), blen); @@ -320,16 +318,12 @@ private Path getPath(int id) { return new Path(fname, IOUtilFunctions.getPartFileName(id)); } - // private Writer getWriter(String fname) throws IOException { - // final Path path = new Path(fname); - // return generateWriter(job, path); - // } - private static Writer generateWriter(JobConf job, Path path, FileSystem fs) throws IOException { - - return SequenceFile.createWriter(job, Writer.file(path), Writer.bufferSize(4096), - Writer.keyClass(MatrixIndexes.class), Writer.valueClass(CompressedWriteBlock.class), - Writer.compression(SequenceFile.CompressionType.NONE), // No Compression type on disk + + return SequenceFile.createWriter(job, Writer.file(path), Writer.bufferSize(4096), // + Writer.keyClass(MatrixIndexes.class), // + Writer.valueClass(CompressedWriteBlock.class), // + Writer.compression(IOUtilFunctions.getCompressionEncodingType(), IOUtilFunctions.getCompressionCodec()), Writer.replication((short) 1)); } @@ -344,9 +338,12 @@ private static void cleanup(JobConf job, Path path, FileSystem fs) throws IOExce private static void write(Writer w, CompressedMatrixBlock rmc, int bc, int bl, int bu, int blen) throws IOException { final int nrow = rmc.getNumRows(); + final int nGroups = rmc.getColGroups().size(); for(int b = bl; b < bu; b++) { MatrixIndexes index = new MatrixIndexes(b, bc); MatrixBlock cb = CLALibSlice.sliceRowsCompressed(rmc, (b - 1) * blen, Math.min(b * blen, nrow) - 1); + if(cb instanceof CompressedMatrixBlock && ((CompressedMatrixBlock)cb).getColGroups().size() != nGroups) + throw new RuntimeException("invalid writing of different number of column groups"); CompressedWriteBlock blk = new CompressedWriteBlock(cb); w.append(index, blk); } @@ -404,7 +401,9 @@ public Object call() throws Exception { Writer.bufferSize(4096), // Writer.keyClass(DictWritable.K.class), // Writer.valueClass(DictWritable.class), // - Writer.compression(SequenceFile.CompressionType.NONE), // + Writer.compression(// + IOUtilFunctions.getCompressionEncodingType(), // + IOUtilFunctions.getCompressionCodec()), // Writer.replication((short) 1))) { w.append(new DictWritable.K(id), new DictWritable(dicts)); } diff --git a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibBinaryCellOp.java b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibBinaryCellOp.java index 94487ea4abc..13e5e3c9381 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibBinaryCellOp.java +++ b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibBinaryCellOp.java @@ -48,6 +48,7 @@ import org.apache.sysds.runtime.functionobjects.Multiply; import org.apache.sysds.runtime.functionobjects.Plus; import org.apache.sysds.runtime.functionobjects.PlusMultiply; +import org.apache.sysds.runtime.functionobjects.ValueComparisonFunction; import org.apache.sysds.runtime.functionobjects.ValueFunction; import org.apache.sysds.runtime.matrix.data.LibMatrixBincell; import org.apache.sysds.runtime.matrix.data.LibMatrixBincell.BinaryAccessType; @@ -355,8 +356,16 @@ private static MatrixBlock binaryMVCol(CompressedMatrixBlock m1, MatrixBlock m2, else nnz = binaryMVColMultiThread(m1, m2, op, left, ret); + if(op.fn instanceof ValueComparisonFunction) { + if(nnz == (long) nRows * nCols) + return CompressedMatrixBlockFactory.createConstant(nRows, nCols, 1.0); + + else if(nnz == 0) + return CompressedMatrixBlockFactory.createConstant(nRows, nCols, 0.0); + } ret.setNonZeros(nnz); ret.examSparsity(); + return ret; } diff --git a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibCompAgg.java b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibCompAgg.java index 109f7514576..95a460a2e0c 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibCompAgg.java +++ b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibCompAgg.java @@ -568,7 +568,7 @@ protected UAOverlappingTask(List filteredGroups, MatrixBlock ret, int _op = op; _rl = rl; _ru = ru; - _blklen = Math.max(65536 * 2 / ret.getNumColumns() / filteredGroups.size(), 64); + _blklen = Math.max(65536 / ret.getNumColumns() / filteredGroups.size(), 64); _ret = ret; _nCol = nCol; } diff --git a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibMatrixMult.java b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibMatrixMult.java index 185b684f652..5b392e8b7e0 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibMatrixMult.java +++ b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibMatrixMult.java @@ -63,11 +63,11 @@ public static MatrixBlock matrixMultiply(MatrixBlock m1, MatrixBlock m2, MatrixB } if(!(m1 instanceof CompressedMatrixBlock) && transposeLeft) { - m1 = LibMatrixReorg.transpose(m1, k); + m1 = LibMatrixReorg.transpose(m1, k, true); transposeLeft = false; } else if(!(m2 instanceof CompressedMatrixBlock) && transposeRight) { - m2 = LibMatrixReorg.transpose(m2, k); + m2 = LibMatrixReorg.transpose(m2, k, true); transposeRight = false; } } @@ -87,7 +87,7 @@ else if(!(m2 instanceof CompressedMatrixBlock) && transposeRight) { LOG.warn("Transposing decompression"); ret = ((CompressedMatrixBlock) ret).decompress(k); } - ret = LibMatrixReorg.transpose(ret, k); + ret = LibMatrixReorg.transpose(ret, k, true); } return ret; diff --git a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibSeparator.java b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibSeparator.java index 42415bcce45..d78b053ad72 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibSeparator.java +++ b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibSeparator.java @@ -44,8 +44,8 @@ public interface CLALibSeparator { * @return A split of the groups and their dictionaries. */ public static SeparatedGroups split(List gs) { - List dicts = new ArrayList<>(); - List indexStructures = new ArrayList<>(); + List dicts = new ArrayList<>(gs.size()); + List indexStructures = new ArrayList<>(gs.size()); for(AColGroup g : gs) { if(g instanceof ADictBasedColGroup) { ADictBasedColGroup dg = (ADictBasedColGroup) g; @@ -68,24 +68,25 @@ public static SeparatedGroups split(List gs) { * @param gs groups to combine with dictionaries * @param d dictionaries to combine back into the groups. * @param blen The block size. - * @return A combined list of columngroups. + * @return A combined list of column groups. */ public static List combine(List gs, Map> d, int blen) { int gid = 0; + int s = 0; + for(List e : d.values()) + s += e.size(); + + if(gs.size() != s) + throw new RuntimeException( + "Invalid combine of of groups and dictionaries groups:" + gs.size() + " vs dicts" + s); for(int i = 0; i < d.size(); i++) { List dd = d.get(i); for(int j = 0; j < dd.size(); j++) { IDictionary ddd = dd.get(j); - if(!(ddd instanceof PlaceHolderDict)) { - - AColGroup g = gs.get(gid); - while(!(g instanceof ADictBasedColGroup)) { - gid++; - g = gs.get(gid); - } + AColGroup g = gs.get(gid); + if(g instanceof ADictBasedColGroup) { ADictBasedColGroup dg = (ADictBasedColGroup) g; - gs.set(gid, dg.copyAndSet(ddd)); } diff --git a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibSlice.java b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibSlice.java index 5ba36923b9a..06198c3f2e0 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibSlice.java +++ b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibSlice.java @@ -32,8 +32,6 @@ import org.apache.sysds.runtime.compress.colgroup.AColGroup; import org.apache.sysds.runtime.compress.colgroup.ColGroupConst; import org.apache.sysds.runtime.compress.colgroup.ColGroupEmpty; -import org.apache.sysds.runtime.compress.colgroup.indexes.ColIndexFactory; -import org.apache.sysds.runtime.compress.colgroup.indexes.IColIndex; import org.apache.sysds.runtime.data.DenseBlock; import org.apache.sysds.runtime.matrix.data.MatrixBlock; import org.apache.sysds.runtime.util.CommonThreadPool; @@ -146,7 +144,6 @@ private static MatrixBlock sliceRowsDecompress(CompressedMatrixBlock cmb, int rl public static MatrixBlock sliceRowsCompressed(CompressedMatrixBlock cmb, int rl, int ru) { final List groups = cmb.getColGroups(); final List newColGroups = new ArrayList<>(groups.size()); - final List emptyGroups = new ArrayList<>(); final int rue = ru + 1; final CompressedMatrixBlock ret = new CompressedMatrixBlock(rue - rl, cmb.getNumColumns()); @@ -156,17 +153,12 @@ public static MatrixBlock sliceRowsCompressed(CompressedMatrixBlock cmb, int rl, if(slice != null) newColGroups.add(slice); else - emptyGroups.add(grp.getColIndices()); + newColGroups.add(new ColGroupEmpty(grp.getColIndices())); } if(newColGroups.size() == 0) return new MatrixBlock(rue - rl, cmb.getNumColumns(), 0.0); - if(!emptyGroups.isEmpty()) { - IColIndex empties = ColIndexFactory.combineIndexes(emptyGroups); - newColGroups.add(new ColGroupEmpty(empties)); - } - ret.allocateColGroupList(newColGroups); ret.setNonZeros(-1); ret.setOverlapping(cmb.isOverlapping()); diff --git a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibStack.java b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibStack.java index e6756604ace..e88913391f9 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibStack.java +++ b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibStack.java @@ -31,7 +31,6 @@ import org.apache.commons.logging.LogFactory; import org.apache.sysds.runtime.DMLRuntimeException; import org.apache.sysds.runtime.compress.CompressedMatrixBlock; -import org.apache.sysds.runtime.compress.DMLCompressionException; import org.apache.sysds.runtime.compress.colgroup.AColGroup; import org.apache.sysds.runtime.compress.colgroup.dictionary.IDictionary; import org.apache.sysds.runtime.matrix.data.MatrixBlock; @@ -51,7 +50,7 @@ private CLALibStack() { * The intension is that the combining is able to resolve differences in the different MatrixBlocks allocation. * * @param m The map of Index to MatrixBLocks - * @param d A map of the dictionaries contained in the compressionscheme + * @param d A map of the dictionaries contained in the compression scheme * @param k The parallelization degree allowed for this operation * @return The combined matrix. */ @@ -113,8 +112,7 @@ private static MatrixBlock combine(final Map m, Map< return combineColumnGroups(m, d, lookup, rlen, clen, blen, k); } catch(Exception e) { - // throw new RuntimeException("failed normal combine", e); - LOG.error("Failed to combine compressed blocks, fallback to decompression.", e); + LOG.warn("Failed to combine compressed blocks, fallback to decompression.", e); return combineViaDecompression(m, rlen, clen, blen, k); } } @@ -144,14 +142,8 @@ private static MatrixBlock combineColumnGroups(final Map gs = cmb.getColGroups(); nGroups += gs.size(); @@ -165,20 +157,18 @@ private static MatrixBlock combineColumnGroups(final Map gs = cmb.getColGroups(); + if(cgid + gs.size() > nGroups) + return combineViaDecompression(m, rlen, clen, blen, k); for(int i = 0; i < gs.size(); i++) { AColGroup g = gs.get(i); final AColGroup gc = bc > 0 ? g.shiftColIndices(bc * blen) : g; - finalCols[cgid][br] = gc; - if(br != 0 && (finalCols[cgid][0] == null || - !finalCols[cgid][br].getColIndices().equals(finalCols[cgid][0].getColIndices()))) { - LOG.warn("Combining via decompression. There was an column with different index"); - return combineViaDecompression(m, rlen, clen, blen, k); - } cgid++; - } } if(cgid != finalCols.length) { @@ -195,18 +185,20 @@ private static MatrixBlock combineColumnGroups(final Map { - return combineN(x); + AColGroup r = AColGroup.appendN(x, blen, rlen); + return r; }).collect(Collectors.toList()); }).get(); - if(d != null) { - finalGroups = CLALibSeparator.combine(finalGroups, d, blen); - } - if(finalGroups.contains(null)) { LOG.warn("Combining via decompression. There was a column group that did not append "); return combineViaDecompression(m, rlen, clen, blen, k); } + + if(d != null) { + finalGroups = CLALibSeparator.combine(finalGroups, d, blen); + } + return new CompressedMatrixBlock(rlen, clen, -1, false, finalGroups); } catch(InterruptedException | ExecutionException e) { @@ -216,14 +208,4 @@ private static MatrixBlock combineColumnGroups(final Map get(double key) { @Override public DCounts inc(Double key, int c, int id) { - // return inc((double) key, c, id); - throw new NotImplementedException(); + return inc((double) key, c, id); } @Override diff --git a/src/main/java/org/apache/sysds/runtime/compress/utils/ACountHashMap.java b/src/main/java/org/apache/sysds/runtime/compress/utils/ACountHashMap.java index 5182dfde48d..9e8e87c83b6 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/utils/ACountHashMap.java +++ b/src/main/java/org/apache/sysds/runtime/compress/utils/ACountHashMap.java @@ -25,254 +25,254 @@ import org.apache.sysds.runtime.compress.utils.ACount.DCounts; public abstract class ACountHashMap implements Cloneable { - protected static final Log LOG = LogFactory.getLog(ACountHashMap.class.getName()); - protected static final int RESIZE_FACTOR = 2; - protected static final float LOAD_FACTOR = 0.80f; - protected static final int shortCutSize = 10; - - protected int size; - protected ACount[] data; - - public ACountHashMap() { - data = create(1); - size = 0; - } - - public ACountHashMap(int arrSize) { - if(arrSize < shortCutSize) - data = create(1); - else { - arrSize = (int) (arrSize * (1.0 / LOAD_FACTOR)); - arrSize += arrSize % 2 == 0 ? 1 : 0; - data = create(arrSize); - } - size = 0; - } - - public int size() { - return size; - } - - /** - * Increment and return the id of the incremeted index. - * - * @param key The key to increment - * @return The id of the incremented entry. - */ - public final int increment(T key) { - return increment(key, 1); - } - - public final int increment(double key) { - return increment(key, 1); - } - - /** - * Increment and return the id of the incremented index. - * - * @param key The key to increment - * @param count The number of times to increment the value - * @return The Id of the incremented entry. - */ - public int increment(final T key, final int count) { - // skip hash if data array is 1 length - final int ix = data.length < shortCutSize ? 0 : hash(key) % data.length; - - try { - return increment(key, ix, count); - } - catch(ArrayIndexOutOfBoundsException e) { - if(ix < 0) - return increment(key, 0, count); - else - throw new RuntimeException(e); - } - } - - private final int increment(final T key, final int ix, final int count) throws ArrayIndexOutOfBoundsException { - final ACount l = data[ix]; - if(l == null) { - data[ix] = create(key, size); - // never try to resize here since we use a new unused bucket. - return size++; - } - else { - final ACount v = l.inc(key, count, size); - if(v.id == size) { - size++; - resize(); - return size - 1; - } - else { - // do not resize if not new. - return v.id; - } - } - } - - public final int increment(final double key, final int count) { - // skip hash if data array is 1 length - final int ix = data.length < shortCutSize ? 0 : DCounts.hashIndex(key) % data.length; - - try { - return increment(key, ix, count); - } - catch(ArrayIndexOutOfBoundsException e) { - if(ix < 0) - return increment(key, 0, count); - else - throw new RuntimeException(e); - } - } - - private final int increment(final double key, final int ix, final int count) throws ArrayIndexOutOfBoundsException { - final ACount l = data[ix]; - if(l == null) { - data[ix] = create(key, size); - // never try to resize here since we use a new unused bucket. - return size++; - } - else { - final ACount v = l.inc(key, count, size); - if(v.id == size) { - size++; - resize(); - return size - 1; - } - else { - // do not resize if not new. - return v.id; - } - } - } - - public int get(T key) { - return getC(key).count; - } - - public int getId(T key) { - return getC(key).id; - } - - public ACount getC(T key) { - final int ix = data.length < shortCutSize ? 0 : hash(key) % data.length; - try { - ACount l = data[ix]; - return l != null ? l.get(key) : null; - } - catch(ArrayIndexOutOfBoundsException e) { - if(ix < 0) { - ACount l = data[0]; - return l != null ? l.get(key) : null; - } - else - throw new RuntimeException(e); - } - } - - public int getOrDefault(T key, int def) { - ACount e = getC(key); - return (e == null) ? def : e.count; - } - - public final ACount[] extractValues() { - final ACount[] ret = create(size); - int i = 0; - for(ACount e : data) { - while(e != null) { - ret[i++] = e; - e = e.next(); - } - } - return ret; - } - - public T getMostFrequent() { - T f = null; - int fq = 0; - for(ACount e : data) { - while(e != null) { - if(e.count > fq) { - fq = e.count; - f = e.key(); - } - e = e.next(); - } - } - return f; - } - - private void resize() { - if(size >= LOAD_FACTOR * data.length && size > shortCutSize) - // +1 to make the hash buckets better - resize(Math.max(data.length, shortCutSize) * RESIZE_FACTOR + 1); - } - - private void resize(int underlying_size) { - - // resize data array and copy existing contents - final ACount[] olddata = data; - data = create(underlying_size); - - // rehash all entries - for(ACount e : olddata) - appendValue(e); - } - - protected void appendValue(ACount ent) { - if(ent != null) { - // take the tail recursively first - appendValue(ent.next()); // append tail first - ent.setNext(null); // set this tail to null. - final int ix = hash(ent.key()) % data.length; - try { - appendValue(ent, ix); - } - catch(ArrayIndexOutOfBoundsException e) { - if(ix < 0) - appendValue(ent, 0); - else - throw new RuntimeException(e); - } - } - } - - private void appendValue(ACount ent, int ix) { - ACount l = data[ix]; - data[ix] = ent; - ent.setNext(l); - } - - public void sortBuckets() { - if(size > 10) - for(int i = 0; i < data.length; i++) - if(data[i] != null) - data[i] = data[i].sort(); - } - - public void reset(int size) { - this.data = create(size); - this.size = 0; - } - - protected abstract ACount[] create(int size); - - protected abstract int hash(T key); - - protected abstract ACount create(T key, int id); - - protected ACount create(double key, int id) { - throw new NotImplementedException(); - } - - @Override - public String toString() { - StringBuilder sb = new StringBuilder(); - sb.append(this.getClass().getSimpleName()); - for(int i = 0; i < data.length; i++) - if(data[i] != null) - sb.append(", " + data[i]); - return sb.toString(); - } + protected static final Log LOG = LogFactory.getLog(ACountHashMap.class.getName()); + protected static final int RESIZE_FACTOR = 2; + protected static final float LOAD_FACTOR = 0.80f; + protected static final int shortCutSize = 10; + + protected int size; + protected ACount[] data; + + public ACountHashMap() { + data = create(1); + size = 0; + } + + public ACountHashMap(int arrSize) { + if(arrSize < shortCutSize) + data = create(1); + else { + arrSize = (int) (arrSize * (1.0 / LOAD_FACTOR)); + arrSize += arrSize % 2 == 0 ? 1 : 0; + data = create(arrSize); + } + size = 0; + } + + public int size() { + return size; + } + + /** + * Increment and return the id of the incremeted index. + * + * @param key The key to increment + * @return The id of the incremented entry. + */ + public final int increment(T key) { + return increment(key, 1); + } + + public final int increment(double key) { + return increment(key, 1); + } + + /** + * Increment and return the id of the incremented index. + * + * @param key The key to increment + * @param count The number of times to increment the value + * @return The Id of the incremented entry. + */ + public synchronized int increment(final T key, final int count) { + // skip hash if data array is 1 length + final int ix = data.length < shortCutSize ? 0 : hash(key) % data.length; + + try { + return increment(key, ix, count); + } + catch(ArrayIndexOutOfBoundsException e) { + if(ix < 0) + return increment(key, 0, count); + else + throw new RuntimeException(e); + } + } + + private final int increment(final T key, final int ix, final int count) throws ArrayIndexOutOfBoundsException { + final ACount l = data[ix]; + if(l == null) { + data[ix] = create(key, size); + // never try to resize here since we use a new unused bucket. + return size++; + } + else { + final ACount v = l.inc(key, count, size); + if(v.id == size) { + size++; + resize(); + return size - 1; + } + else { + // do not resize if not new. + return v.id; + } + } + } + + public synchronized final int increment(final double key, final int count) { + // skip hash if data array is 1 length + final int ix = data.length < shortCutSize ? 0 : DCounts.hashIndex(key) % data.length; + + try { + return increment(key, ix, count); + } + catch(ArrayIndexOutOfBoundsException e) { + if(ix < 0) + return increment(key, 0, count); + else + throw new RuntimeException(e); + } + } + + private final int increment(final double key, final int ix, final int count) throws ArrayIndexOutOfBoundsException { + final ACount l = data[ix]; + if(l == null) { + data[ix] = create(key, size); + // never try to resize here since we use a new unused bucket. + return size++; + } + else { + final ACount v = l.inc(key, count, size); + if(v.id == size) { + size++; + resize(); + return size - 1; + } + else { + // do not resize if not new. + return v.id; + } + } + } + + public int get(T key) { + return getC(key).count; + } + + public int getId(T key) { + return getC(key).id; + } + + public ACount getC(T key) { + final int ix = data.length < shortCutSize ? 0 : hash(key) % data.length; + try { + ACount l = data[ix]; + return l != null ? l.get(key) : null; + } + catch(ArrayIndexOutOfBoundsException e) { + if(ix < 0) { + ACount l = data[0]; + return l != null ? l.get(key) : null; + } + else + throw new RuntimeException(e); + } + } + + public int getOrDefault(T key, int def) { + ACount e = getC(key); + return (e == null) ? def : e.count; + } + + public final ACount[] extractValues() { + final ACount[] ret = create(size); + int i = 0; + for(ACount e : data) { + while(e != null) { + ret[i++] = e; + e = e.next(); + } + } + return ret; + } + + public T getMostFrequent() { + T f = null; + int fq = 0; + for(ACount e : data) { + while(e != null) { + if(e.count > fq) { + fq = e.count; + f = e.key(); + } + e = e.next(); + } + } + return f; + } + + private void resize() { + if(size >= LOAD_FACTOR * data.length && size > shortCutSize) + // +1 to make the hash buckets better + resize(Math.max(data.length, shortCutSize) * RESIZE_FACTOR + 1); + } + + private void resize(int underlying_size) { + + // resize data array and copy existing contents + final ACount[] olddata = data; + data = create(underlying_size); + + // rehash all entries + for(ACount e : olddata) + appendValue(e); + } + + protected void appendValue(ACount ent) { + if(ent != null) { + // take the tail recursively first + appendValue(ent.next()); // append tail first + ent.setNext(null); // set this tail to null. + final int ix = hash(ent.key()) % data.length; + try { + appendValue(ent, ix); + } + catch(ArrayIndexOutOfBoundsException e) { + if(ix < 0) + appendValue(ent, 0); + else + throw new RuntimeException(e); + } + } + } + + private void appendValue(ACount ent, int ix) { + ACount l = data[ix]; + data[ix] = ent; + ent.setNext(l); + } + + public void sortBuckets() { + if(size > 10) + for(int i = 0; i < data.length; i++) + if(data[i] != null) + data[i] = data[i].sort(); + } + + public void reset(int size) { + this.data = create(size); + this.size = 0; + } + + protected abstract ACount[] create(int size); + + protected abstract int hash(T key); + + protected abstract ACount create(T key, int id); + + protected ACount create(double key, int id) { + throw new NotImplementedException(); + } + + @Override + public String toString() { + StringBuilder sb = new StringBuilder(); + sb.append(this.getClass().getSimpleName()); + for(int i = 0; i < data.length; i++) + if(data[i] != null) + sb.append(", " + data[i]); + return sb.toString(); + } } diff --git a/src/main/java/org/apache/sysds/runtime/compress/utils/CompressRDDClean.java b/src/main/java/org/apache/sysds/runtime/compress/utils/CompressRDDClean.java index 0355722ff7b..5b2e9ab3826 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/utils/CompressRDDClean.java +++ b/src/main/java/org/apache/sysds/runtime/compress/utils/CompressRDDClean.java @@ -17,23 +17,21 @@ * under the License. */ - package org.apache.sysds.runtime.compress.utils; - import org.apache.spark.api.java.function.Function; import org.apache.sysds.runtime.compress.CompressedMatrixBlock; import org.apache.sysds.runtime.matrix.data.MatrixBlock; public class CompressRDDClean implements Function { - + private static final long serialVersionUID = -704403012606821854L; @Override public MatrixBlock call(MatrixBlock mb) throws Exception { - - if(mb instanceof CompressedMatrixBlock){ - CompressedMatrixBlock cmb = (CompressedMatrixBlock)mb; + + if(mb instanceof CompressedMatrixBlock) { + CompressedMatrixBlock cmb = (CompressedMatrixBlock) mb; cmb.clearSoftReferenceToDecompressed(); return cmb; } diff --git a/src/main/java/org/apache/sysds/runtime/compress/utils/DblArrayCountHashMap.java b/src/main/java/org/apache/sysds/runtime/compress/utils/DblArrayCountHashMap.java index 059834e965a..cf8771d83aa 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/utils/DblArrayCountHashMap.java +++ b/src/main/java/org/apache/sysds/runtime/compress/utils/DblArrayCountHashMap.java @@ -35,7 +35,7 @@ protected final DArrCounts[] create(int size) { return new DArrCounts[size]; } - protected int hash(DblArray key) { + protected final int hash(DblArray key) { return Math.abs(key.hashCode()); } @@ -44,13 +44,12 @@ protected final DArrCounts create(DblArray key, int id) { } @Override - public DblArrayCountHashMap clone() { + public DblArrayCountHashMap clone() { DblArrayCountHashMap ret = new DblArrayCountHashMap(size); - - for(ACount e : data) - ret.appendValue(e); + for(ACount e : data) + ret.appendValue(e); ret.size = size; return ret; - } + } } diff --git a/src/main/java/org/apache/sysds/runtime/compress/utils/DblArrayIntListHashMap.java b/src/main/java/org/apache/sysds/runtime/compress/utils/DblArrayIntListHashMap.java index 55fbf90b464..b26e41f648c 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/utils/DblArrayIntListHashMap.java +++ b/src/main/java/org/apache/sysds/runtime/compress/utils/DblArrayIntListHashMap.java @@ -20,7 +20,6 @@ package org.apache.sysds.runtime.compress.utils; import java.util.ArrayList; -import java.util.Arrays; import java.util.List; import org.apache.commons.logging.Log; @@ -35,12 +34,10 @@ public class DblArrayIntListHashMap { protected static final int INIT_CAPACITY = 8; protected static final int RESIZE_FACTOR = 2; - protected static final float LOAD_FACTOR = 0.5f; - public static int hashMissCount = 0; + protected static final float LOAD_FACTOR = 0.8f; - protected int _size = -1; - - protected DArrayIListEntry[] _data = null; + protected int _size; + protected DArrayIListEntry[] _data; public DblArrayIntListHashMap() { _data = new DArrayIListEntry[INIT_CAPACITY]; @@ -57,64 +54,21 @@ public int size() { } public IntArrayList get(DblArray key) { - // probe for early abort - if(_size == 0) - return null; - // compute entry index position - int hash = key.hashCode(); - int ix = indexFor(hash, _data.length); - - // find entry - - while(_data[ix] != null && !_data[ix].keyEquals(key)) { - hash = Integer.hashCode(hash + 1); // hash of hash - ix = indexFor(hash, _data.length); - hashMissCount++; - } - DArrayIListEntry e = _data[ix]; - if(e != null) - return e.value; - return null; - } - - private void appendValue(DblArray key, IntArrayList value) { - // compute entry index position - int hash = key.hashCode(); - int ix = indexFor(hash, _data.length); - - // add new table entry (constant time) - while(_data[ix] != null && !_data[ix].keyEquals(key)) { - hash = Integer.hashCode(hash + 1); // hash of hash - ix = indexFor(hash, _data.length); - hashMissCount++; - } - _data[ix] = new DArrayIListEntry(key, value); - _size++; + final int hash = key.hashCode(); + final int ix = indexFor(hash, _data.length); + return _data[ix] == null ? null: _data[ix].get(key); } public void appendValue(DblArray key, int value) { int hash = key.hashCode(); int ix = indexFor(hash, _data.length); - - while(_data[ix] != null && !_data[ix].keyEquals(key)) { - hash = Integer.hashCode(hash + 1); // hash of hash - ix = indexFor(hash, _data.length); - hashMissCount++; - } - - DArrayIListEntry e = _data[ix]; - if(e == null) { - final IntArrayList lstPtr = new IntArrayList(); - lstPtr.appendValue(value); - _data[ix] = new DArrayIListEntry(new DblArray(key), lstPtr); + if(_data[ix] == null) { + _data[ix] = new DArrayIListEntry(new DblArray(key), value); _size++; } - else { - final IntArrayList lstPtr = e.value; - lstPtr.appendValue(value); - } + else if(_data[ix].add(key, value)) + _size++; - // resize if necessary if(_size >= LOAD_FACTOR * _data.length) resize(); } @@ -122,18 +76,17 @@ public void appendValue(DblArray key, int value) { public List extractValues() { List ret = new ArrayList<>(); - for(DArrayIListEntry e : _data) - if(e != null) + for(DArrayIListEntry e : _data) { + while(e != null) { ret.add(e); + e = e.next; + } + } - // Collections.sort(ret); return ret; } private void resize() { - // check for integer overflow on resize - if(_data.length > Integer.MAX_VALUE / RESIZE_FACTOR) - return; // resize data array and copy existing contents DArrayIListEntry[] olddata = _data; @@ -141,66 +94,112 @@ private void resize() { _size = 0; // rehash all entries - for(DArrayIListEntry e : olddata) - if(e != null) - appendValue(e.key, e.value); - } - - public void reset() { - Arrays.fill(_data, null); - _size = 0; + for(DArrayIListEntry e : olddata) { + while(e != null) { + reinsert(e.key, e.value); + e = e.next; + } + } } - public void reset(int size) { - int newSize = Util.getPow2(size); - if(newSize > _data.length) { - _data = new DArrayIListEntry[newSize]; + private void reinsert(DblArray key, IntArrayList value) { + // compute entry index position + int hash = key.hashCode(); + int ix = indexFor(hash, _data.length); + if(_data[ix] == null) { + _data[ix] = new DArrayIListEntry(key, value); + _size++; } else { - Arrays.fill(_data, null); - // only allocate new if the size is smaller than 2x - if(size < _data.length / 2) - _data = new DArrayIListEntry[newSize]; + _data[ix].reinsert(key, value); + _size++; } - _size = 0; } - protected static int indexFor(int h, int length) { + private static int indexFor(int h, int length) { return h & (length - 1); } public static class DArrayIListEntry { - public DblArray key; - public IntArrayList value; + public final DblArray key; + public final IntArrayList value; + private DArrayIListEntry next; + + private DArrayIListEntry(DblArray key, int value) { + this.key = key; + this.value = new IntArrayList(); + this.value.appendValue(value); + next = null; + } + + private DArrayIListEntry(DblArray key, IntArrayList value) { + this.key = key; + this.value = value; + next = null; + } + + private final boolean reinsert(final DblArray key, final IntArrayList value) { + DArrayIListEntry e = this; + while(e.next != null) + e = e.next; + + e.next = new DArrayIListEntry(key, value); + return true; + } - public DArrayIListEntry(DblArray ekey, IntArrayList evalue) { - key = ekey; - value = evalue; + private final boolean add(final DblArray key, final int value) { + DArrayIListEntry e = this; + if(e.key.equals(key)) { + this.value.appendValue(value); + return false; + } + while(e.next != null) { + e = e.next; + if(e.key.equals(key)) { + e.value.appendValue(value); + return false; + } + } + e.next = new DArrayIListEntry(new DblArray(key), new IntArrayList()); + e.next.value.appendValue(value); + return true; } - @Override - public String toString() { - return key + ":" + value; + private IntArrayList get(DblArray key) { + DArrayIListEntry e = this; + boolean eq = e.key.equals(key); + while(e.next != null && !eq) { + e = e.next; + eq = e.key.equals(key); + } + return eq ? e.value : null; } - public boolean keyEquals(DblArray keyThat) { - return key.equals(keyThat); + private void toString(StringBuilder sb) { + DArrayIListEntry e = this; + while(e != null) { + sb.append(e.key); + sb.append(":"); + sb.append(e.value); + if(e.next != null) + sb.append(" -> "); + e = e.next; + } } } @Override public String toString() { StringBuilder sb = new StringBuilder(); - sb.append(this.getClass().getSimpleName() + this.hashCode()); + sb.append(this.getClass().getSimpleName()); sb.append(" " + _size); for(int i = 0; i < _data.length; i++) { DArrayIListEntry ent = _data[i]; if(ent != null) { sb.append("\n"); - sb.append("id:" + i); sb.append("["); - sb.append(ent); + ent.toString(sb); sb.append("]"); } } diff --git a/src/main/java/org/apache/sysds/runtime/compress/utils/DoubleCountHashMap.java b/src/main/java/org/apache/sysds/runtime/compress/utils/DoubleCountHashMap.java index ba85964bc17..664d8485cfa 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/utils/DoubleCountHashMap.java +++ b/src/main/java/org/apache/sysds/runtime/compress/utils/DoubleCountHashMap.java @@ -48,18 +48,38 @@ protected final DCounts create(Double key, int id) { } public double[] getDictionary() { + return getDictionary(size); + } + + + public double[] getDictionary(int size) { double[] ret = new double[size]; for(int i = 0; i < data.length; i++) { ACount e = data[i]; while(e != null) { - ret[e.id] = e.key(); + if(e.id >= 0) + ret[e.id] = e.key(); e = e.next(); } } - return ret; } + + public void replaceWithUIDs(double v) { + int i = 0; + for(ACount e : data) { + while(e != null) { + if(!e.key().equals(v)) + e.id = i++; + else + e.id = -1; + e = e.next(); + } + } + + } + public void replaceWithUIDsNoZero() { int i = 0; Double z = Double.valueOf(0.0); diff --git a/src/main/java/org/apache/sysds/runtime/compress/utils/IntArrayList.java b/src/main/java/org/apache/sysds/runtime/compress/utils/IntArrayList.java index b787c2ee1d6..1edd33a74bc 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/utils/IntArrayList.java +++ b/src/main/java/org/apache/sysds/runtime/compress/utils/IntArrayList.java @@ -52,7 +52,7 @@ public int size() { public void appendValue(int value) { // allocate or resize array if necessary - if(_size + 1 >= _data.length) + if(_size + 1 > _data.length) resize(); // append value @@ -60,6 +60,14 @@ public void appendValue(int value) { _size++; } + public void appendValue(IntArrayList value) { + // allocate or resize array if necessary + if(_size + value._size >= _data.length) + _data = Arrays.copyOf(_data, _size + value._size); + System.arraycopy(value._data, 0, _data, _size, value._size); + _size = _size + value._size; + } + /** * Returns the underlying array of offsets. Note that this array might be physically larger than the actual length of * the offset lists. Use size() to obtain the actual length. @@ -75,20 +83,22 @@ public int get(int index) { } public int[] extractValues(boolean trim) { - int[] ret = extractValues(); - return (trim && _size < ret.length) ? Arrays.copyOfRange(ret, 0, _size) : ret; + if(trim ){ + if(_data.length == _size) + return _data; + return Arrays.copyOfRange(_data, 0, _size); + } + else + return _data; } private void resize() { - // check for integer overflow on resize - if(_data.length > Integer.MAX_VALUE / RESIZE_FACTOR) - throw new RuntimeException("IntArrayList resize leads to integer overflow: size=" + _size); // resize data array and copy existing contents _data = Arrays.copyOf(_data, _data.length * RESIZE_FACTOR); } - public void reset(){ + public void reset() { _size = 0; } diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/context/SparkExecutionContext.java b/src/main/java/org/apache/sysds/runtime/controlprogram/context/SparkExecutionContext.java index 14eff0df5da..687377345cc 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/context/SparkExecutionContext.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/context/SparkExecutionContext.java @@ -1769,7 +1769,7 @@ public static long getStorageSpaceUsed() { * * @return spark cluster configuration */ - public static SparkClusterConfig getSparkClusterConfig() { + public synchronized static SparkClusterConfig getSparkClusterConfig() { //lazy creation of spark cluster config if( _sconf == null ) _sconf = new SparkClusterConfig(); @@ -1782,8 +1782,7 @@ public static SparkClusterConfig getSparkClusterConfig() { * @return broadcast memory budget */ public static double getBroadcastMemoryBudget() { - return getSparkClusterConfig() - .getBroadcastMemoryBudget(); + return getSparkClusterConfig().getBroadcastMemoryBudget(); } /** diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/FrameBlock.java b/src/main/java/org/apache/sysds/runtime/frame/data/FrameBlock.java index e1dbb538630..456fc9afc7b 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/FrameBlock.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/FrameBlock.java @@ -24,7 +24,9 @@ import java.io.Externalizable; import java.io.IOException; import java.io.ObjectInput; +import java.io.ObjectInputStream; import java.io.ObjectOutput; +import java.io.ObjectOutputStream; import java.io.Serializable; import java.lang.ref.SoftReference; import java.lang.reflect.InvocationTargetException; @@ -69,6 +71,7 @@ import org.apache.sysds.runtime.instructions.cp.ScalarObject; import org.apache.sysds.runtime.io.IOUtilFunctions; import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import org.apache.sysds.runtime.matrix.data.MatrixBlockDataInput; import org.apache.sysds.runtime.matrix.data.Pair; import org.apache.sysds.runtime.matrix.operators.BinaryOperator; import org.apache.sysds.runtime.meta.DataCharacteristics; @@ -76,6 +79,8 @@ import org.apache.sysds.runtime.util.CommonThreadPool; import org.apache.sysds.runtime.util.DMVUtils; import org.apache.sysds.runtime.util.EMAUtils; +import org.apache.sysds.runtime.util.FastBufferedDataInputStream; +import org.apache.sysds.runtime.util.FastBufferedDataOutputStream; import org.apache.sysds.runtime.util.IndexRange; import org.apache.sysds.runtime.util.UtilFunctions; import org.apache.sysds.utils.MemoryEstimates; @@ -815,14 +820,30 @@ public void readFields(DataInput in) throws IOException { @Override public void writeExternal(ObjectOutput out) throws IOException { - // redirect serialization to writable impl - write(out); + + // if((out instanceof ObjectOutputStream)){ + // ObjectOutputStream oos = (ObjectOutputStream)out; + // FastBufferedDataOutputStream fos = new FastBufferedDataOutputStream(oos); + // write(fos); //note: cannot close fos as this would close oos + // fos.flush(); + // } + // else{ + write(out); + // } } @Override public void readExternal(ObjectInput in) throws IOException { - // redirect deserialization to writable impl - readFields(in); + // if(in instanceof ObjectInputStream) { + // // fast deserialize of dense/sparse blocks + // ObjectInputStream ois = (ObjectInputStream) in; + // FastBufferedDataInputStream fis = new FastBufferedDataInputStream(ois); + // readFields(fis); // note: cannot close fos as this would close oos + // } + // else { + // redirect deserialization to writable impl + readFields(in); + // } } @Override diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/columns/Array.java b/src/main/java/org/apache/sysds/runtime/frame/data/columns/Array.java index 68ef739e65c..874364255f3 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/columns/Array.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/columns/Array.java @@ -108,7 +108,7 @@ public synchronized final Map getRecodeMap() { */ protected Map createRecodeMap() { Map map = new HashMap<>(); - long id = 0; + long id = 1; for(int i = 0; i < size(); i++) { T val = get(i); if(val != null) { diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/columns/StringArray.java b/src/main/java/org/apache/sysds/runtime/frame/data/columns/StringArray.java index ef66a046f74..fd86286972b 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/columns/StringArray.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/columns/StringArray.java @@ -22,6 +22,7 @@ import java.io.DataInput; import java.io.DataOutput; import java.io.IOException; +import java.nio.charset.Charset; import java.util.Arrays; import java.util.BitSet; import java.util.HashMap; @@ -101,7 +102,7 @@ public void set(int rl, int ru, Array value, int rlSrc) { catch(Exception e) { super.set(rl, ru, value, rlSrc); } - finally{ + finally { materializedSize = -1; } } @@ -146,7 +147,21 @@ public Array append(Array other) { public void write(DataOutput out) throws IOException { out.writeByte(FrameArrayType.STRING.ordinal()); out.writeLong(getInMemorySize()); + + // final Charset cs = Charset.defaultCharset(); for(int i = 0; i < _size; i++) + // { + // if(_data[i] == null){ + // out.writeInt(0); + // } + // else{ + // // cs.encode(_data[i]); + // byte[] bs = _data[i].getBytes(cs); + // out.writeInt(bs.length); + // out.write(bs); + // } + // } + out.writeUTF((_data[i] != null) ? _data[i] : ""); } @@ -154,9 +169,25 @@ public void write(DataOutput out) throws IOException { public void readFields(DataInput in) throws IOException { _size = _data.length; materializedSize = in.readLong(); + // byte[] bs = new byte[16]; + // final Charset cs = Charset.defaultCharset(); for(int i = 0; i < _size; i++) { - String tmp = in.readUTF(); - _data[i] = (!tmp.isEmpty()) ? tmp : null; + // int l = in.readInt(); + // if(l == 0){ + // _data[i] = null; + // } + // else{ + // if(l > bs.length) + // bs = new byte[l]; + // in.readFully(bs, 0, l); + // String tmp = new String(bs, 0, l, cs); + // // String tmp = in.readUTF(); + // _data[i] = tmp; + // } + { + String tmp = in.readUTF(); + _data[i] = tmp.isEmpty() ? null : tmp; + } } } diff --git a/src/main/java/org/apache/sysds/runtime/instructions/spark/CompressionSPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/spark/CompressionSPInstruction.java index 4aca912444a..2b409c1396c 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/spark/CompressionSPInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/spark/CompressionSPInstruction.java @@ -33,14 +33,19 @@ import org.apache.sysds.runtime.compress.cost.CostEstimatorBuilder; import org.apache.sysds.runtime.compress.cost.CostEstimatorFactory.CostType; import org.apache.sysds.runtime.compress.workload.WTreeRoot; +import org.apache.sysds.runtime.controlprogram.caching.CacheableData; +import org.apache.sysds.runtime.controlprogram.caching.FrameObject; import org.apache.sysds.runtime.controlprogram.context.ExecutionContext; import org.apache.sysds.runtime.controlprogram.context.SparkExecutionContext; import org.apache.sysds.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer; +import org.apache.sysds.runtime.frame.data.FrameBlock; +import org.apache.sysds.runtime.frame.data.lib.FrameLibCompress; import org.apache.sysds.runtime.instructions.InstructionUtils; import org.apache.sysds.runtime.instructions.cp.CPOperand; import org.apache.sysds.runtime.matrix.data.MatrixBlock; import org.apache.sysds.runtime.matrix.data.MatrixIndexes; import org.apache.sysds.runtime.matrix.operators.Operator; +import org.apache.sysds.utils.Statistics; import scala.Tuple2; @@ -74,6 +79,24 @@ public static CompressionSPInstruction parseInstruction(String str) { public void processInstruction(ExecutionContext ec) { SparkExecutionContext sec = (SparkExecutionContext) ec; + CacheableData obj = sec.getCacheableData(input1.getName()); + + if(obj instanceof FrameObject) + processFrame(sec); + else + processMatrix(sec); + } + + private void processFrame(SparkExecutionContext sec) { + Statistics.decrementNoOfExecutedSPInst(); // + FrameBlock fb = sec.getFrameInput(input1.getName()); + sec.releaseFrameInput(input1.getName()); + FrameBlock compResult = FrameLibCompress.compress(fb, InfrastructureAnalyzer.getLocalParallelism()); + sec.setFrameOutput(output.getName(), compResult); + } + + private void processMatrix(SparkExecutionContext sec) { + // get input rdd handle JavaPairRDD in = sec.getBinaryMatrixBlockRDDHandleForVariable(input1.getName()); @@ -139,8 +162,8 @@ public CompressionWorkloadFunction(CostEstimatorBuilder costBuilder) { @Override public MatrixBlock call(MatrixBlock arg0) throws Exception { CompressionSettingsBuilder csb = new CompressionSettingsBuilder().setIsInSparkInstruction(); - return CompressedMatrixBlockFactory.compress(arg0, InfrastructureAnalyzer.getLocalParallelism(), csb, costBuilder) - .getLeft(); + return CompressedMatrixBlockFactory + .compress(arg0, InfrastructureAnalyzer.getLocalParallelism(), csb, costBuilder).getLeft(); } } diff --git a/src/main/java/org/apache/sysds/runtime/io/IOUtilFunctions.java b/src/main/java/org/apache/sysds/runtime/io/IOUtilFunctions.java index 3735f9d7dbb..c188928ae00 100644 --- a/src/main/java/org/apache/sysds/runtime/io/IOUtilFunctions.java +++ b/src/main/java/org/apache/sysds/runtime/io/IOUtilFunctions.java @@ -22,6 +22,7 @@ import java.io.BufferedReader; import java.io.ByteArrayOutputStream; import java.io.Closeable; +import java.io.File; import java.io.IOException; import java.io.InputStream; import java.io.InputStreamReader; @@ -49,6 +50,15 @@ import org.apache.hadoop.io.SequenceFile; import org.apache.hadoop.io.SequenceFile.Writer; import org.apache.hadoop.io.Text; +import org.apache.hadoop.io.compress.BZip2Codec; +import org.apache.hadoop.io.compress.CompressionCodec; +import org.apache.hadoop.io.compress.DefaultCodec; +import org.apache.hadoop.io.compress.DeflateCodec; +import org.apache.hadoop.io.compress.GzipCodec; +import org.apache.hadoop.io.compress.Lz4Codec; +import org.apache.hadoop.io.compress.PassthroughCodec; +import org.apache.hadoop.io.compress.SnappyCodec; +import org.apache.hadoop.io.compress.ZStandardCodec; import org.apache.hadoop.mapred.FileSplit; import org.apache.hadoop.mapred.InputFormat; import org.apache.hadoop.mapred.InputSplit; @@ -57,6 +67,7 @@ import org.apache.hadoop.mapred.Reporter; import org.apache.hadoop.mapred.TextInputFormat; import org.apache.sysds.conf.ConfigurationManager; +import org.apache.sysds.conf.DMLConfig; import org.apache.sysds.runtime.DMLRuntimeException; import org.apache.sysds.runtime.data.TensorBlock; import org.apache.sysds.runtime.data.TensorIndexes; @@ -66,11 +77,12 @@ import org.apache.sysds.runtime.matrix.data.MatrixIndexes; import org.apache.sysds.runtime.transform.TfUtils; import org.apache.sysds.runtime.util.LocalFileUtils; -import org.apache.sysds.runtime.util.UtilFunctions; -public class IOUtilFunctions -{ - private static final Log LOG = LogFactory.getLog(UtilFunctions.class.getName()); +import io.airlift.compress.lzo.LzoCodec; +import io.airlift.compress.lzo.LzopCodec; + +public class IOUtilFunctions { + private static final Log LOG = LogFactory.getLog(IOUtilFunctions.class.getName()); public static final PathFilter hiddenFileFilter = new PathFilter(){ @Override @@ -188,7 +200,6 @@ public static String[] split(String str, String delim) { //split by whole separator required for multi-character delimiters, preserve //all tokens required for empty cells and in order to keep cell alignment - return StringUtils.splitByWholeSeparatorPreserveAllTokens(str, delim); } @@ -252,13 +263,22 @@ else if(empty) { private static String[] splitCSVNonNullWithCache(final String str, final String delim, final String[] cache) { final int len = str.length(); final int delimLen = delim.length(); - + final boolean containsQuotationMarks = str.contains("\""); int from = 0; int id = 0; - while(from < len) { // for all tokens - final int to = getTo(str, from, delim); - cache[id++] =str.substring(from, to); - from = to + delimLen; + if(containsQuotationMarks){ + while(from < len) { // for all tokens + final int to = getTo(str, from, delim); + cache[id++] = str.substring(from, to); + from = to + delimLen; + } + } + else{ + while(from < len) { // for all tokens + final int to = getToNoQuote(str, from, delim); + cache[id++] = str.substring(from, to); + from = to + delimLen; + } } if(from == len) @@ -302,6 +322,21 @@ else if(isEmptyMatch(str, from, delim, dLen, len)) return to >= 0 ? to : len; } + private static int getToNoQuote(final String str, final int from, final String delim) { + final int len = str.length(); + final int dLen = delim.length(); + final int fromP1 = from + 1; + int to; + + if(isEmptyMatch(str, from, delim, dLen, len)) + return to = from; // empty string + else // default: unquoted non-empty + to = str.indexOf(delim, fromP1); + + // slice out token and advance position + return to >= 0 ? to : len; + } + public static String trim(String str) { return str.trim(); } @@ -561,32 +596,43 @@ public static int countNumColumnsCSV(InputSplit[] splits, InputFormat informat, return ncol; } - public static Path[] getSequenceFilePaths( FileSystem fs, Path file ) - throws IOException - { + public static Path[] getSequenceFilePaths(FileSystem fs, Path file) throws IOException { Path[] ret = null; - //Note on object stores: Since the object store file system implementations - //only emulate a file system, the directory of a multi-part file does not - //exist physically and hence the isDirectory call returns false. Furthermore, - //listStatus call returns all files with the given directory as prefix, which - //includes the mtd file which needs to be ignored accordingly. - - if( fs.getFileStatus(file).isDirectory() - || IOUtilFunctions.isObjectStoreFileScheme(file) ) - { + // Note on object stores: Since the object store file system implementations + // only emulate a file system, the directory of a multi-part file does not + // exist physically and hence the isDirectory call returns false. Furthermore, + // listStatus call returns all files with the given directory as prefix, which + // includes the mtd file which needs to be ignored accordingly. + + if(fs instanceof LocalFileSystem) { + java.io.File f = new java.io.File(file.toString()); + if(f.isDirectory()){ + java.io.File[] r = new java.io.File(file.toString()).listFiles((a) -> { + final String s = a.getName(); + return !(s.startsWith("_") || (s.endsWith(".crc")) || s.endsWith(".mtd")); + }); + ret = new Path[r.length]; + for(int i = 0; i < r.length; i++) + ret[i] = new Path(r[i].toString()); + } + else{ + return new Path[]{file}; + } + } + else if(fs.getFileStatus(file).isDirectory() || IOUtilFunctions.isObjectStoreFileScheme(file)) { LinkedList tmp = new LinkedList<>(); FileStatus[] dStatus = fs.listStatus(file); - for( FileStatus fdStatus : dStatus ) - if( !fdStatus.getPath().getName().startsWith("_") //skip internal files - && !fdStatus.getPath().toString().equals(file.toString()+".mtd") ) //mtd file + for(FileStatus fdStatus : dStatus) + if(!fdStatus.getPath().getName().startsWith("_") // skip internal files + && !fdStatus.getPath().toString().equals(file.toString() + ".mtd")) // mtd file tmp.add(fdStatus.getPath()); ret = tmp.toArray(new Path[0]); } else { - ret = new Path[]{ file }; + ret = new Path[] {file}; } - + return ret; } @@ -703,6 +749,27 @@ public static T get(Future in) { throw new DMLRuntimeException(e); } } + + public static boolean isFileCPReadable(String path){ + try{ + + JobConf job = ConfigurationManager.getCachedJobConf(); + Path p = new Path(path); + FileSystem fs = getFileSystem(p,job); + if(fs instanceof LocalFileSystem){ + // fast java path. + File f = new File(path); + return ! f.isDirectory() && f.length() < Integer.MAX_VALUE; + } + else{ + FileStatus fstat = fs.getFileStatus(p); + return !fstat.isDirectory() && fstat.getLen() < Integer.MAX_VALUE; + } + } + catch(Exception e){ + return false; + } + } public static class CountRowsTask implements Callable { private final InputSplit _split; @@ -745,28 +812,69 @@ public Long call() throws Exception { public static Writer getSeqWriter(Path path, Configuration job, int replication) throws IOException { return SequenceFile.createWriter(job, Writer.file(path), Writer.bufferSize(4096), Writer.replication((short) (replication > 0 ? replication : 1)), - Writer.compression(SequenceFile.CompressionType.NONE), Writer.keyClass(MatrixIndexes.class), + Writer.compression(getCompressionEncodingType(), getCompressionCodec()), Writer.keyClass(MatrixIndexes.class), Writer.valueClass(MatrixBlock.class)); } public static Writer getSeqWriterFrame(Path path, Configuration job, int replication) throws IOException { return SequenceFile.createWriter(job, Writer.file(path), Writer.bufferSize(4096), Writer.keyClass(LongWritable.class), Writer.valueClass(FrameBlock.class), - Writer.compression(SequenceFile.CompressionType.NONE), + Writer.compression(getCompressionEncodingType(), getCompressionCodec()), Writer.replication((short) (replication > 0 ? replication : 1))); } public static Writer getSeqWriterTensor(Path path, Configuration job, int replication) throws IOException { return SequenceFile.createWriter(job, Writer.file(path), Writer.bufferSize(4096), Writer.replication((short) (replication > 0 ? replication : 1)), - Writer.compression(SequenceFile.CompressionType.NONE), Writer.keyClass(TensorIndexes.class), + Writer.compression(getCompressionEncodingType(),getCompressionCodec()), Writer.keyClass(TensorIndexes.class), Writer.valueClass(TensorBlock.class)); } public static Writer getSeqWriterCell(Path path, Configuration job, int replication) throws IOException { return SequenceFile.createWriter(job, Writer.file(path), Writer.bufferSize(4096), Writer.replication((short) (replication > 0 ? replication : 1)), - Writer.compression(SequenceFile.CompressionType.NONE), Writer.keyClass(MatrixIndexes.class), + Writer.compression(getCompressionEncodingType(), getCompressionCodec()), + Writer.keyClass(MatrixIndexes.class), Writer.valueClass(MatrixCell.class)); } + + public static SequenceFile.CompressionType getCompressionEncodingType() { + String v = ConfigurationManager.getDMLConfig().getTextValue(DMLConfig.IO_COMPRESSION_CODEC); + if(v.equals("none")) + return SequenceFile.CompressionType.NONE; + else + return SequenceFile.CompressionType.RECORD; + } + + public static CompressionCodec getCompressionCodec() { + String v = ConfigurationManager.getDMLConfig().getTextValue(DMLConfig.IO_COMPRESSION_CODEC); + + switch(v) { + case "Lz4": + return new Lz4Codec(); + case "Lzo": + return new LzoCodec(); + case "Lzop": + return new LzopCodec(); + case "Snappy": + return new SnappyCodec(); + case "BZip2": + return new BZip2Codec(); + case "deflate": + return new DeflateCodec(); + case "Gzip": + return new GzipCodec(); + case "pass": + return new PassthroughCodec(); + case "Zstd": + return new ZStandardCodec(); + case "none": + return null; + case "default": + default: + return new DefaultCodec(); + } + + } + } diff --git a/src/main/java/org/apache/sysds/runtime/io/ReaderBinaryBlockParallel.java b/src/main/java/org/apache/sysds/runtime/io/ReaderBinaryBlockParallel.java index 55127ef2c44..cabcad1b8c9 100644 --- a/src/main/java/org/apache/sysds/runtime/io/ReaderBinaryBlockParallel.java +++ b/src/main/java/org/apache/sysds/runtime/io/ReaderBinaryBlockParallel.java @@ -93,10 +93,10 @@ private static void readBinaryBlockMatrixFromHDFS( Path path, JobConf job, FileS if( HDFSTool.USE_BINARYBLOCK_SERIALIZATION ) HDFSTool.addBinaryBlockSerializationFramework( job ); + final ExecutorService pool = CommonThreadPool.get(_numThreads); try { //create read tasks for all files - ExecutorService pool = CommonThreadPool.get(_numThreads); ArrayList tasks = new ArrayList<>(); for( Path lpath : IOUtilFunctions.getSequenceFilePaths(fs, path) ){ ReadFileTask t = new ReadFileTask(lpath, job, dest, rlen, clen, blen, syncBlock); @@ -116,11 +116,13 @@ private static void readBinaryBlockMatrixFromHDFS( Path path, JobConf job, FileS if( dest.isInSparseFormat() && clen>blen ) sortSparseRowsParallel(dest, rlen, _numThreads, pool); - pool.shutdown(); } catch (Exception e) { throw new IOException("Failed parallel read of binary block input.", e); } + finally{ + pool.shutdown(); + } } private static class ReadFileTask implements Callable diff --git a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixReorg.java b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixReorg.java index 2952c2b4c85..b5249b28b03 100644 --- a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixReorg.java +++ b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixReorg.java @@ -245,36 +245,31 @@ public static MatrixBlock transpose(MatrixBlock in, MatrixBlock out, int k, bool // CSR is only allowed in the transposed output if the number of non zeros is counted in the columns allowCSR = allowCSR && (in.clen <= 4096 || out.nonZeros < 10000000); - if(out.sparse && allowCSR) { - int size = (int) out.nonZeros; - out.sparseBlock = new SparseBlockCSR(in.getNumColumns(), size, size); - } - else if(out.sparse) - out.allocateSparseRowsBlock(false); - else - out.allocateDenseBlock(false); - - // core multi-threaded transpose + int[] cnt = null; try { final ExecutorService pool = CommonThreadPool.get(k); - // pre-processing (compute nnz per column once for sparse) - int[] cnt = null; - // filter matrices with many columns since the CountNnzTask would return - // null if the number of columns is larger than threshold - if(allowCSR) { + if(out.sparse && allowCSR) { + final int size = (int) out.nonZeros; + final Future f = countNNZColumns(in, k, pool); + out.sparseBlock = new SparseBlockCSR(in.getNumColumns(), size, size); + final int[] outPtr = ((SparseBlockCSR) out.sparseBlock).rowPointers(); - cnt = countNNZColumns(in, k, pool); - - if(allowCSR) { - int[] outPtr = ((SparseBlockCSR) out.sparseBlock).rowPointers(); - for(int i = 0; i < cnt.length; i++) { - // set out pointers to correct start of rows. - outPtr[i + 1] = outPtr[i] + cnt[i]; - // set the cnt value to the new pointer to start of row in CSR - cnt[i] = outPtr[i]; - } + // pre-processing (compute nnz per column once for sparse) + // filter matrices with many columns since the CountNnzTask would return + // null if the number of columns is larger than threshold + cnt = f.get(); + for(int i = 0; i < cnt.length; i++) { + // set out pointers to correct start of rows. + outPtr[i + 1] = outPtr[i] + cnt[i]; + // set the cnt value to the new pointer to start of row in CSR + cnt[i] = outPtr[i]; } } + else if(out.sparse) + out.allocateSparseRowsBlock(false); + else + out.allocateDenseBlock(false); + // compute actual transpose and check for errors ArrayList tasks = new ArrayList<>(); @@ -299,13 +294,15 @@ else if(out.sparse) return out; } - public static int[] countNNZColumns(MatrixBlock in, int k, ExecutorService pool) + public static Future countNNZColumns(MatrixBlock in, int k, ExecutorService pool) throws InterruptedException, ExecutionException { - int[] cnt = null; - List> rtasks = countNNZColumnsFuture(in, k, pool); - for(Future rtask : rtasks) - cnt = mergeNnzCounts(cnt, rtask.get()); - return cnt; + final List> rtasks = countNNZColumnsFuture(in, k, pool); + return pool.submit(() -> { + int[] cnt = null; + for(Future rtask : rtasks) + cnt = mergeNnzCounts(cnt, rtask.get()); + return cnt; + }); } public static List> countNNZColumnsFuture(MatrixBlock in, int k, ExecutorService pool) throws InterruptedException { @@ -1140,10 +1137,85 @@ private static void transposeSparseToSparse(MatrixBlock in, MatrixBlock out, int private static void transposeSparseToSparseCSR(MatrixBlock in, MatrixBlock out, int rl, int ru, int cl, int cu, int[] cnt) { - // NOTE: called only in sequential or column-wise parallel execution if(rl > 0 || ru < in.rlen) throw new RuntimeException("Unsupported row-parallel transposeSparseToSparse: " + rl + ", " + ru); + if(cu - cl == 1) + transposeSparseToSparseCSRSingleCol(in, out, rl, ru, cl, cu, cnt); + else if(in.getSparseBlock() instanceof SparseBlockCSR) + transposeSparseCSRToSparseCSRMultiCol(in, out, cl, cu, cnt); + else + transposeSparseToSparseCSRMultiCol(in, out, cl, cu, cnt); + } + + private final static void transposeSparseCSRToSparseCSRMultiCol(final MatrixBlock in, final MatrixBlock out, + final int cl, final int cu, final int[] cnt) { + final int rlen = in.rlen; + + final SparseBlockCSR a = (SparseBlockCSR) in.getSparseBlock(); + final SparseBlockCSR c = (SparseBlockCSR) out.getSparseBlock(); + + final long xsp = (long) rlen * in.clen / in.nonZeros; + final int blocksizeI = Math.min(Math.max(128, (int) (8 * xsp)), 512); + + // temporary array for block boundaries (for preventing binary search) + final int[] ix = new int[Math.min(blocksizeI, rlen)]; + + // blocked execution + for(int bi = 0; bi < rlen; bi += blocksizeI) + transposeSparseCSRToSparseCSRMultiColBlock(bi, blocksizeI, rlen, cl, cu, ix, a, cnt, c); + + } + + private final static void transposeSparseCSRToSparseCSRMultiColBlock(int bi, int blocksizeI, int rlen, int cl, + int cu, int[] ix, SparseBlockCSR a, final int[] cnt, SparseBlockCSR c) { + + final int[] aix = a.indexes(); + final double[] avals = a.values(); + final int[] outIndexes = c.indexes(); + final double[] outValues = c.values(); + + // find column starting positions + final int bimin = Math.min(bi + blocksizeI, rlen); + if(cl > 0) + fillSkip(bi, bimin, a, cl, ix); + else + Arrays.fill(ix, 0); + + for(int bj = cl; bj < cu; bj += blocksizeI) + transposeSparseCSRToSparseCSRMultiColBlockBlock(bi, bj, bimin, cu, blocksizeI, a, ix, aix, avals, outIndexes, + outValues, cnt); + } + + private final static void transposeSparseCSRToSparseCSRMultiColBlockBlock(int bi, int bj, int bimin, int cu, + int blocksizeI, SparseBlockCSR a, int[] ix, int[] aix, double[] avals, int[] outIndexes, double[] outValues, int[] cnt) { + final int bjmin = Math.min(bj + blocksizeI, cu); + // core block transpose operation + for(int i = bi; i < bimin; i++) { + final int apos = a.pos(i); + final int alen = a.size(i); + int j = ix[i - bi] + apos; // last block boundary + for(; j < apos + alen && aix[j] < bjmin; j++) { + int pointer = cnt[aix[j]]; + cnt[aix[j]]++; + outIndexes[pointer] = i; + outValues[pointer] = avals[j]; + } + ix[i - bi] = j - apos; // keep block boundary + } + } + + private final static void fillSkip(int bi, int bimin, SparseBlockCSR a, int cl, int[] ix){ + // fill the skip boundaries. + for(int i = bi; i < bimin; i++) { + int j = a.posFIndexGTE(i, cl); + ix[i - bi] = (j >= 0) ? j : a.size(i); + } + } + + private final static void transposeSparseToSparseCSRMultiCol(final MatrixBlock in, final MatrixBlock out, + final int cl, final int cu, final int[] cnt) { + final int rlen = in.rlen; final SparseBlock a = in.getSparseBlock(); final SparseBlockCSR c = (SparseBlockCSR) out.getSparseBlock(); @@ -1151,70 +1223,76 @@ private static void transposeSparseToSparseCSR(MatrixBlock in, MatrixBlock out, final int[] outIndexes = c.indexes(); final double[] outValues = c.values(); - if(cu - cl == 1) { - int i = 0; - final int end = c.size(cl) + c.pos(cl); - int outPointer = cnt[cl]; - while(outPointer < end) { - if(!a.isEmpty(i)) { + final long xsp = (long) rlen * in.clen / in.nonZeros; + final int blocksizeI = Math.min(Math.max(128, (int) (8 * xsp)), 512); + + // temporary array for block boundaries (for preventing binary search) + final int[] ix = new int[Math.min(blocksizeI, rlen)]; + + // blocked execution + for(int bi = 0; bi < rlen; bi += blocksizeI) { + Arrays.fill(ix, 0); + // find column starting positions + int bimin = Math.min(bi + blocksizeI, rlen); + if(cl > 0) { + for(int i = bi; i < bimin; i++) { + if(a.isEmpty(i)) + continue; + int j = a.posFIndexGTE(i, cl); + ix[i - bi] = (j >= 0) ? j : a.size(i); + } + } + + for(int bj = cl; bj < cu; bj += blocksizeI) { + final int bjmin = Math.min(bj + blocksizeI, cu); + // core block transpose operation + for(int i = bi; i < bimin; i++) { + if(a.isEmpty(i)) + continue; final int apos = a.pos(i); final int alen = a.size(i); final int[] aix = a.indexes(i); final double[] avals = a.values(i); - for(int j = apos; j < apos + alen && aix[j] <= cl; j++) - if(aix[j] == cl) { - outIndexes[outPointer] = i; - outValues[outPointer] = avals[j]; - outPointer++; - } + int j = ix[i - bi] + apos; // last block boundary + for(; j < apos + alen && aix[j] < bjmin; j++) { + int pointer = cnt[aix[j]]; + cnt[aix[j]]++; + outIndexes[pointer] = i; + outValues[pointer] = avals[j]; + } + ix[i - bi] = j - apos; // keep block boundary } - i++; } } - else { - final long xsp = (long) in.rlen * in.clen / in.nonZeros; - final int blocksizeI = Math.max(128, (int) (8 * xsp)); - final int blocksizeJ = Math.max(128, (int) (8 * xsp)); + } - // temporary array for block boundaries (for preventing binary search) - int[] ix = new int[Math.min(blocksizeI, ru - rl)]; + private static void transposeSparseToSparseCSRSingleCol(MatrixBlock in, MatrixBlock out, int rl, int ru, int cl, + int cu, int[] cnt) { + final SparseBlock a = in.getSparseBlock(); + final SparseBlockCSR c = (SparseBlockCSR) out.getSparseBlock(); - // blocked execution - for(int bi = rl; bi < ru; bi += blocksizeI) { - Arrays.fill(ix, 0); - // find column starting positions - int bimin = Math.min(bi + blocksizeI, ru); - if(cl > 0) { - for(int i = bi; i < bimin; i++) { - if(a.isEmpty(i)) - continue; - int j = a.posFIndexGTE(i, cl); - ix[i - bi] = (j >= 0) ? j : a.size(i); - } - } + final int[] outIndexes = c.indexes(); + final double[] outValues = c.values(); - for(int bj = cl; bj < cu; bj += blocksizeJ) { - int bjmin = Math.min(bj + blocksizeJ, cu); - // core block transpose operation - for(int i = bi; i < bimin; i++) { - if(a.isEmpty(i)) - continue; - int apos = a.pos(i); - int alen = a.size(i); - int[] aix = a.indexes(i); - double[] avals = a.values(i); - int j = ix[i - bi] + apos; // last block boundary - for(; j < apos + alen && aix[j] < bjmin; j++) { - int pointer = cnt[aix[j]]; - cnt[aix[j]]++; - outIndexes[pointer] = i; - outValues[pointer] = avals[j]; - } - ix[i - bi] = j - apos; // keep block boundary + int i = 0; + final int end = c.size(cl) + c.pos(cl); + int outPointer = cnt[cl]; + while(outPointer < end) { + if(!a.isEmpty(i)) { + final int apos = a.pos(i); + final int alen = a.size(i); + final int[] aix = a.indexes(i); + final double[] avals = a.values(i); + for(int j = apos; j < apos + alen && aix[j] <= cl; j++) + if(aix[j] == cl) { + outIndexes[outPointer] = i; + outValues[outPointer] = avals[j]; + outPointer++; } - } } + i++; } + } private static void transposeSparseToDense(MatrixBlock in, MatrixBlock out, int rl, int ru, int cl, int cu) { diff --git a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixSparseToDense.java b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixSparseToDense.java new file mode 100644 index 00000000000..41b1048ad97 --- /dev/null +++ b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixSparseToDense.java @@ -0,0 +1,185 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.runtime.matrix.data; + +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Future; + +import org.apache.sysds.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer; +import org.apache.sysds.runtime.data.DenseBlock; +import org.apache.sysds.runtime.data.SparseBlockCSR; +import org.apache.sysds.runtime.data.SparseBlockMCSR; +import org.apache.sysds.runtime.data.SparseRowVector; +import org.apache.sysds.runtime.util.CommonThreadPool; +import org.apache.sysds.runtime.util.UtilFunctions; + +public interface LibMatrixSparseToDense { + /** + * Convert the given matrix block to a sparse allocation. + * + * @param r The matrix block to modify, and return the sparse block in. + * @param allowCSR If CSR is allowed. + */ + public static void denseToSparse(MatrixBlock r, boolean allowCSR) { + final DenseBlock a = r.getDenseBlock(); + + // set target representation, early abort on empty blocks + r.sparse = true; + if(a == null) + return; + + int k = InfrastructureAnalyzer.getLocalParallelism(); + + if(k > 1) + denseToSparseParallel(r, k, allowCSR); + else if(allowCSR && r.nonZeros <= Integer.MAX_VALUE) + denseToSparseCSR(r); + else + denseToSparseMCSR(r); + + // cleanup dense block + r.denseBlock = null; + } + + private static void denseToSparseCSR(MatrixBlock r) { + final DenseBlock a = r.getDenseBlock(); + final int m = r.rlen; + final int n = r.clen; + try { + // allocate target in memory-efficient CSR format + int lnnz = (int) r.nonZeros; + int[] rptr = new int[m + 1]; + int[] indexes = new int[lnnz]; + double[] values = new double[lnnz]; + for(int i = 0, pos = 0; i < m; i++) { + double[] avals = a.values(i); + int aix = a.pos(i); + for(int j = 0; j < n; j++) { + double aval = avals[aix + j]; + if(aval != 0) { + indexes[pos] = j; + values[pos] = aval; + pos++; + } + } + rptr[i + 1] = pos; + } + r.sparseBlock = new SparseBlockCSR(rptr, indexes, values, lnnz); + } + catch(ArrayIndexOutOfBoundsException ioobe) { + r.sparse = false; + // this means something was wrong with the sparse count. + final long nnzBefore = r.nonZeros; + final long nnzNew = r.recomputeNonZeros(); + + // try again. + if(nnzBefore != nnzNew) + denseToSparse(r, true); + else + denseToSparse(r, false); + + } + } + + private static void denseToSparseMCSR(MatrixBlock r) { + final DenseBlock a = r.getDenseBlock(); + + final int m = r.rlen; + final int n = r.clen; + // remember number non zeros. + long nnzTemp = r.getNonZeros(); + // fallback to less-memory efficient MCSR format, + // which however allows much larger sparse matrices + if(!r.allocateSparseRowsBlock()) + r.reset(); // reset if not allocated + SparseBlockMCSR sblock = (SparseBlockMCSR) r.sparseBlock; + toSparseMCSRRange(a, sblock, n, 0, m); + r.nonZeros = nnzTemp; + } + + private static void toSparseMCSRRange(DenseBlock a, SparseBlockMCSR b, int n, int rl, int ru) { + for(int i = rl; i < ru; i++) + toSparseMCSRRow(a, b, n, i); + } + + private static void toSparseMCSRRow(DenseBlock a, SparseBlockMCSR b, int n, int i) { + final double[] avals = a.values(i); + final int aix = a.pos(i); + // compute nnz per row (not via recomputeNonZeros as sparse allocated) + final int lnnz = UtilFunctions.computeNnz(avals, aix, n); + if(lnnz <= 0) + return; + + final double[] vals = new double[lnnz]; + final int[] idx = new int[lnnz]; + // allocate sparse row and append non-zero values + // b.allocate(i, lnnz); + + for(int j = 0, o = 0; j < n; j++) { + double v = avals[aix + j]; + if(v != 0.0) { + vals[o] = v; + idx[o] = j; + o++; + } + } + b.set(i, new SparseRowVector(vals, idx), false); + } + + private static void denseToSparseParallel(MatrixBlock r, int k, boolean allowCSR) { + final DenseBlock a = r.getDenseBlock(); + + final int m = r.rlen; + final int n = r.clen; + // remember number non zeros. + final long nnzTemp = r.getNonZeros(); + // fallback to less-memory efficient MCSR format, + // which however allows much larger sparse matrices + if(!r.allocateSparseRowsBlock()) + r.reset(); // reset if not allocated + final SparseBlockMCSR b = (SparseBlockMCSR) r.sparseBlock; + final int blockSize = Math.max(250, m / k); + + ExecutorService pool = CommonThreadPool.get(k); + try { + + List> tasks = new ArrayList<>(); + for(int i = 0; i < m; i += blockSize) { + final int start = i; + final int end = Math.min(m, i + blockSize); + tasks.add(pool.submit(() -> toSparseMCSRRange(a, b, n, start, end))); + } + + for(Future t : tasks) + t.get(); + } + catch(Exception e) { + throw new RuntimeException(e); + } + finally { + pool.shutdown(); + } + + r.nonZeros = nnzTemp; + + } +} diff --git a/src/main/java/org/apache/sysds/runtime/transform/decode/Decoder.java b/src/main/java/org/apache/sysds/runtime/transform/decode/Decoder.java index f7c7e96385d..724af1be630 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/decode/Decoder.java +++ b/src/main/java/org/apache/sysds/runtime/transform/decode/Decoder.java @@ -24,6 +24,8 @@ import java.io.ObjectInput; import java.io.ObjectOutput; +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; import org.apache.sysds.common.Types.ValueType; import org.apache.sysds.runtime.DMLRuntimeException; import org.apache.sysds.runtime.frame.data.FrameBlock; @@ -34,8 +36,8 @@ * interface for decoding matrices to frames. * */ -public abstract class Decoder implements Externalizable -{ +public abstract class Decoder implements Externalizable{ + protected static final Log LOG = LogFactory.getLog(Decoder.class.getName()); private static final long serialVersionUID = -1732411001366177787L; protected ValueType[] _schema; @@ -61,12 +63,33 @@ public String[] getColnames() { /** * Block decode API converting a matrix block into a frame block. * - * @param in input matrix block - * @param out output frame block - * + * @param in Input matrix block + * @param out Output frame block * @return returns given output frame block for convenience */ public abstract FrameBlock decode(MatrixBlock in, FrameBlock out); + + /** + * Block decode API converting a matrix block into a frame block in parallel. + * + * @param in Input matrix block + * @param out Output frame block + * @param k Parallelization degree + * @return returns the given output frame block for convenience + */ + public FrameBlock decode(MatrixBlock in, FrameBlock out, int k) { + return decode(in, out); + } + + /** + * Block decode row block + * + * @param in input Matrix Block + * @param out output FrameBlock + * @param rl row start to decode + * @param ru row end to decode (not inclusive) + */ + public abstract void decode(MatrixBlock in, FrameBlock out, int rl, int ru); /** * Returns a new Decoder that only handles a sub range of columns. The sub-range refers to the columns after diff --git a/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderBin.java b/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderBin.java index 5eb43e8bde4..4565d2f836d 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderBin.java +++ b/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderBin.java @@ -27,6 +27,7 @@ import org.apache.sysds.common.Types.ValueType; import org.apache.sysds.runtime.DMLRuntimeException; import org.apache.sysds.runtime.frame.data.FrameBlock; +import org.apache.sysds.runtime.frame.data.columns.Array; import org.apache.sysds.runtime.matrix.data.MatrixBlock; import org.apache.sysds.runtime.util.UtilFunctions; @@ -36,8 +37,8 @@ * constant time for incoming values and * */ -public class DecoderBin extends Decoder -{ +public class DecoderBin extends Decoder { + private static final long serialVersionUID = -3784249774608228805L; // a) column bin boundaries @@ -56,18 +57,28 @@ protected DecoderBin(ValueType[] schema, int[] binCols) { @Override public FrameBlock decode(MatrixBlock in, FrameBlock out) { out.ensureAllocatedColumns(in.getNumRows()); - for( int i=0; i a = out.getColumn(_colList[j] - 1); + final double val = in.quickGetValue(i, _colList[j] - 1); + if(!Double.isNaN(val)){ + final int key = (int) Math.round(val); + double bmin = _binMins[j][key - 1]; + double bmax = _binMaxs[j][key - 1]; + double oval = bmin + (bmax - bmin) / 2 // bin center + + (val - key) * (bmax - bmin); // bin fractions + a.set(i, oval); + } + else + a.set(i, val); // NaN } } - return out; } @Override diff --git a/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderComposite.java b/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderComposite.java index 52367b34613..f4bc9f8b216 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderComposite.java +++ b/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderComposite.java @@ -25,10 +25,13 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.List; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Future; import org.apache.sysds.common.Types.ValueType; import org.apache.sysds.runtime.frame.data.FrameBlock; import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import org.apache.sysds.runtime.util.CommonThreadPool; /** * Simple composite decoder that applies a list of decoders @@ -56,6 +59,40 @@ public FrameBlock decode(MatrixBlock in, FrameBlock out) { return out; } + + @Override + public FrameBlock decode(final MatrixBlock in, final FrameBlock out, final int k) { + final ExecutorService pool = CommonThreadPool.get(k); + out.ensureAllocatedColumns(in.getNumRows()); + try { + final List> tasks = new ArrayList<>(); + int blz = Math.max(in.getNumRows() / k, 1000); + for(Decoder decoder : _decoders){ + for(int i = 0; i < in.getNumRows(); i += blz){ + final int start = i; + final int end = Math.min(in.getNumRows(), i + blz); + tasks.add(pool.submit(() -> decoder.decode(in, out, start, end))); + } + } + for(Future f : tasks) + f.get(); + return out; + } + catch(Exception e) { + throw new RuntimeException(e); + } + finally { + pool.shutdown(); + } + } + + @Override + public void decode(MatrixBlock in, FrameBlock out, int rl, int ru){ + for( Decoder decoder : _decoders ) + decoder.decode(in, out, rl, ru); + } + + @Override public Decoder subRangeDecoder(int colStart, int colEnd, int dummycodedOffset) { List subRangeDecoders = new ArrayList<>(); diff --git a/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderDummycode.java b/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderDummycode.java index dec1486bebc..114f1238f8e 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderDummycode.java +++ b/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderDummycode.java @@ -53,7 +53,15 @@ protected DecoderDummycode(ValueType[] schema, int[] dcCols) { public FrameBlock decode(MatrixBlock in, FrameBlock out) { //TODO perf (exploit sparse representation for better asymptotic behavior) out.ensureAllocatedColumns(in.getNumRows()); - for( int i=0; i colList = new ArrayList<>(); diff --git a/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderRecode.java b/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderRecode.java index c5009172836..f9218456551 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderRecode.java +++ b/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderRecode.java @@ -65,8 +65,14 @@ public Object getRcMapValue(int i, long key) { @Override public FrameBlock decode(MatrixBlock in, FrameBlock out) { + decode(in, out, 0, in.getNumRows()); + return out; + } + + @Override + public void decode(MatrixBlock in, FrameBlock out, int rl, int ru) { if( _onOut ) { //recode on output (after dummy) - for( int i=0; i in, int startInd, int endInd, final double[] codes = tmp != null && tmp.length == endLength ? tmp : new double[endLength]; if (_binMins == null || _binMins.length == 0 || _binMaxs.length == 0) { LOG.warn("ColumnEncoderBin: applyValue without bucket boundaries, assign 1"); - Arrays.fill(codes, startInd, endInd, 1.0); + Arrays.fill(codes, 0, endLength, 1.0); return codes; } @@ -187,39 +188,34 @@ protected final void getCodeColFrame(FrameBlock in, int startInd, int endInd, do final Array c = in.getColumn(_colID - 1); final double mi = _binMins[0]; final double mx = _binMaxs[_binMaxs.length-1]; - if(!c.containsNull()) + if(!(c instanceof StringArray) && !c.containsNull()) for(int i = startInd; i < endInd; i++) - codes[i - startInd] = getCodeIndex(c.getAsDouble(i), mi,mx); + codes[i - startInd] = getCodeIndex(c.getAsDouble(i), mi, mx); else for(int i = startInd; i < endInd; i++) - codes[i - startInd] = getCodeIndex(c.getAsNaNDouble(i),mi,mx); + codes[i - startInd] = getCodeIndex(c.getAsNaNDouble(i),mi, mx); } protected final double getCodeIndex(double inVal){ - return getCodeIndex(inVal, _binMins[0],_binMaxs[_binMaxs.length-1]); + return getCodeIndex(inVal, _binMins[0], _binMaxs[_binMaxs.length-1]); } - protected final double getCodeIndex(double inVal, double mi, double mx){ - final boolean nan = Double.isNaN(inVal); - if(nan || (_binMethod != BinMethod.EQUI_HEIGHT_APPROX && (inVal < mi || inVal > mx))) + protected final double getCodeIndex(double inVal, double min, double max){ + if(Double.isNaN(inVal)) return Double.NaN; else if(_binMethod == BinMethod.EQUI_WIDTH) - return getEqWidth(inVal); + return getEqWidth(inVal, min, max); else // if (_binMethod == BinMethod.EQUI_HEIGHT || _binMethod == BinMethod.EQUI_HEIGHT_APPROX) return getCodeIndexEQHeight(inVal); } - private final double getEqWidth(double inVal) { - final double max = _binMaxs[_binMaxs.length - 1]; - final double min = _binMins[0]; - + private final double getEqWidth(double inVal, double min, double max) { if(max == min) return 1; - - // TODO: Skip computing bin boundaries for equi-width - double binWidth = (max - min) / _numBin; - double code = Math.ceil((inVal - min) / binWidth); - return (code == 0) ? code + 1 : code; + if(_numBin <= 0) + throw new RuntimeException("Invalid num bins"); + final int code = (int)(Math.ceil((inVal - min) / (max - min) * _numBin) ); + return code > _numBin ? _numBin : code < 1 ? 1 : code; } private final double getCodeIndexEQHeight(double inVal){ @@ -241,7 +237,7 @@ private final double getCodeIndexEQHeightNormal(double inVal) { if(ix < 0) // somewhere in between values // +2 because negative values are found from binary search. // plus 2 to correct for the absolute value of that. - return Math.abs(ix + 1) + 1; + return Math.min(Math.abs(ix + 1) + 1, _binMaxs.length); else if(ix == 0) // If first bucket boundary add it there. return 1; else diff --git a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderComposite.java b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderComposite.java index 225f2db54c7..6fda66113dd 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderComposite.java +++ b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderComposite.java @@ -216,8 +216,7 @@ public MatrixBlock apply(CacheBlock in, MatrixBlock out, int outputCol, int r } } catch(Exception ex) { - LOG.error("Failed to transform-apply frame with \n" + this); - throw ex; + throw new DMLRuntimeException("Failed to transform-apply frame with \n" + this, ex); } return out; } diff --git a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderDummycode.java b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderDummycode.java index 4b24652750f..a9b00a0767a 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderDummycode.java +++ b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderDummycode.java @@ -41,7 +41,8 @@ public class ColumnEncoderDummycode extends ColumnEncoder { private static final long serialVersionUID = 5832130477659116489L; - public int _domainSize = -1; // length = #of dummycoded columns + /** The number of columns outputted from this column group. */ + public int _domainSize = -1; public ColumnEncoderDummycode() { super(-1); @@ -230,8 +231,11 @@ else if(columnEncoder instanceof ColumnEncoderFeatureHash){ } if(distinct != -1) { - _domainSize = distinct; - LOG.debug("DummyCoder for column: " + _colID + " has domain size: " + _domainSize); + _domainSize = Math.max(1, distinct); + if(LOG.isDebugEnabled()){ + + LOG.debug("DummyCoder for column: " + _colID + " has domain size: " + _domainSize); + } } } } @@ -249,7 +253,7 @@ public FrameBlock getMetaData(FrameBlock meta) { @Override public void initMetaData(FrameBlock meta) { // initialize domain sizes and output num columns - _domainSize = (int) meta.getColumnMetadata()[_colID - 1].getNumDistinct(); + _domainSize = Math.max(1, (int) meta.getColumnMetadata()[_colID - 1].getNumDistinct()); } @Override diff --git a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderPassThrough.java b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderPassThrough.java index d40359c4842..12077221c05 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderPassThrough.java +++ b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderPassThrough.java @@ -68,12 +68,17 @@ protected double getCode(CacheBlock in, int row) { @Override protected double[] getCodeCol(CacheBlock in, int startInd, int endInd, double[] tmp) { - final int endLength = endInd - startInd; - final double[] codes = tmp != null && tmp.length == endLength ? tmp : new double[endLength]; - for (int i=startInd; i" + endInd, e); } - return codes; } protected void applySparse(CacheBlock in, MatrixBlock out, int outputCol, int rowStart, int blk){ diff --git a/src/main/java/org/apache/sysds/runtime/transform/encode/CompressedEncode.java b/src/main/java/org/apache/sysds/runtime/transform/encode/CompressedEncode.java index 8c359aaa83d..8ca8b6d9fc2 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/encode/CompressedEncode.java +++ b/src/main/java/org/apache/sysds/runtime/transform/encode/CompressedEncode.java @@ -156,6 +156,8 @@ private AColGroup recodeToDummy(ColumnEncoderComposite c) { Array a = in.getColumn(colId - 1); boolean containsNull = a.containsNull(); Map map = a.getRecodeMap(); + List r = c.getEncoders(); + r.set(0, new ColumnEncoderRecode(colId, (HashMap) map)); int domain = map.size(); if(containsNull && domain == 0) return new ColGroupEmpty(ColIndexFactory.create(1)); @@ -164,8 +166,6 @@ private AColGroup recodeToDummy(ColumnEncoderComposite c) { return ColGroupConst.create(colIndexes, new double[] {1}); ADictionary d = new IdentityDictionary(colIndexes.size(), containsNull); AMapToData m = createMappingAMapToData(a, map, containsNull); - List r = c.getEncoders(); - r.set(0, new ColumnEncoderRecode(colId, (HashMap) map)); return ColGroupDDC.create(colIndexes, d, m, null); } @@ -189,20 +189,40 @@ private AMapToData binEncode(Array a, ColumnEncoderBin b, boolean containsNul AMapToData m = MapToFactory.create(a.size(), b._numBin + (containsNull ? 1 : 0)); if(containsNull) { for(int i = 0; i < a.size(); i++) { - double v = a.getAsNaNDouble(i); - if(Double.isNaN(v)) - m.set(i, b._numBin); - else - m.set(i, (int) b.getCodeIndex(v) - 1); + final double v = a.getAsNaNDouble(i); + try { + + if(Double.isNaN(v)) + m.set(i, b._numBin); + else { + int idx = (int) b.getCodeIndex(v) - 1; + if(idx < 0) + idx = 0; + m.set(i, idx); + } + } + catch(Exception e) { + + m.set(i, (int) b.getCodeIndex(v - 0.00001) - 1); + } } } else { - - for(int i = 0; i < a.size(); i++){ - int idx = (int) b.getCodeIndex(a.getAsDouble(i)) - 1; - if(idx < 0) - throw new RuntimeException(a.getAsDouble(i) + " is invalid value for " + b + "\n" + idx); - m.set(i, idx); + + for(int i = 0; i < a.size(); i++) { + try { + + int idx = (int) b.getCodeIndex(a.getAsDouble(i)) - 1; + if(idx < 0) + idx = 0; + // throw new RuntimeException(a.getAsDouble(i) + " is invalid value for " + b + "\n" + idx); + m.set(i, idx); + } + catch(Exception e) { + + int idx = (int) b.getCodeIndex(a.getAsDouble(i) - 0.00001) - 1; + m.set(i, idx); + } } } return m; @@ -264,6 +284,7 @@ private AColGroup recode(ColumnEncoderComposite c) { @SuppressWarnings("unchecked") private AColGroup passThrough(ColumnEncoderComposite c) { + // TODO optimize to not construct full map but only some of it until aborting compression. IColIndex colIndexes = ColIndexFactory.create(1); int colId = c._colID; Array a = in.getColumn(colId - 1); @@ -282,7 +303,7 @@ private AColGroup passThrough(ColumnEncoderComposite c) { if(containsNull) vals[map.size()] = Double.NaN; ValueType t = a.getValueType(); - map.forEach((k, v) -> vals[v.intValue()] = UtilFunctions.objectToDouble(t, k)); + map.forEach((k, v) -> vals[v.intValue()-1] = UtilFunctions.objectToDouble(t, k)); ADictionary d = Dictionary.create(vals); AMapToData m = createMappingAMapToData(a, map, containsNull); return ColGroupDDC.create(colIndexes, d, m, null); @@ -291,33 +312,44 @@ private AColGroup passThrough(ColumnEncoderComposite c) { } private AMapToData createMappingAMapToData(Array a, Map map, boolean containsNull) { - final int si = map.size(); - AMapToData m = MapToFactory.create(in.getNumRows(), si + (containsNull ? 1 : 0)); - Array.ArrayIterator it = a.getIterator(); - if(containsNull) { + try { - while(it.hasNext()) { - Object v = it.next(); - if(v != null) - m.set(it.getIndex(), map.get(v).intValue()); - else - m.set(it.getIndex(), si); + final int si = map.size(); + AMapToData m = MapToFactory.create(in.getNumRows(), si + (containsNull ? 1 : 0)); + Array.ArrayIterator it = a.getIterator(); + if(containsNull) { + + while(it.hasNext()) { + Object v = it.next(); + try{ + if(v != null) + m.set(it.getIndex(), map.get(v).intValue() -1); + else + m.set(it.getIndex(), si); + } + catch(Exception e){ + throw new RuntimeException("failed on " + v +" " + a.getValueType(), e); + } + } } - } - else { - while(it.hasNext()) { - Object v = it.next(); - m.set(it.getIndex(), map.get(v).intValue()); + else { + while(it.hasNext()) { + Object v = it.next(); + m.set(it.getIndex(), map.get(v).intValue() -1); + } } + return m; + } + catch(Exception e) { + throw new RuntimeException("failed constructing map: " + map, e); } - return m; } private AMapToData createHashMappingAMapToData(Array a, int k, boolean nulls) { AMapToData m = MapToFactory.create(a.size(), k + (nulls ? 1 : 0)); if(nulls) { for(int i = 0; i < a.size(); i++) { - double h = a.hashDouble(i); + double h = Math.abs(a.hashDouble(i)); if(Double.isNaN(h)) { m.set(i, k); } @@ -328,7 +360,7 @@ private AMapToData createHashMappingAMapToData(Array a, int k, boolean nulls) } else { for(int i = 0; i < a.size(); i++) { - double h = a.hashDouble(i); + double h = Math.abs(a.hashDouble(i)); m.set(i, (int) h % k); } } @@ -342,7 +374,7 @@ private AColGroup hash(ColumnEncoderComposite c) { int domain = (int) CEHash.getK(); boolean nulls = a.containsNull(); IColIndex colIndexes = ColIndexFactory.create(0, 1); - if(domain == 1 && ! nulls) + if(domain == 1 && !nulls) return ColGroupConst.create(colIndexes, new double[] {1}); MatrixBlock incrementing = new MatrixBlock(domain + (nulls ? 1 : 0), 1, false); diff --git a/src/main/java/org/apache/sysds/runtime/transform/encode/MultiColumnEncoder.java b/src/main/java/org/apache/sysds/runtime/transform/encode/MultiColumnEncoder.java index c32cc4b220b..f1813e29a77 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/encode/MultiColumnEncoder.java +++ b/src/main/java/org/apache/sysds/runtime/transform/encode/MultiColumnEncoder.java @@ -102,7 +102,6 @@ public MatrixBlock encode(CacheBlock in, boolean compressedOut) { } public MatrixBlock encode(CacheBlock in, int k, boolean compressedOut){ - deriveNumRowPartitions(in, k); try { if(isCompressedTransformEncode(in, compressedOut)) @@ -112,7 +111,8 @@ else if(k > 1 && !MULTI_THREADED_STAGES && !hasLegacyEncoder()) { DependencyThreadPool pool = new DependencyThreadPool(k); LOG.debug("Encoding with full DAG on " + k + " Threads"); try { - pool.submitAllAndWait(getEncodeTasks(in, out, pool)); + List> tasks = getEncodeTasks(in, out, pool); + pool.submitAllAndWait(tasks); } finally{ pool.shutdown(); @@ -296,10 +296,11 @@ private void buildMT(CacheBlock in, int k) { pool.submitAllAndWait(getBuildTasks(in)); } catch(ExecutionException | InterruptedException e) { - LOG.error("MT Column build failed"); - e.printStackTrace(); + throw new RuntimeException(e); + } + finally{ + pool.shutdown(); } - pool.shutdown(); } public void legacyBuild(FrameBlock in) { @@ -412,10 +413,11 @@ private void applyMT(CacheBlock in, MatrixBlock out, int outputCol, int k) { pool.submitAllAndWait(getApplyTasks(in, out, outputCol)); } catch(ExecutionException | InterruptedException e) { - LOG.error("MT Column apply failed"); - e.printStackTrace(); + throw new DMLRuntimeException(e); + } + finally{ + pool.shutdown(); } - pool.shutdown(); } private void deriveNumRowPartitions(CacheBlock in, int k) { @@ -679,8 +681,7 @@ public void allocateMetaData(FrameBlock meta) { @Override public FrameBlock getMetaData(FrameBlock meta) { - getMetaData(meta, 1); - return meta; + return getMetaData(meta, 1); } public FrameBlock getMetaData(FrameBlock meta, int k) { @@ -691,19 +692,21 @@ public FrameBlock getMetaData(FrameBlock meta, int k) { meta = new FrameBlock(_columnEncoders.size(), ValueType.STRING); this.allocateMetaData(meta); if (k > 1) { + ExecutorService pool = CommonThreadPool.get(k); try { - ExecutorService pool = CommonThreadPool.get(k); ArrayList> tasks = new ArrayList<>(); for(ColumnEncoder columnEncoder : _columnEncoders) tasks.add(new ColumnMetaDataTask<>(columnEncoder, meta)); List> taskret = pool.invokeAll(tasks); - pool.shutdown(); for (Future task : taskret) - task.get(); + task.get(); } catch(Exception ex) { throw new DMLRuntimeException(ex); } + finally{ + pool.shutdown(); + } } else { for(ColumnEncoder columnEncoder : _columnEncoders) @@ -1167,8 +1170,8 @@ private static class ApplyTasksWrapperTask extends DependencyWrapperTask private final ColumnEncoder _encoder; private final MatrixBlock _out; private final CacheBlock _in; - private int _offset = -1; // offset dude to dummycoding in - // previous columns needs to be updated by external task! + /** Offset because of dummmy coding such that the column id is correct. */ + private int _offset = -1; private ApplyTasksWrapperTask(ColumnEncoder encoder, CacheBlock in, MatrixBlock out, DependencyThreadPool pool) { @@ -1189,7 +1192,7 @@ public Object call() throws Exception { // and _outputCol has been updated! if(_offset == -1) throw new DMLRuntimeException( - "OutputCol for apply task wrapper has not been updated!, Most likely some " + "concurrency issues"); + "OutputCol for apply task wrapper has not been updated!, Most likely some concurrency issues\n " + this); return super.call(); } diff --git a/src/main/java/org/apache/sysds/runtime/util/HDFSTool.java b/src/main/java/org/apache/sysds/runtime/util/HDFSTool.java index 4fe6247e561..c5c3b299735 100644 --- a/src/main/java/org/apache/sysds/runtime/util/HDFSTool.java +++ b/src/main/java/org/apache/sysds/runtime/util/HDFSTool.java @@ -21,12 +21,14 @@ import java.io.BufferedReader; import java.io.BufferedWriter; +import java.io.File; import java.io.FileNotFoundException; import java.io.IOException; import java.io.InputStream; import java.io.InputStreamReader; import java.io.OutputStream; import java.io.OutputStreamWriter; +import java.nio.file.Files; import java.text.SimpleDateFormat; import java.util.Arrays; import java.util.Date; @@ -175,26 +177,47 @@ public static boolean isHDFSFileEmpty(String dir) throws IOException { } public static boolean isFileEmpty(FileSystem fs, Path dir) throws IOException { - FileStatus fstat = fs.getFileStatus(dir); - - if( fstat.isDirectory() - || IOUtilFunctions.isObjectStoreFileScheme(dir) ) - { - // it is a directory - FileStatus[] stats = fs.listStatus(dir); - if (stats != null) { - for (FileStatus stat : stats) { - if (stat.getLen() > 0) + if(fs instanceof LocalFileSystem) { + // use local Java filesystem, this is much faster. + java.io.File f = new java.io.File(dir.toString()); + if(f.exists()){ + + if(f.isDirectory()) { + java.io.File[] fff = f.listFiles(); + if(fff.length == 0) return false; + for(File ff : fff) { + if(Files.size(ff.toPath()) > 0) + return false; + } + return true; } - return true; - } else { - return true; + else + return Files.size(f.toPath()) <= 0; + } + else return false; + } + else{ + FileStatus fstat = fs.getFileStatus(dir); + + if(fstat.isDirectory() || IOUtilFunctions.isObjectStoreFileScheme(dir)) { + // it is a directory + FileStatus[] stats = fs.listStatus(dir); + if(stats != null) { + for(FileStatus stat : stats) { + if(stat.getLen() > 0) + return false; + } + return true; + } + else { + return true; + } + } + else { + // it is a regular file + return fstat.getLen() == 0; } - } - else { - // it is a regular file - return (fstat.getLen() == 0); } } diff --git a/src/main/java/org/apache/sysds/runtime/util/UtilFunctions.java b/src/main/java/org/apache/sysds/runtime/util/UtilFunctions.java index 74cf1d45ef7..967855814fa 100644 --- a/src/main/java/org/apache/sysds/runtime/util/UtilFunctions.java +++ b/src/main/java/org/apache/sysds/runtime/util/UtilFunctions.java @@ -514,13 +514,19 @@ public static double objectToDouble(ValueType vt, Object in) { case BOOLEAN: return ((Boolean)in) ? 1 : 0; case CHARACTER: return (Character)in; case STRING: + String inStr = (String) in; try { - return !((String) in).isEmpty() ? Double.parseDouble((String) in) : 0; + return !(inStr).isEmpty() ? Double.parseDouble(inStr) : 0; } catch(NumberFormatException e) { - if(in.equals("true")) + final int len = inStr.length(); + if(len == 1 && inStr.equalsIgnoreCase("T")) return 1.0; - else if(in.equals("false")) + else if (len == 1 && inStr.equalsIgnoreCase("F")) + return 0.0; + else if(inStr.equalsIgnoreCase("true")) + return 1.0; + else if(inStr.equalsIgnoreCase("false")) return 0.0; else throw new DMLRuntimeException("failed parsing object to double",e); diff --git a/src/test/java/org/apache/sysds/performance/micro/FrameCompressedTransform.java b/src/test/java/org/apache/sysds/performance/micro/FrameCompressedTransform.java new file mode 100644 index 00000000000..3536b282bae --- /dev/null +++ b/src/test/java/org/apache/sysds/performance/micro/FrameCompressedTransform.java @@ -0,0 +1,175 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.performance.micro; + +import org.apache.sysds.runtime.data.SparseBlock.Type; +import org.apache.sysds.runtime.frame.data.FrameBlock; +import org.apache.sysds.runtime.frame.data.columns.Array; +import org.apache.sysds.runtime.frame.data.columns.IntegerArray; +import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import org.apache.sysds.runtime.transform.encode.EncoderFactory; +import org.apache.sysds.runtime.transform.encode.MultiColumnEncoder; + +public class FrameCompressedTransform { + + static int specCols = 5; + static String spec = "{ids:true,dummycode:[1,2,3,4,5]}"; + + public static void main(String[] args) { + + // scaleRows(); + // scaleDistinct(); + scaleCols(); + } + + public static void scaleRows() { + + System.out.println("Rows,Comp,MCSR,CSR,COO,Dense"); + for(int i = 1; i < 300; i += 1) { + System.out.print(i + ","); + System.out.println(getSize(genFrameBlock(i, 1000))); + } + + for(int i = 300; i < 16000; i += 100) { + System.out.print(i + ","); + System.out.println(getSize(genFrameBlock(i, 1000))); + } + + for(int i = 16000; i < 160000; i += 1000) { + System.out.print(i + ","); + System.out.println(getSize(genFrameBlock(i, 1000))); + } + } + + public static void scaleDistinct() { + + System.out.println("Distinct,Comp,MCSR,CSR,COO,Dense"); + for(int i = 1; i < 10; i += 1) { + System.out.print(i + ","); + System.out.println(getSize(genFrameBlock(100000, i))); + } + + for(int i = 10; i < 100; i += 10) { + System.out.print(i + ","); + System.out.println(getSize(genFrameBlock(100000, i))); + } + + for(int i = 100; i < 100000; i += 100) { + System.out.print(i + ","); + System.out.println(getSize(genFrameBlock(100000, i))); + } + + } + + public static void scaleCols() { + + System.out.println("Cols,Comp,MCSR,CSR,COO,Dense"); + for(int i = 1; i < 10; i += 1) { + System.out.print(i + ","); + System.out.println(getSize(genFrameBlock(100000, i, 1000))); + } + + for(int i = 10; i < 100; i += 10) { + System.out.print(i + ","); + System.out.println(getSize(genFrameBlock(100000, i, 1000))); + } + + for(int i = 100; i < 1000; i += 100) { + System.out.print(i + ","); + System.out.println(getSize(genFrameBlock(100000, i, 1000))); + } + + for(int i = 1000; i < 10000; i += 1000) { + System.out.print(i + ","); + System.out.println(getSize(genFrameBlock(100000, i, 1000))); + } + + for(int i = 10000; i < 20000; i += 10000) { + System.out.print(i + ","); + System.out.println(getSize(genFrameBlock(100000, i, 1000))); + } + + } + + private static String getSize(FrameBlock e) { + if(specCols != e.getNumColumns()) + createSpec(e.getNumColumns()); + MultiColumnEncoder encoderCompressed = // + EncoderFactory.createEncoder(spec, e.getColumnNames(), e.getNumColumns(), null); + MatrixBlock outCompressed = encoderCompressed.encode(e, 16, true); + + long compSize = outCompressed.getInMemorySize(); + + long denseSize = outCompressed.estimateSizeDenseInMemory(); + long csr = outCompressed.estimateSizeSparseInMemory(Type.CSR); + long mcsr = outCompressed.estimateSizeSparseInMemory(Type.MCSR); + long coo = outCompressed.estimateSizeSparseInMemory(Type.COO); + + // MatrixBlock uc = CompressedMatrixBlock.getUncompressed(outCompressed); + // uc.denseToSparse(true); + + // SparseBlock sb = uc.getSparseBlock(); + // if(sb == null) { + // System.out.println(uc); + // System.exit(-1); + // } + + // long csr = new MatrixBlock(uc.getNumRows(), uc.getNumColumns(), uc.getNonZeros(), new SparseBlockCSR(sb)) + // .getInMemorySize(); + // long mcsr = new MatrixBlock(uc.getNumRows(), uc.getNumColumns(), uc.getNonZeros(), new SparseBlockMCSR(sb)) + // .getInMemorySize(); + // long coo = new MatrixBlock(uc.getNumRows(), uc.getNumColumns(), uc.getNonZeros(), new SparseBlockCOO(sb)) + // .getInMemorySize(); + + // long denseSize = uc.estimateSizeDenseInMemory(); + + return compSize + "," + mcsr + "," + csr + "," + coo + "," + denseSize; + } + + private static void createSpec(int nCol) { + StringBuilder sb = new StringBuilder(); + sb.append("{ids:true,dummycode:[1"); + for(int i = 1; i < nCol; i++) { + sb.append(","); + sb.append(i + 1); + } + sb.append("]}"); + spec = sb.toString(); + specCols = nCol; + } + + private static FrameBlock genFrameBlock(int nRow, int nDistinct) { + IntegerArray a = new IntegerArray(new int[nRow]); + for(int i = 0; i < nRow; i++) + a.set(i, i % nDistinct); + return new FrameBlock(new Array[] {a, a, a, a, a}); + } + + private static FrameBlock genFrameBlock(int nRow, int cols, int nDistinct) { + IntegerArray a = new IntegerArray(new int[nRow]); + for(int i = 0; i < nRow; i++) + a.set(i, i % nDistinct); + Array[] r = new Array[cols]; + for(int i = 0; i < cols; i++) + r[i] = a; + + return new FrameBlock(r); + } +} diff --git a/src/test/java/org/apache/sysds/performance/micro/InformationLoss.java b/src/test/java/org/apache/sysds/performance/micro/InformationLoss.java new file mode 100644 index 00000000000..88fd767e76d --- /dev/null +++ b/src/test/java/org/apache/sysds/performance/micro/InformationLoss.java @@ -0,0 +1,186 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.performance.micro; + +import java.io.File; +import java.io.IOException; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Future; + +import org.apache.sysds.performance.generators.FrameFile; +import org.apache.sysds.runtime.frame.data.FrameBlock; +import org.apache.sysds.runtime.functionobjects.Builtin; +import org.apache.sysds.runtime.functionobjects.Builtin.BuiltinCode; +import org.apache.sysds.runtime.functionobjects.Divide; +import org.apache.sysds.runtime.functionobjects.Minus; +import org.apache.sysds.runtime.io.FileFormatPropertiesCSV; +import org.apache.sysds.runtime.io.MatrixWriter; +import org.apache.sysds.runtime.io.WriterTextCSV; +import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import org.apache.sysds.runtime.matrix.data.Pair; +import org.apache.sysds.runtime.matrix.operators.BinaryOperator; +import org.apache.sysds.runtime.matrix.operators.UnaryOperator; +import org.apache.sysds.runtime.transform.decode.Decoder; +import org.apache.sysds.runtime.transform.decode.DecoderFactory; +import org.apache.sysds.runtime.transform.encode.EncoderFactory; +import org.apache.sysds.runtime.transform.encode.MultiColumnEncoder; +import org.apache.sysds.runtime.util.CommonThreadPool; +import org.apache.sysds.runtime.util.DataConverter; +import org.apache.sysds.test.TestUtils; + +/** + * evaluate the different loss in accuracy on different number of distinct values on transform specifications. + */ +public class InformationLoss { + + public static void main(String[] args) throws Exception { + + String frame_path = args[0]; + String binningTechnique = args[1]; + + writeRandomMatrix(frame_path); + + final Pair p = readFrame(frame_path); + final FrameBlock f = p.getKey(); + final MatrixBlock org = p.getValue(); + + System.gc(); // indicate to do garbage collection here. + + for(int i = 1; i < 20; i++) { + String spec = generateSpec(i, f.getNumColumns(), binningTechnique); + System.out.print(i + ","); + calculateLoss(f, org, spec); + } + + for(int i = 20; i < 200; i += 10) { + String spec = generateSpec(i, f.getNumColumns(), binningTechnique); + System.out.print(i + ","); + calculateLoss(f, org, spec); + } + + for(int i = 200; i <= 2000; i += 100) { + String spec = generateSpec(i, f.getNumColumns(), binningTechnique); + System.out.print(i + ","); + calculateLoss(f, org, spec); + } + + } + + private static Pair readFrame(String path) throws Exception { + FrameBlock f = FrameFile.create(path).take(); + f = f.applySchema(f.detectSchema(16), 16); // apply scheme + + MatrixBlock org = DataConverter.convertToMatrixBlock(f); + f = null; + System.gc(); // cleanup original frame. + + // normalize org. + org = org// + .replaceOperations(null, Double.NaN, 0)// + .replaceOperations(null, Double.POSITIVE_INFINITY, 0).replaceOperations(null, Double.NEGATIVE_INFINITY, 0); + + Pair mm = getMinMax(org); + + // normalize org to 0-1 range + org = org // + .binaryOperations(new BinaryOperator(Minus.getMinusFnObject()), mm.getKey()) + .binaryOperations(new BinaryOperator(Divide.getDivideFnObject()), + mm.getValue().binaryOperations(new BinaryOperator(Minus.getMinusFnObject()), mm.getKey())) + .replaceOperations(null, Double.NaN, 0); + + f = DataConverter.convertToFrameBlock(org, 16); + + return new Pair<>(f, org); + } + + private static Pair getMinMax(final MatrixBlock org) throws Exception { + ExecutorService pool = CommonThreadPool.get(16); + + Future minF = pool.submit(() -> (MatrixBlock) org.colMin(16)); + Future maxF = pool.submit(() -> (MatrixBlock) org.colMax(16)); + + MatrixBlock min = minF.get(); + MatrixBlock max = maxF.get(); + + return new Pair<>(min, max); + } + + private static void writeRandomMatrix(String path) throws IOException { + if(!new File(path).exists()) { + MatrixWriter w = new WriterTextCSV(new FileFormatPropertiesCSV(false, ",", false)); + MatrixBlock mb = TestUtils.generateTestMatrixBlock(1000, 10, 0, 1, 0.5, 23); + w.writeMatrixToHDFS(mb, path, mb.getNumRows(), mb.getNumColumns(), 1000, mb.getNonZeros(), false); + } + } + + private static FrameBlock encodeAndDecode(FrameBlock f, MatrixBlock org, String spec) { + + MultiColumnEncoder encoder = // + EncoderFactory.createEncoder(spec, f.getColumnNames(), f.getNumColumns(), null); + MatrixBlock binned = encoder.encode(f, 16, true); + Decoder d = DecoderFactory.createDecoder(spec, f.getColumnNames(false), f.getSchema(), encoder.getMetaData(null), + binned.getNumColumns()); + FrameBlock dr = new FrameBlock(f.getSchema()); + d.decode(binned, dr, 16); + return dr; + } + + private static MatrixBlock delta(FrameBlock f, MatrixBlock org, String spec) { + return DataConverter// + .convertToMatrixBlock(encodeAndDecode(f, org, spec)) + .binaryOperations(new BinaryOperator(Minus.getMinusFnObject(), 16), org) + // .binaryOperations(new BinaryOperator(Divide.getDivideFnObject(), 16), org) + .unaryOperations(new UnaryOperator(Builtin.getBuiltinFnObject(BuiltinCode.ABS), 16), null) + .replaceOperations(null, Double.NaN, 0); + + } + + private static void calculateLoss(FrameBlock f, MatrixBlock org, String spec) throws Exception { + + final MatrixBlock delta = delta(f, org, spec); + ExecutorService pool = CommonThreadPool.get(16); + + Future minF = pool.submit(() -> delta.min(16)); + Future maxF = pool.submit(() -> delta.max(16).quickGetValue(0, 0)); + Future meanF = pool + .submit(() -> delta.sum(16).quickGetValue(0, 0) / (delta.getNumRows() * delta.getNumColumns())); + + double min = minF.get(); + double max = maxF.get(); + double mean = meanF.get(); + + pool.shutdown(); + System.out.println(String.format("%e, %e, %e", min, max, mean)); + } + + private static String generateSpec(int bins, int cols, String technique) { + StringBuilder sb = new StringBuilder(); + sb.append("{\"ids\":true,\"bin\":["); + for(int i = 0; i < cols; i++) { + sb.append(String.format("{\"id\":%d,\"method\":\"%s\",\"numbins\":%d}", i + 1, technique, bins)); + if(i + 1 < cols) + sb.append(','); + } + sb.append("]}"); + return sb.toString(); + + } + +} diff --git a/src/test/java/org/apache/sysds/test/component/compress/colgroup/ColGroupNegativeTests.java b/src/test/java/org/apache/sysds/test/component/compress/colgroup/ColGroupNegativeTests.java index e89d41295ff..dd5a65a0f77 100644 --- a/src/test/java/org/apache/sysds/test/component/compress/colgroup/ColGroupNegativeTests.java +++ b/src/test/java/org/apache/sysds/test/component/compress/colgroup/ColGroupNegativeTests.java @@ -353,7 +353,7 @@ public AColGroup append(AColGroup g) { } @Override - protected AColGroup appendNInternal(AColGroup[] groups) { + protected AColGroup appendNInternal(AColGroup[] groups, int blen, int rlen) { // TODO Auto-generated method stub return null; } @@ -591,7 +591,7 @@ public AColGroup append(AColGroup g) { } @Override - protected AColGroup appendNInternal(AColGroup[] groups) { + protected AColGroup appendNInternal(AColGroup[] groups, int blen, int rlen) { // TODO Auto-generated method stub return null; } diff --git a/src/test/java/org/apache/sysds/test/component/compress/colgroup/ColGroupTest.java b/src/test/java/org/apache/sysds/test/component/compress/colgroup/ColGroupTest.java index 3941ba7463b..14f4a56c180 100644 --- a/src/test/java/org/apache/sysds/test/component/compress/colgroup/ColGroupTest.java +++ b/src/test/java/org/apache/sysds/test/component/compress/colgroup/ColGroupTest.java @@ -2282,8 +2282,7 @@ private void appendSelfVerification(AColGroup g) { try { AColGroup g2 = g.append(g); - AColGroup g2n = AColGroup.appendN(new AColGroup[] {g, g}); - + AColGroup g2n = AColGroup.appendN(new AColGroup[] {g, g}, nRow, nRow*2); if(g2 != null && g2n != null) { double s2 = g2.getSum(nRow * 2); double s = g.getSum(nRow) * 2; @@ -2294,6 +2293,9 @@ private void appendSelfVerification(AColGroup g) { UA_ROW(InstructionUtils.parseBasicAggregateUnaryOperator("uar+", 1), 0, nRow * 2, g2, g2n, nRow * 2); } } + catch(NotImplementedException e){ + // okay + } catch(Exception e) { e.printStackTrace(); fail(e.getMessage()); diff --git a/src/test/java/org/apache/sysds/test/component/compress/colgroup/CustomColGroupTest.java b/src/test/java/org/apache/sysds/test/component/compress/colgroup/CustomColGroupTest.java new file mode 100644 index 00000000000..b4be0a4a486 --- /dev/null +++ b/src/test/java/org/apache/sysds/test/component/compress/colgroup/CustomColGroupTest.java @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.component.compress.colgroup; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.apache.sysds.runtime.compress.colgroup.AColGroup; +import org.apache.sysds.runtime.compress.colgroup.ColGroupEmpty; +import org.apache.sysds.runtime.compress.colgroup.ColGroupSDCSingleZeros; +import org.apache.sysds.runtime.compress.colgroup.dictionary.PlaceHolderDict; +import org.apache.sysds.runtime.compress.colgroup.indexes.ColIndexFactory; +import org.apache.sysds.runtime.compress.colgroup.indexes.IColIndex; +import org.apache.sysds.runtime.compress.colgroup.offset.OffsetFactory; +import org.junit.Test; + +public class CustomColGroupTest { + protected static final Log LOG = LogFactory.getLog(CustomColGroupTest.class.getName()); + + @Test + public void appendEmptyToSDCZero() { + IColIndex i = ColIndexFactory.createI(3); + AColGroup e = new ColGroupEmpty(i); + AColGroup s = ColGroupSDCSingleZeros.create(i, 10, new PlaceHolderDict(1), + OffsetFactory.createOffset(new int[] {5, 10}), null); + + AColGroup r = AColGroup.appendN(new AColGroup[] {e, s}, 20, 40); + + assertTrue(r instanceof ColGroupSDCSingleZeros); + assertEquals(r.getColIndices(),i); + assertEquals(((ColGroupSDCSingleZeros)r).getNumRows(), 40); + + } + + + @Test + public void appendEmptyToSDCZero2() { + IColIndex i = ColIndexFactory.createI(3); + AColGroup e = new ColGroupEmpty(i); + AColGroup s = ColGroupSDCSingleZeros.create(i, 10, new PlaceHolderDict(1), + OffsetFactory.createOffset(new int[] {5, 10}), null); + + AColGroup r = AColGroup.appendN(new AColGroup[] {e, s, e, e, s, s, e}, 20, 7*20); + LOG.error(r); + assertTrue(r instanceof ColGroupSDCSingleZeros); + assertEquals(r.getColIndices(),i); + assertEquals(((ColGroupSDCSingleZeros)r).getNumRows(), 7 * 20); + + } +} diff --git a/src/test/java/org/apache/sysds/test/component/compress/indexes/IndexesTest.java b/src/test/java/org/apache/sysds/test/component/compress/indexes/IndexesTest.java index 088c64798b0..871636ed477 100644 --- a/src/test/java/org/apache/sysds/test/component/compress/indexes/IndexesTest.java +++ b/src/test/java/org/apache/sysds/test/component/compress/indexes/IndexesTest.java @@ -129,7 +129,11 @@ public static Collection data() { tests.add(createWithArray(144, 32)); tests.add(createWithArray(13, 23)); tests.add(createWithArray(145, 14)); + tests.add(createWithArray(300, 14)); tests.add(createWithArray(23, 51515)); + tests.add(createWithArray(128, 321)); + tests.add(createWithArray(129, 1324)); + tests.add(createWithArray(127, 1323)); tests.add(createWithArray(66, 132)); tests.add(createRangeWithArray(66, 132)); tests.add(createRangeWithArray(32, 132)); @@ -170,13 +174,16 @@ public void testSerialize() { DataOutputStream fos = new DataOutputStream(bos); actual.write(fos); + long actualSize = bos.size(); // Serialize in ByteArrayInputStream bis = new ByteArrayInputStream(bos.toByteArray()); DataInputStream fis = new DataInputStream(bis); IColIndex n = ColIndexFactory.read(fis); + long expectedSize = actual.getExactSizeOnDisk(); compare(actual, n); + assertEquals(actual.toString(), expectedSize, actualSize); } catch(IOException e) { throw new RuntimeException("Error in io " + actual, e); @@ -198,7 +205,7 @@ public void testSerializeSize() { long actualSize = bos.size(); long expectedSize = actual.getExactSizeOnDisk(); - assertEquals(expectedSize, actualSize); + assertEquals(actual.toString(), expectedSize, actualSize); } catch(IOException e) { throw new RuntimeException("Error in io", e); @@ -606,7 +613,6 @@ private static void compare(IColIndex expected, IColIndex actual) { } private static void compare(int[] expected, IIterate actual) { - // LOG.error(expected); for(int i = 0; i < expected.length; i++) { assertTrue(actual.hasNext()); assertEquals(i, actual.i()); diff --git a/src/test/java/org/apache/sysds/test/component/compress/io/IOCompressionTestUtils.java b/src/test/java/org/apache/sysds/test/component/compress/io/IOCompressionTestUtils.java index 0eb66e619ef..e2f24e00549 100644 --- a/src/test/java/org/apache/sysds/test/component/compress/io/IOCompressionTestUtils.java +++ b/src/test/java/org/apache/sysds/test/component/compress/io/IOCompressionTestUtils.java @@ -61,8 +61,8 @@ protected static void verifyEquivalence(MatrixBlock a, MatrixBlock b) { // assertTrue("Disk size is not equivalent", a.getExactSizeOnDisk() > b.getExactSizeOnDisk()); } - public synchronized static MatrixBlock read(String path) throws Exception { - return ReaderCompressed.readCompressedMatrixFromHDFS(path); + public synchronized static MatrixBlock read(String path, long rlen, long clen, int blen) throws Exception { + return ReaderCompressed.readCompressedMatrixFromHDFS(path, rlen, clen, blen); } } diff --git a/src/test/java/org/apache/sysds/test/component/compress/io/IOEmpty.java b/src/test/java/org/apache/sysds/test/component/compress/io/IOEmpty.java index 53a0a8c7078..230c13b7268 100644 --- a/src/test/java/org/apache/sysds/test/component/compress/io/IOEmpty.java +++ b/src/test/java/org/apache/sysds/test/component/compress/io/IOEmpty.java @@ -69,7 +69,7 @@ public void writeEmpty() { public void writeEmptyAndRead() throws Exception { String n = getName(); write(n, 10, 10, 1000); - MatrixBlock mb = IOCompressionTestUtils.read(n); + MatrixBlock mb = IOCompressionTestUtils.read(n, 10, 10, 1000); IOCompressionTestUtils.verifyEquivalence(mb, new MatrixBlock(10, 10, 0.0)); } @@ -87,7 +87,7 @@ public void writeEmptyAndReadMultiBlock() throws Exception { write(n, 1000, 10, 100); File f = new File(n); assertTrue(f.isDirectory() || f.isFile()); - MatrixBlock mb = IOCompressionTestUtils.read(n); + MatrixBlock mb = IOCompressionTestUtils.read(n, 1000, 10, 100); IOCompressionTestUtils.verifyEquivalence(mb, new MatrixBlock(1000, 10, 0.0)); } diff --git a/src/test/java/org/apache/sysds/test/component/compress/io/IOSpark.java b/src/test/java/org/apache/sysds/test/component/compress/io/IOSpark.java index 5bfdacf42f5..4a49c2b33c3 100644 --- a/src/test/java/org/apache/sysds/test/component/compress/io/IOSpark.java +++ b/src/test/java/org/apache/sysds/test/component/compress/io/IOSpark.java @@ -138,75 +138,75 @@ public void readMultiBlockRowsAndColsIncompressable() { } @Test - public void writeSparkReadCPMultiColBlock() { + public void writeSparkReadMultiColBlock() { MatrixBlock mb = TestUtils.ceil(TestUtils.generateTestMatrixBlock(50, 124, 1, 3, 1.0, 2514)); - testWriteSparkReadCP(mb, 100, 100); + testWriteSparkRead(mb, 100, 100); } @Test - public void writeSparkReadCPMultiRowBlock() { + public void writeSparkReadMultiRowBlock() { MatrixBlock mb = TestUtils.ceil(TestUtils.generateTestMatrixBlock(1322, 33, 1, 3, 1.0, 2514)); - testWriteSparkReadCP(mb, 100, 100); + testWriteSparkRead(mb, 100, 100); } @Test - public void writeSparkReadCPSingleBlock() { + public void writeSparkReadSingleBlock() { MatrixBlock mb = TestUtils.ceil(TestUtils.generateTestMatrixBlock(50, 99, 1, 3, 1.0, 33)); - testWriteSparkReadCP(mb, 100, 100); + testWriteSparkRead(mb, 100, 100); } @Test - public void writeSparkReadCPMultiBlock() { + public void writeSparkReadMultiBlock() { MatrixBlock mb = TestUtils.ceil(TestUtils.generateTestMatrixBlock(580, 244, 1, 3, 1.0, 33)); - testWriteSparkReadCP(mb, 100, 100); + testWriteSparkRead(mb, 100, 100); } @Test - public void writeSparkReadCPMultiColBlockReblockUp() { + public void writeSparkReadMultiColBlockReblockUp() { MatrixBlock mb = TestUtils.ceil(TestUtils.generateTestMatrixBlock(50, 124, 1, 3, 1.0, 2514)); - testWriteSparkReadCP(mb, 100, 150); + testWriteSparkRead(mb, 100, 150); } @Test - public void writeSparkReadCPMultiRowBlockReblockUp() { + public void writeSparkReadMultiRowBlockReblockUp() { MatrixBlock mb = TestUtils.ceil(TestUtils.generateTestMatrixBlock(1322, 33, 1, 3, 1.0, 2514)); - testWriteSparkReadCP(mb, 100, 150); + testWriteSparkRead(mb, 100, 150); } @Test - public void writeSparkReadCPSingleBlockReblockUp() { + public void writeSparkReadSingleBlockReblockUp() { MatrixBlock mb = TestUtils.ceil(TestUtils.generateTestMatrixBlock(50, 99, 1, 3, 1.0, 33)); - testWriteSparkReadCP(mb, 100, 150); + testWriteSparkRead(mb, 100, 150); } @Test - public void writeSparkReadCPMultiBlockReblockUp() { + public void writeSparkReadMultiBlockReblockUp() { MatrixBlock mb = TestUtils.ceil(TestUtils.generateTestMatrixBlock(580, 244, 1, 3, 1.0, 33)); - testWriteSparkReadCP(mb, 100, 150); + testWriteSparkRead(mb, 100, 150); } @Test - public void writeSparkReadCPMultiColBlockReblockDown() { + public void writeSparkReadMultiColBlockReblockDown() { MatrixBlock mb = TestUtils.ceil(TestUtils.generateTestMatrixBlock(50, 124, 1, 3, 1.0, 2514)); - testWriteSparkReadCP(mb, 100, 80); + testWriteSparkRead(mb, 100, 80); } @Test - public void writeSparkReadCPMultiRowBlockReblockDown() { + public void writeSparkReadMultiRowBlockReblockDown() { MatrixBlock mb = TestUtils.ceil(TestUtils.generateTestMatrixBlock(1322, 33, 1, 3, 1.0, 2514)); - testWriteSparkReadCP(mb, 100, 80); + testWriteSparkRead(mb, 100, 80); } @Test - public void writeSparkReadCPSingleBlockReblockDown() { + public void writeSparkReadSingleBlockReblockDown() { MatrixBlock mb = TestUtils.ceil(TestUtils.generateTestMatrixBlock(50, 99, 1, 3, 1.0, 33)); - testWriteSparkReadCP(mb, 100, 80); + testWriteSparkRead(mb, 100, 80); } @Test - public void writeSparkReadCPMultiBlockReblockDown() { + public void writeSparkReadMultiBlockReblockDown() { MatrixBlock mb = TestUtils.ceil(TestUtils.generateTestMatrixBlock(580, 244, 1, 3, 1.0, 33)); - testWriteSparkReadCP(mb, 100, 80); + testWriteSparkRead(mb, 100, 80); } @Test @@ -269,11 +269,11 @@ public void testReblock_down_5() { testReblock(mb, 100, 25); } - private void testWriteSparkReadCP(MatrixBlock mb, int blen1, int blen2) { - testWriteSparkReadCP(mb, blen1, blen2, 1); + private void testWriteSparkRead(MatrixBlock mb, int blen1, int blen2) { + testWriteSparkRead(mb, blen1, blen2, 1); } - private void testWriteSparkReadCP(MatrixBlock mb, int blen1, int blen2, int rep) { + private void testWriteSparkRead(MatrixBlock mb, int blen1, int blen2, int rep) { try { CompressedMatrixBlock.debug = true; @@ -300,7 +300,7 @@ private void testWriteSparkReadCP(MatrixBlock mb, int blen1, int blen2, int rep) Thread.sleep(100); // Read locally the spark block written. - MatrixBlock mbr = IOCompressionTestUtils.read(f2); + MatrixBlock mbr = IOCompressionTestUtils.read(f2, mb.getNumRows(), mb.getNumColumns(), blen2); IOCompressionTestUtils.verifyEquivalence(mb, mbr); LOG.warn("IOSpark Writer Read: " + t.stop()); } @@ -310,7 +310,7 @@ private void testWriteSparkReadCP(MatrixBlock mb, int blen1, int blen2, int rep) if(rep < 3) { Thread.sleep(1000); - testWriteSparkReadCP(mb, blen1, blen2, rep + 1); + testWriteSparkRead(mb, blen1, blen2, rep + 1); return; } } diff --git a/src/test/java/org/apache/sysds/test/component/compress/io/IOTest.java b/src/test/java/org/apache/sysds/test/component/compress/io/IOTest.java index 4493374c765..3708b52e7d5 100644 --- a/src/test/java/org/apache/sysds/test/component/compress/io/IOTest.java +++ b/src/test/java/org/apache/sysds/test/component/compress/io/IOTest.java @@ -26,6 +26,7 @@ import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; +import org.apache.sysds.hops.OptimizerUtils; import org.apache.sysds.runtime.compress.CompressedMatrixBlockFactory; import org.apache.sysds.runtime.compress.io.WriterCompressed; import org.apache.sysds.runtime.matrix.data.MatrixBlock; @@ -139,7 +140,8 @@ protected static void writeAndReadR(MatrixBlock mb, int rep) throws Exception { WriterCompressed.writeCompressedMatrixToHDFS(mb, filename); File f = new File(filename); assertTrue(f.isFile() || f.isDirectory()); - MatrixBlock mbr = IOCompressionTestUtils.read(filename); + MatrixBlock mbr = IOCompressionTestUtils.read(filename, mb.getNumRows(), mb.getNumColumns(), + OptimizerUtils.DEFAULT_BLOCKSIZE); IOCompressionTestUtils.verifyEquivalence(mb, mbr); } catch(Exception e) { @@ -180,7 +182,7 @@ protected static void writeAndReadR(MatrixBlock mb, int blen, int rep) throws Ex WriterCompressed.writeCompressedMatrixToHDFS(mb, filename, blen); File f = new File(filename); assertTrue(f.isFile() || f.isDirectory()); - MatrixBlock mbr = IOCompressionTestUtils.read(filename); + MatrixBlock mbr = IOCompressionTestUtils.read(filename, mb.getNumRows(), mb.getNumColumns(), blen); IOCompressionTestUtils.verifyEquivalence(mb, mbr); } catch(Exception e) { diff --git a/src/test/java/org/apache/sysds/test/component/compress/offset/CustomOffsetTest.java b/src/test/java/org/apache/sysds/test/component/compress/offset/CustomOffsetTest.java new file mode 100644 index 00000000000..d37373a7589 --- /dev/null +++ b/src/test/java/org/apache/sysds/test/component/compress/offset/CustomOffsetTest.java @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.component.compress.offset; + +import static org.junit.Assert.assertEquals; + +import org.apache.sysds.runtime.compress.colgroup.offset.AOffset; +import org.apache.sysds.runtime.compress.colgroup.offset.AOffset.OffsetSliceInfo; +import org.apache.sysds.runtime.compress.colgroup.offset.OffsetFactory; +import org.junit.Test; + +public class CustomOffsetTest { + + @Test + public void sliceE() { + AOffset a = OffsetFactory.createOffset(new int[] {441, 1299, 14612, 16110, 18033, 18643, 18768, 25798, 32315}); + + OffsetSliceInfo i = a.slice(1000, 2000); + System.out.println(a); + assertEquals(OffsetFactory.createOffset(new int[] {299}), i.offsetSlice); + } +} diff --git a/src/test/java/org/apache/sysds/test/component/compress/util/ArrCountMapTest.java b/src/test/java/org/apache/sysds/test/component/compress/util/ArrCountMapTest.java index 90b218243ba..096739b13a6 100644 --- a/src/test/java/org/apache/sysds/test/component/compress/util/ArrCountMapTest.java +++ b/src/test/java/org/apache/sysds/test/component/compress/util/ArrCountMapTest.java @@ -195,7 +195,7 @@ public void getId2() { assertEquals(19, m.getId(I(19.0))); } - @Test() + @Test public void sortBucketsSmall() { for(int i = 0; i < 9; i++) m.increment(I((double) i)); diff --git a/src/test/java/org/apache/sysds/test/component/compress/util/ArrayListTest.java b/src/test/java/org/apache/sysds/test/component/compress/util/ArrayListTest.java new file mode 100644 index 00000000000..50ae5097cf9 --- /dev/null +++ b/src/test/java/org/apache/sysds/test/component/compress/util/ArrayListTest.java @@ -0,0 +1,214 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.component.compress.util; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; + +import java.util.Arrays; + +import org.apache.sysds.runtime.compress.DMLCompressionException; +import org.apache.sysds.runtime.compress.utils.IntArrayList; +import org.junit.Test; + +public class ArrayListTest { + + @Test + public void allocate() { + assertEquals(4, new IntArrayList(4).extractValues().length); + } + + @Test + public void allocate2() { + assertEquals(16, new IntArrayList(16).extractValues().length); + } + + @Test + public void sizeEmpty() { + assertEquals(0, new IntArrayList(16).size()); + } + + @Test + public void sizeEmpty2() { + assertEquals(0, new IntArrayList(32).size()); + } + + @Test + public void sizeEmpty3() { + assertEquals(0, new IntArrayList().size()); + } + + @Test(expected = DMLCompressionException.class) + public void directError() { + new IntArrayList(null); + } + + @Test + public void directAllocation() { + IntArrayList a = new IntArrayList(new int[] {1, 2, 3}); + + assertEquals(3, a.size()); + assertEquals(1, a.get(0)); + assertEquals(2, a.get(1)); + assertEquals(3, a.get(2)); + } + + @Test + public void appendValue() { + IntArrayList a = new IntArrayList(); + + for(int i = 0; i < 10; i++) { + a.appendValue(i); + } + + assertEquals(10, a.size()); + assertEquals(6, a.get(6)); + } + + @Test + public void appendValue2() { + IntArrayList a = new IntArrayList(new int[] {1, 2, 3}); + + for(int i = 0; i < 10; i++) { + a.appendValue(i); + } + + assertEquals(13, a.size()); + assertEquals(6, a.get(6 + 3)); + } + + @Test + public void appendValueArray() { + IntArrayList a = new IntArrayList(new int[] {1, 2, 3}); + IntArrayList b = new IntArrayList(new int[] {4, 5, 6}); + a.appendValue(b); + + assertEquals(6, a.size()); + assertEquals(6, a.get(5)); + } + + @Test + public void appendValueArray2() { + IntArrayList a = new IntArrayList(new int[] {1, 2, 3}); + IntArrayList b = new IntArrayList(new int[] {4, 5, 6}); + a.appendValue(b); + a.appendValue(b); + a.appendValue(b); + + assertEquals(12, a.size()); + assertEquals(6, a.get(5)); + } + + @Test + public void appendValueArray3() { + IntArrayList a = new IntArrayList(new int[] {1, 2, 3}); + IntArrayList b = new IntArrayList(new int[] {4, 5, 6}); + a.appendValue(b); + a.appendValue(b); + a.appendValue(b); + int[] ex = a.extractValues(); + + assertTrue(ex.length >= a.size()); + assertEquals(6, a.get(5)); + assertEquals(6, ex[5]); + } + + + @Test + public void appendValueArray4() { + IntArrayList a = new IntArrayList(new int[] {1, 2, 3}); + IntArrayList b = new IntArrayList(new int[] {4, 5, 6}); + for(int i = 0; i < 10; i++){ + + a.appendValue(b); + a.appendValue(1); + } + int[] ex = a.extractValues(); + + assertTrue(ex.length >= a.size()); + assertEquals(10*3+3 + 10, a.size()); + assertEquals(6, a.get(5)); + assertEquals(6, ex[5]); + } + + @Test + public void extract() { + IntArrayList a = new IntArrayList(); + for(int i = 0; i < 2; i++) + a.appendValue(i); + + int[] ex = a.extractValues(); + assertTrue(ex.length > a.size()); + int[] et = a.extractValues(true); + assertTrue(et.length == a.size()); + assertEquals(1, a.get(1)); + assertEquals(1, ex[1]); + assertEquals(1, et[1]); + } + + @Test + public void toStringTest() { + IntArrayList a = new IntArrayList(); + for(int i = 0; i < 2; i++) + a.appendValue(i); + String as = a.toString(); + // int[] ex = a.extractValues(); + // assertTrue(ex.length > a.size()); + int[] et = a.extractValues(true); + String es = Arrays.toString(et); + assertEquals(es, as); + } + + @Test + public void toStringEmpty() { + IntArrayList a = new IntArrayList(); + // for(int i = 0; i < 2; i++) + // a.appendValue(i); + String as = a.toString(); + // int[] ex = a.extractValues(); + // assertTrue(ex.length > a.size()); + int[] et = a.extractValues(true); + String es = Arrays.toString(et); + assertEquals(es, as); + } + + @Test + public void extractExactEach() { + IntArrayList a = new IntArrayList(); + for(int i = 0; i < 10; i++) { + assertEquals(i, a.extractValues(true).length); + assertTrue(i <= a.extractValues(false).length); + a.appendValue(i); + } + assertTrue(10 <= a.extractValues(false).length); + assertEquals(10, a.extractValues(true).length); + } + + @Test + public void reset1(){ + IntArrayList a = new IntArrayList(); + a.appendValue(1); + a.appendValue(2); + assertEquals(2, a.size()); + a.reset(); + assertEquals(0, a.size()); + + } +} diff --git a/src/test/java/org/apache/sysds/test/component/compress/util/CountTest.java b/src/test/java/org/apache/sysds/test/component/compress/util/CountTest.java index f2eeee5b2a9..cff948d5f5e 100644 --- a/src/test/java/org/apache/sysds/test/component/compress/util/CountTest.java +++ b/src/test/java/org/apache/sysds/test/component/compress/util/CountTest.java @@ -23,7 +23,11 @@ import static org.junit.Assert.assertNotEquals; import static org.junit.Assert.fail; +import org.apache.commons.lang3.NotImplementedException; +import org.apache.sysds.runtime.compress.utils.ACount; +import org.apache.sysds.runtime.compress.utils.ACount.DArrCounts; import org.apache.sysds.runtime.compress.utils.ACount.DCounts; +import org.apache.sysds.runtime.compress.utils.DblArray; import org.junit.Test; public class CountTest { @@ -85,4 +89,16 @@ public void inc() { assertEquals(4 + 3, h.get(3.0).count); assertEquals(1 + 3, h.get(1.0).count); } + + @Test(expected = NotImplementedException.class) + public void getDouble() { + ACount a = new DArrCounts(new DblArray(new double[] {1, 2}), 1); + a.get(0.0); + } + + @Test(expected = NotImplementedException.class) + public void incDouble() { + ACount a = new DArrCounts(new DblArray(new double[] {1, 2}), 1); + a.inc(0.0, 0, 1); + } } diff --git a/src/test/java/org/apache/sysds/test/component/compress/util/ListHashMapTest.java b/src/test/java/org/apache/sysds/test/component/compress/util/ListHashMapTest.java new file mode 100644 index 00000000000..fa4d6f7e343 --- /dev/null +++ b/src/test/java/org/apache/sysds/test/component/compress/util/ListHashMapTest.java @@ -0,0 +1,180 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.component.compress.util; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; + +import org.apache.sysds.runtime.compress.utils.DblArray; +import org.apache.sysds.runtime.compress.utils.DblArrayIntListHashMap; +import org.apache.sysds.runtime.compress.utils.IntArrayList; +import org.junit.Test; + +public class ListHashMapTest { + + @Test + public void add() { + DblArrayIntListHashMap m = new DblArrayIntListHashMap(); + DblArray a = new DblArray(new double[] {1, 2, 3}); + final int rep = 100; + for(int i = 0; i < rep; i++) { + m.appendValue(a, 0); + } + IntArrayList l = m.get(a); + assertEquals(rep, l.size()); + + } + + @Test + public void add2() { + DblArrayIntListHashMap m = new DblArrayIntListHashMap(); + DblArray a = new DblArray(new double[] {1, 2, 3}); + DblArray b = new DblArray(new double[] {1, 2, 4}); + final int rep = 100; + for(int i = 0; i < rep; i++) { + m.appendValue(a, i); + m.appendValue(b, i); + } + IntArrayList l = m.get(a); + assertEquals(rep, l.size()); + + } + + @Test + public void add3() { + DblArrayIntListHashMap m = new DblArrayIntListHashMap(); + DblArray b = new DblArray(new double[] {1, 2, 4}); + final int rep = 100; + for(int i = 0; i < rep; i++) { + DblArray a = new DblArray(new double[] {1, i, i}); + m.appendValue(a, i); + m.appendValue(b, i); + } + IntArrayList l = m.get(b); + assertEquals(rep, l.size()); + + } + + @Test + public void add4() { + DblArrayIntListHashMap m = new DblArrayIntListHashMap(); + DblArray b = new DblArray(new double[] {1, 2, 4}); + final int rep = 100; + for(int i = 0; i < rep; i++) { + DblArray a = new DblArray(new double[] {1, i, i}); + m.appendValue(a, i); + } + for(int i = 0; i < rep; i++) { + m.appendValue(b, i); + } + IntArrayList l = m.get(b); + assertEquals(rep, l.size()); + + } + + @Test + public void extractAll() { + DblArrayIntListHashMap m = new DblArrayIntListHashMap(); + DblArray b = new DblArray(new double[] {1, 2, 4}); + final int rep = 100; + for(int i = 0; i < rep; i++) { + DblArray a = new DblArray(new double[] {1, i, i}); + m.appendValue(a, i); + } + for(int i = 0; i < rep; i++) { + m.appendValue(b, i); + } + + assertEquals(rep + 1, m.extractValues().size()); + + } + + @Test + public void toStringWorks() { + DblArrayIntListHashMap m = new DblArrayIntListHashMap(); + DblArray b = new DblArray(new double[] {1, 2, 4}); + final int rep = 100; + for(int i = 0; i < rep; i++) { + DblArray a = new DblArray(new double[] {1, i, i}); + m.appendValue(a, i); + } + for(int i = 0; i < rep; i++) { + m.appendValue(b, i); + } + m.toString(); + } + + @Test + public void size() { + DblArrayIntListHashMap m = new DblArrayIntListHashMap(); + DblArray b = new DblArray(new double[] {1, 2, 4}); + final int rep = 100; + for(int i = 0; i < rep; i++) { + DblArray a = new DblArray(new double[] {1, i, i}); + m.appendValue(a, i); + assertEquals(i + 1, m.size()); + } + for(int i = 0; i < rep; i++) { + m.appendValue(b, i); + assertEquals(rep + 1, m.size()); + } + m.toString(); + } + + @Test + public void get() { + DblArrayIntListHashMap m = new DblArrayIntListHashMap(); + DblArray b = new DblArray(new double[] {1, 2, 4}); + final int rep = 100; + for(int i = 0; i < rep; i++) { + DblArray a = new DblArray(new double[] {1, i, i}); + assertTrue(m.get(a) == null); + m.appendValue(a, i); + assertEquals(i + 1, m.size()); + assertTrue(m.get(a) != null); + + } + for(int i = 0; i < rep; i++) { + m.appendValue(b, i); + assertEquals(rep + 1, m.size()); + } + m.toString(); + } + + @Test + public void getCustom() { + DblArrayIntListHashMap m = new DblArrayIntListHashMap(25); + DblArray b = new DblArray(new double[] {1, 2, 4}); + final int rep = 100; + for(int i = 0; i < rep; i++) { + DblArray a = new DblArray(new double[] {1, i, i}); + assertTrue(m.get(a) == null); + m.appendValue(a, i); + assertEquals(i + 1, m.size()); + assertTrue(m.get(a) != null); + + } + for(int i = 0; i < rep; i++) { + m.appendValue(b, i); + assertEquals(rep + 1, m.size()); + } + + } +} diff --git a/src/test/java/org/apache/sysds/test/component/federated/FederatedTestUtils.java b/src/test/java/org/apache/sysds/test/component/federated/FederatedTestUtils.java index 595fa6799f1..4d3796892aa 100644 --- a/src/test/java/org/apache/sysds/test/component/federated/FederatedTestUtils.java +++ b/src/test/java/org/apache/sysds/test/component/federated/FederatedTestUtils.java @@ -99,7 +99,6 @@ public static long putMatrixBlock(MatrixBlock mb, InetSocketAddress addr, int ti final FederatedRequest frq = new FederatedRequest(RequestType.PUT_VAR, null, id, mb); final Future fr = FederatedData.executeFederatedOperation(addr, frq); final FederatedResponse r = fr.get(timeout, TimeUnit.MILLISECONDS); - LOG.error(r); if(r.isSuccessful()) return id; else diff --git a/src/test/java/org/apache/sysds/test/component/frame/transform/TransformCompressedTestLogger.java b/src/test/java/org/apache/sysds/test/component/frame/transform/TransformCompressedTestLogger.java index 6e2d0351986..1ea9ca344e2 100644 --- a/src/test/java/org/apache/sysds/test/component/frame/transform/TransformCompressedTestLogger.java +++ b/src/test/java/org/apache/sysds/test/component/frame/transform/TransformCompressedTestLogger.java @@ -73,16 +73,28 @@ public void test(String spec) { MultiColumnEncoder encoderCompressed = EncoderFactory.createEncoder(spec, data.getColumnNames(), data.getNumColumns(), meta); MatrixBlock outCompressed = encoderCompressed.encode(data, true); - FrameBlock outCompressedMD = encoderCompressed.getMetaData(null); + // FrameBlock outCompressedMD = encoderCompressed.getMetaData(null); MultiColumnEncoder encoderNormal = EncoderFactory.createEncoder(spec, data.getColumnNames(), data.getNumColumns(), meta); MatrixBlock outNormal = encoderNormal.encode(data); - FrameBlock outNormalMD = encoderNormal.getMetaData(null); + // FrameBlock outNormalMD = encoderNormal.getMetaData(null); final List log = LoggingUtils.reinsert(appender); assertTrue(log.get(3).getMessage().toString().contains("Compression ratio")); TestUtils.compareMatrices(outNormal, outCompressed, 0, "Not Equal after apply"); - TestUtils.compareFrames(outNormalMD, outCompressedMD, true); + + MultiColumnEncoder ec = EncoderFactory.createEncoder(spec, data.getColumnNames(), data.getNumColumns(), + encoderCompressed.getMetaData(null)); + + MatrixBlock outMeta1 = ec.apply(data, 1); + + TestUtils.compareMatrices(outNormal, outMeta1, 0, "Not Equal after apply"); + + MultiColumnEncoder ec2 = EncoderFactory.createEncoder(spec, data.getColumnNames(), data.getNumColumns(), + encoderNormal.getMetaData(null)); + + MatrixBlock outMeta12 = ec2.apply(data, 1); + TestUtils.compareMatrices(outNormal, outMeta12, 0, "Not Equal after apply2"); } catch(Exception e) { e.printStackTrace(); diff --git a/src/test/java/org/apache/sysds/test/component/frame/transform/TransformCompressedTestMultiCol.java b/src/test/java/org/apache/sysds/test/component/frame/transform/TransformCompressedTestMultiCol.java index a592fc74778..af81216412c 100644 --- a/src/test/java/org/apache/sysds/test/component/frame/transform/TransformCompressedTestMultiCol.java +++ b/src/test/java/org/apache/sysds/test/component/frame/transform/TransformCompressedTestMultiCol.java @@ -61,25 +61,25 @@ public static Collection data() { Arrays.fill(kPlusCols, ValueType.BOOLEAN); FrameBlock[] blocks = new FrameBlock[] {// - TestUtils.generateRandomFrameBlock(100, // + TestUtils.generateRandomFrameBlock(16, // new ValueType[] {ValueType.UINT4, ValueType.UINT8, ValueType.UINT4}, 231), // - TestUtils.generateRandomFrameBlock(100, // + TestUtils.generateRandomFrameBlock(10, // new ValueType[] {ValueType.BOOLEAN, ValueType.UINT8, ValueType.UINT4}, 231), // new FrameBlock(new ValueType[] {ValueType.BOOLEAN, ValueType.INT32, ValueType.INT32}, 100), // - TestUtils.generateRandomFrameBlock(100, // + TestUtils.generateRandomFrameBlock(11, // new ValueType[] {ValueType.UINT4, ValueType.BOOLEAN, ValueType.FP32}, 231, 0.2), TestUtils.generateRandomFrameBlock(432, // new ValueType[] {ValueType.UINT4, ValueType.BOOLEAN, ValueType.FP32}, 231, 0.2), - TestUtils.generateRandomFrameBlock(100, // + TestUtils.generateRandomFrameBlock(12, // new ValueType[] {ValueType.UINT4, ValueType.BOOLEAN, ValueType.FP32}, 231, 0.9), - TestUtils.generateRandomFrameBlock(100, // + TestUtils.generateRandomFrameBlock(14, // new ValueType[] {ValueType.UINT4, ValueType.BOOLEAN, ValueType.FP32}, 231, 0.99), TestUtils.generateRandomFrameBlock(5, kPlusCols, 322), TestUtils.generateRandomFrameBlock(1020, kPlusCols, 322), }; - blocks[2].ensureAllocatedColumns(100); + blocks[2].ensureAllocatedColumns(20); for(FrameBlock block : blocks) { for(int k : threads) { @@ -105,6 +105,11 @@ public void testDummyCode() { test("{dummycode:[C1,C2,C3]}"); } + @Test + public void testDummyCodeV2(){ + test("{ids:true, dummycode:[1,2,3]}"); + } + @Test public void testBin() { test( @@ -139,20 +144,29 @@ public void testHashToDummy() { public void test(String spec) { try { - FrameBlock meta = null; MultiColumnEncoder encoderCompressed = EncoderFactory.createEncoder(spec, data.getColumnNames(), data.getNumColumns(), meta); MatrixBlock outCompressed = encoderCompressed.encode(data, k, true); - FrameBlock outCompressedMD = encoderCompressed.getMetaData(null); MultiColumnEncoder encoderNormal = EncoderFactory.createEncoder(spec, data.getColumnNames(), data.getNumColumns(), meta); MatrixBlock outNormal = encoderNormal.encode(data, k); - FrameBlock outNormalMD = encoderNormal.getMetaData(null); TestUtils.compareMatrices(outNormal, outCompressed, 0, "Not Equal after apply"); - TestUtils.compareFrames(outNormalMD, outCompressedMD, true); + + MultiColumnEncoder ec2 = EncoderFactory.createEncoder(spec, data.getColumnNames(), data.getNumColumns(), + encoderNormal.getMetaData(null)); + + MatrixBlock outMeta12 = ec2.apply(data, k); + TestUtils.compareMatrices(outNormal, outMeta12, 0, "Not Equal after apply2"); + + MultiColumnEncoder ec = EncoderFactory.createEncoder(spec, data.getColumnNames(), data.getNumColumns(), + encoderCompressed.getMetaData(null)); + + MatrixBlock outMeta1 = ec.apply(data, k); + TestUtils.compareMatrices(outNormal, outMeta1, 0, "Not Equal after apply"); + } catch(Exception e) { e.printStackTrace(); diff --git a/src/test/java/org/apache/sysds/test/component/frame/transform/TransformCompressedTestSingleCol.java b/src/test/java/org/apache/sysds/test/component/frame/transform/TransformCompressedTestSingleCol.java index 7b37ba1413a..14a552c9934 100644 --- a/src/test/java/org/apache/sysds/test/component/frame/transform/TransformCompressedTestSingleCol.java +++ b/src/test/java/org/apache/sysds/test/component/frame/transform/TransformCompressedTestSingleCol.java @@ -146,17 +146,30 @@ public void test(String spec) { MultiColumnEncoder encoderNormal = EncoderFactory.createEncoder(spec, data.getColumnNames(), data.getNumColumns(), meta); MatrixBlock outNormal = encoderNormal.encode(data, k); - FrameBlock outNormalMD = encoderNormal.getMetaData(null); - + + meta = null; MultiColumnEncoder encoderCompressed = EncoderFactory.createEncoder(spec, data.getColumnNames(), data.getNumColumns(), meta); MatrixBlock outCompressed = encoderCompressed.encode(data, k, true); - FrameBlock outCompressedMD = encoderCompressed.getMetaData(null); - // LOG.error(data.slice(0,10)); - // LOG.error(outNormal.slice(0,10)); - // LOG.error(outCompressed.slice(0,10)); - TestUtils.compareMatrices(outNormal, outCompressed, 0, "Not Equal after apply"); - TestUtils.compareFrames(outNormalMD, outCompressedMD, true); + + TestUtils.compareMatrices(outNormal, outCompressed, 0, "Not Equal after encode"); + + // meta data is allowed to be different but! + // when applied inversely should return the same. + + MultiColumnEncoder ec = EncoderFactory.createEncoder(spec, data.getColumnNames(), data.getNumColumns(), + encoderCompressed.getMetaData(null)); + + MatrixBlock outMeta1 = ec.apply(data, k); + + TestUtils.compareMatrices(outNormal, outMeta1, 0, "Not Equal after apply"); + + MultiColumnEncoder ec2 = EncoderFactory.createEncoder(spec, data.getColumnNames(), data.getNumColumns(), + encoderNormal.getMetaData(null)); + + MatrixBlock outMeta12 = ec2.apply(data, k); + TestUtils.compareMatrices(outNormal, outMeta12, 0, "Not Equal after apply2"); + } catch(Exception e) { e.printStackTrace(); diff --git a/src/test/java/org/apache/sysds/test/component/matrix/SparseCSRTest.java b/src/test/java/org/apache/sysds/test/component/matrix/SparseCSRTest.java index b7070c55b86..bc7c14a3be4 100644 --- a/src/test/java/org/apache/sysds/test/component/matrix/SparseCSRTest.java +++ b/src/test/java/org/apache/sysds/test/component/matrix/SparseCSRTest.java @@ -53,7 +53,6 @@ public void testGTE2Rows() { int[] colInd = new int[] {10, 20, 30, 40, 50, 60, 80, 90, 100}; double[] val = new double[] {1, 1, 1, 1, 1, 1, 1, 1, 1}; SparseBlockCSR b = new SparseBlockCSR(rs, colInd, val, val.length); - LOG.error(b); assertEquals(0, b.posFIndexGTE(1, 0)); assertEquals(0, b.posFIndexGTE(1, 10)); @@ -71,7 +70,6 @@ public void testGTE2RowsNN() { int[] colInd = new int[] {100, 10, 20, 30, 40, 50, 60, 80, 90, 100}; double[] val = new double[] {1, 1, 1, 1, 1, 1, 1, 1, 1, 1}; SparseBlockCSR b = new SparseBlockCSR(rs, colInd, val, val.length); - LOG.error(b); assertEquals(0, b.posFIndexGTE(1, 0)); assertEquals(0, b.posFIndexGTE(1, 10)); diff --git a/src/test/java/org/apache/sysds/test/functions/io/compressed/WriteCompressedTest.java b/src/test/java/org/apache/sysds/test/functions/io/compressed/WriteCompressedTest.java index 53fe03c9f4c..59038afdcf9 100644 --- a/src/test/java/org/apache/sysds/test/functions/io/compressed/WriteCompressedTest.java +++ b/src/test/java/org/apache/sysds/test/functions/io/compressed/WriteCompressedTest.java @@ -20,6 +20,7 @@ package org.apache.sysds.test.functions.io.compressed; import org.apache.sysds.common.Types.ExecMode; +import org.apache.sysds.hops.OptimizerUtils; import org.apache.sysds.runtime.matrix.data.MatrixBlock; import org.apache.sysds.test.AutomatedTestBase; import org.apache.sysds.test.TestConfiguration; @@ -84,7 +85,7 @@ private void runWriteTest(ExecMode platform, int rows, int cols, int min, int ma runTest(null); double sumDML = TestUtils.readDMLScalar(output("sum.scalar")); - MatrixBlock mbr = IOCompressionTestUtils.read(output("out.cla")); + MatrixBlock mbr = IOCompressionTestUtils.read(output("out.cla"), rows, cols, OptimizerUtils.DEFAULT_BLOCKSIZE); TestUtils.compareScalars(sumDML, mbr.sum(), eps); diff --git a/src/test/java/org/apache/sysds/test/functions/transform/TransformApplyUnknownsTest.java b/src/test/java/org/apache/sysds/test/functions/transform/TransformApplyUnknownsTest.java index d5952a2a6b1..9e50de68188 100644 --- a/src/test/java/org/apache/sysds/test/functions/transform/TransformApplyUnknownsTest.java +++ b/src/test/java/org/apache/sysds/test/functions/transform/TransformApplyUnknownsTest.java @@ -19,20 +19,22 @@ package org.apache.sysds.test.functions.transform; -import org.apache.sysds.runtime.transform.encode.MultiColumnEncoder; -import org.junit.Assert; -import org.junit.Test; +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; import org.apache.sysds.common.Types.ValueType; import org.apache.sysds.runtime.DMLRuntimeException; import org.apache.sysds.runtime.frame.data.FrameBlock; import org.apache.sysds.runtime.matrix.data.MatrixBlock; import org.apache.sysds.runtime.transform.encode.EncoderFactory; +import org.apache.sysds.runtime.transform.encode.MultiColumnEncoder; import org.apache.sysds.runtime.util.DataConverter; import org.apache.sysds.test.AutomatedTestBase; import org.apache.sysds.test.TestUtils; +import org.junit.Assert; +import org.junit.Test; -public class TransformApplyUnknownsTest extends AutomatedTestBase -{ +public class TransformApplyUnknownsTest extends AutomatedTestBase { + protected static final Log LOG = LogFactory.getLog(TransformApplyUnknownsTest.class.getName()); private static final int rows = 70; @Override @@ -91,8 +93,10 @@ public void testTransformApplyBinning() { Assert.assertEquals(out.getNumRows(), data2.getNumRows()); Assert.assertEquals(out.getNumColumns(), data2.getNumColumns()); for(int i=-5; i<=rows+5; i++) { - if( i < 1 | i > rows ) - Assert.assertTrue(Double.isNaN(out.quickGetValue(i+5, 0))); + if( i < 1 ) + Assert.assertEquals(1, out.quickGetValue(i+5, 0), 0.0); + else if(i > rows) + Assert.assertEquals(out.quickGetValue(out.getNumRows()-1, 0), out.quickGetValue(i+5, 0), 0.0); else Assert.assertEquals(((i-1)/10+1), out.quickGetValue(i+5, 0), 1e-8); } From 1697c7ba16792831de05c2d6ae08e3aeb2f38ff3 Mon Sep 17 00:00:00 2001 From: Sebastian Baunsgaard Date: Tue, 24 Oct 2023 17:53:23 +0200 Subject: [PATCH 04/28] [MINOR] CSV frame reader refine csv parsing This commit adds a few shortcuts in the CSV parsing to: 1. reduce call time of trim by filtering strings not containing whitespace this is a trade off, that makes it slower for strings with whitespace, and faster for the common case of no white spaces. 2. Specialize the split CSV to a case with a single char delimiter, this simplify the splitting logic. But only implemented for the case of no quotation marks in the line input, since quotations make the rules change for csv parsing. Closes 1932 --- .../sysds/runtime/io/FrameReaderTextCSV.java | 59 ++++++--- .../sysds/runtime/io/IOUtilFunctions.java | 121 ++++++++++++++---- .../util/FastBufferedDataOutputStream.java | 2 +- 3 files changed, 140 insertions(+), 42 deletions(-) diff --git a/src/main/java/org/apache/sysds/runtime/io/FrameReaderTextCSV.java b/src/main/java/org/apache/sysds/runtime/io/FrameReaderTextCSV.java index d8de58f058d..cfe4a5e45ba 100644 --- a/src/main/java/org/apache/sysds/runtime/io/FrameReaderTextCSV.java +++ b/src/main/java/org/apache/sysds/runtime/io/FrameReaderTextCSV.java @@ -144,9 +144,8 @@ protected final int readCSVFrameFromInputSplit(InputSplit split, InputFormat naValues, + boolean isFill, double dfillValue, String sfillValue) { + if(!isFill && naValues == null) + return assignColumnsNoFillNoNan(row, nCol, dest, parts); + else + return assignColumnsGeneric(row, nCol, dest, parts, naValues, isFill, dfillValue, sfillValue); + } + + private boolean assignColumnsGeneric(int row, int nCol, FrameBlock dest, String[] parts, Set naValues, + boolean isFill, double dfillValue, String sfillValue) { + boolean emptyValuesFound = false; + for(int col = 0; col < nCol; col++) { + String part = IOUtilFunctions.trim(parts[col]); + if(part.isEmpty() || (naValues != null && naValues.contains(part))) { + if(isFill && dfillValue != 0) + dest.set(row, col, sfillValue); + emptyValuesFound = true; + } + else + dest.set(row, col, part); + } + + return emptyValuesFound; + } + + private boolean assignColumnsNoFillNoNan(int row, int nCol, FrameBlock dest, String[] parts){ + + boolean emptyValuesFound = false; + for(int col = 0; col < nCol; col++) { + String part = IOUtilFunctions.trim(parts[col]); + if(part.isEmpty()) + emptyValuesFound = true; + else + dest.set(row, col, part); + } + + return emptyValuesFound; + } + + protected Pair computeCSVSize(Path path, JobConf job, FileSystem fs) throws IOException { TextInputFormat informat = new TextInputFormat(); informat.configure(job); diff --git a/src/main/java/org/apache/sysds/runtime/io/IOUtilFunctions.java b/src/main/java/org/apache/sysds/runtime/io/IOUtilFunctions.java index c188928ae00..9ce18d11b88 100644 --- a/src/main/java/org/apache/sysds/runtime/io/IOUtilFunctions.java +++ b/src/main/java/org/apache/sysds/runtime/io/IOUtilFunctions.java @@ -222,7 +222,7 @@ public static String[] splitCSV(String str, String delim){ final ArrayList tokens = new ArrayList<>(); while(from < len) { // for all tokens - to = getTo(str, from, delim); + to = getTo(str, from, delim, len, delimLen); tokens.add(str.substring(from, to)); from = to + delimLen; } @@ -257,28 +257,60 @@ else if(empty) { return cache; } else - return splitCSVNonNullWithCache(str,delim,cache); + return splitCSVNonNullWithCache(str, delim, cache); } private static String[] splitCSVNonNullWithCache(final String str, final String delim, final String[] cache) { + final int len = str.length(); final int delimLen = delim.length(); - final boolean containsQuotationMarks = str.contains("\""); + + if(str.contains("\"")) + return splitCSVNonNullWithCacheWithQuote(str, delim, cache, len, delimLen); + else if(delimLen == 1) + return splitCSVNonNullCacheNoQuoteCharDelim(str, delim.charAt(0), cache, len); + else + return splitCSVNonNullCacheNoQuote(str, delim, cache, len, delimLen); + } + + private static String[] splitCSVNonNullWithCacheWithQuote(final String str, final String delim, + final String[] cache, final int len, final int delimLen) { int from = 0; int id = 0; - if(containsQuotationMarks){ - while(from < len) { // for all tokens - final int to = getTo(str, from, delim); - cache[id++] = str.substring(from, to); - from = to + delimLen; - } + while(from < len) { // for all tokens + final int to = getTo(str, from, delim, len, delimLen); + cache[id++] = str.substring(from, to); + from = to + delimLen; } - else{ - while(from < len) { // for all tokens - final int to = getToNoQuote(str, from, delim); - cache[id++] = str.substring(from, to); - from = to + delimLen; - } + + if(from == len) + cache[id] = ""; + return cache; + } + + private static String[] splitCSVNonNullCacheNoQuote(final String str, final String delim, final String[] cache,final int len, final int delimLen) { + int from = 0; + int id = 0; + + while(from < len) { // for all tokens + final int to = getToNoQuote(str, from, delim, len, delimLen); + cache[id++] = str.substring(from, to); + from = to + delimLen; + } + + if(from == len) + cache[id] = ""; + return cache; + } + + private static String[] splitCSVNonNullCacheNoQuoteCharDelim(final String str, final char delim, + final String[] cache, final int len) { + int from = 0; + int id = 0; + while(from < len) { // for all tokens + final int to = getToNoQuoteCharDelim(str, from, delim, len); + cache[id++] = str.substring(from, to); + from = to + 1; } if(from == len) @@ -296,9 +328,18 @@ private static boolean isEmptyMatch(final String str, final int from, final Stri return true; } - private static int getTo(final String str, final int from, final String delim) { - final int len = str.length(); - final int dLen = delim.length(); + /** + * Get next index of substring after delim, while the string can contain Quotation marks + * + * @param str The string to get the index from + * @param from The index to start searching from + * @param delim The delimiter to find + * @param len The length of the str string argument + * @param dLen The length of the delimiter string + * @return The next index. + */ + private static int getTo(final String str, final int from, final String delim, + final int len, final int dLen) { final char cq = CSV_QUOTE_CHAR; final int fromP1 = from + 1; int to; @@ -322,12 +363,21 @@ else if(isEmptyMatch(str, from, delim, dLen, len)) return to >= 0 ? to : len; } - private static int getToNoQuote(final String str, final int from, final String delim) { - final int len = str.length(); - final int dLen = delim.length(); - final int fromP1 = from + 1; + /** + * Get next index of substring after delim + * + * @param str The string to get the index from + * @param from The index to start searching from + * @param delim The delimiter to find + * @param len The length of the str string argument + * @param dLen The length of the delimiter string + * @return The next index. + */ + private static int getToNoQuote(final String str, final int from, final String delim, final int len, + final int dLen) { + int to; - + final int fromP1 = from + 1; if(isEmptyMatch(str, from, delim, dLen, len)) return to = from; // empty string else // default: unquoted non-empty @@ -335,10 +385,29 @@ private static int getToNoQuote(final String str, final int from, final String d // slice out token and advance position return to >= 0 ? to : len; + + } + + private static int getToNoQuoteCharDelim(final String str, final int from, final char delim, final int len){ + for(int i = from; i < len; i++) + if(str.charAt(i) == delim) + return i; + return len; } public static String trim(String str) { - return str.trim(); + try{ + final int len = str.length(); + if(len == 0) + return str; + // short the call to return input if not whitespace in ends. + else if(str.charAt(0) <= ' ' || str.charAt(len -1) <= ' ') + return str.trim(); + else + return str; + }catch(Exception e){ + throw new RuntimeException("failed trimming: " + str + " " + str.length(),e); + } } /** @@ -366,7 +435,7 @@ else if (naStrings == null) int from = 0; int pos = 0; while( from < len ) { // for all tokens - final int to = getTo(str, from, delim); + final int to = getTo(str, from, delim, len, dLen); final String curString = str.substring(from, to); tokens[pos++] = naStrings.contains(curString) ? null : curString; from = to + dLen; @@ -401,7 +470,7 @@ public static int countTokensCSV(String str, String delim) int numTokens = 0; int from = 0; while( from < len ) { // for all tokens - int to = getTo(str, from, delim); + int to = getTo(str, from, delim, len, dlen); from = to + dlen; numTokens++; } diff --git a/src/main/java/org/apache/sysds/runtime/util/FastBufferedDataOutputStream.java b/src/main/java/org/apache/sysds/runtime/util/FastBufferedDataOutputStream.java index adf9f0abd54..1804bc78e04 100644 --- a/src/main/java/org/apache/sysds/runtime/util/FastBufferedDataOutputStream.java +++ b/src/main/java/org/apache/sysds/runtime/util/FastBufferedDataOutputStream.java @@ -191,7 +191,7 @@ public void writeUTF(String s) throws IOException { for( int i=0; i _bufflen) flushBuffer(); - char c = s.charAt(i); + final char c = s.charAt(i); if( c>= 0x0001 && c<=0x007F ) //1 byte range _buff[_count++] = (byte) c; else if( c>=0x0800 ) { //3 byte range From 351828d6184c234e5ffa10279ad7c370834b59e5 Mon Sep 17 00:00:00 2001 From: Sebastian Baunsgaard Date: Thu, 19 Oct 2023 13:37:26 +0200 Subject: [PATCH 05/28] [MINOR] DML Startup At startup the first thing we do is to call Hadoop to parse the Hadoop specific arguments. This takes ~ 200 ms at startup before we start our timing of SystemDS. The script: 'print("Hello, World!")' Before the change it ran 1,6187 sec on my laptop and 1.6764 on a scale out cluster node. With this commit change, it speeds up to: 1,4366 on the laptop and 1.519 on a scale out cluster node. Closes #1926 --- .../java/org/apache/sysds/api/DMLScript.java | 12 ++++-------- .../org/apache/sysds/test/AutomatedTestBase.java | 16 ++++++---------- 2 files changed, 10 insertions(+), 18 deletions(-) diff --git a/src/main/java/org/apache/sysds/api/DMLScript.java b/src/main/java/org/apache/sysds/api/DMLScript.java index bf638dfcf77..aa680a97f37 100644 --- a/src/main/java/org/apache/sysds/api/DMLScript.java +++ b/src/main/java/org/apache/sysds/api/DMLScript.java @@ -41,10 +41,8 @@ import org.apache.commons.lang3.StringUtils; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; -import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.fs.FileSystem; import org.apache.hadoop.fs.Path; -import org.apache.hadoop.util.GenericOptionsParser; import org.apache.sysds.common.Types.ExecMode; import org.apache.sysds.conf.CompilerConfig; import org.apache.sysds.conf.ConfigurationManager; @@ -204,16 +202,15 @@ public static boolean isActiveAM(){ public static void main(String[] args) { try{ - Configuration conf = new Configuration(ConfigurationManager.getCachedJobConf()); - String[] otherArgs = new GenericOptionsParser(conf, args).getRemainingArgs(); - DMLScript.executeScript(conf, otherArgs); + DMLScript.executeScript(args); } catch(Exception e){ - errorPrint(e); for(String s: args){ if(s.trim().contains("-debug")){ e.printStackTrace(); + return; } } + errorPrint(e); } } @@ -221,12 +218,11 @@ public static void main(String[] args) * Single entry point for all public invocation alternatives (e.g., * main, executeScript, JaqlUdf etc) * - * @param conf Hadoop configuration * @param args arguments * @return true if success, false otherwise * @throws IOException If an internal IOException happens. */ - public static boolean executeScript( Configuration conf, String[] args ) + public static boolean executeScript( String[] args ) throws IOException, ParseException, DMLScriptException { //parse arguments and set execution properties diff --git a/src/test/java/org/apache/sysds/test/AutomatedTestBase.java b/src/test/java/org/apache/sysds/test/AutomatedTestBase.java index 354fa12febb..f63fbb987a0 100644 --- a/src/test/java/org/apache/sysds/test/AutomatedTestBase.java +++ b/src/test/java/org/apache/sysds/test/AutomatedTestBase.java @@ -19,6 +19,11 @@ package org.apache.sysds.test; +import static java.lang.Math.ceil; +import static java.lang.Thread.sleep; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.fail; + import java.io.ByteArrayOutputStream; import java.io.File; import java.io.IOException; @@ -38,18 +43,12 @@ import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeoutException; -import static java.lang.Math.ceil; -import static java.lang.Thread.sleep; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.fail; import org.apache.commons.io.FileUtils; import org.apache.commons.io.IOUtils; import org.apache.commons.lang3.ArrayUtils; import org.apache.commons.lang3.tuple.Pair; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; -import org.apache.hadoop.conf.Configuration; -import org.apache.hadoop.util.GenericOptionsParser; import org.apache.spark.sql.SparkSession; import org.apache.spark.sql.SparkSession.Builder; import org.apache.sysds.api.DMLScript; @@ -59,7 +58,6 @@ import org.apache.sysds.common.Types.ExecType; import org.apache.sysds.common.Types.FileFormat; import org.apache.sysds.common.Types.ValueType; -import org.apache.sysds.conf.ConfigurationManager; import org.apache.sysds.conf.DMLConfig; import org.apache.sysds.hops.OptimizerUtils; import org.apache.sysds.hops.fedplanner.FTypes.FType; @@ -1568,9 +1566,7 @@ private ByteArrayOutputStream runTestWithTimeout(boolean newWay, boolean excepti * @throws IOException if an IOException occurs in the hadoop GenericOptionsParser */ public static void main(String[] args) throws IOException, ParseException, DMLScriptException { - Configuration conf = new Configuration(ConfigurationManager.getCachedJobConf()); - String[] otherArgs = new GenericOptionsParser(conf, args).getRemainingArgs(); - DMLScript.executeScript(conf, otherArgs); + DMLScript.executeScript(args); } private void addProgramIndependentArguments(ArrayList args, String[] otherArgs) { From 7561f61a14dc1097e3bfcfee497a90451b4564f1 Mon Sep 17 00:00:00 2001 From: Sebastian Baunsgaard Date: Wed, 25 Oct 2023 10:38:02 +0200 Subject: [PATCH 06/28] [SYSTEMDS-3640] Hash Column This commit adds a new value type HASH64 for that can contain hashes of 16 hex encoded characters. It behaves internally as if it is a string column, but allocate a single long value per cell. This reduce the allocation of columns with hash values from 40+ byte per value to 8 byte. Closes #1933 --- .../java/org/apache/sysds/common/Types.java | 17 +- .../runtime/compress/colgroup/APreAgg.java | 2 +- .../runtime/compress/lib/CLALibScalar.java | 2 +- .../runtime/frame/data/columns/Array.java | 11 + .../frame/data/columns/ArrayFactory.java | 33 +- .../frame/data/columns/BitSetArray.java | 8 + .../frame/data/columns/BooleanArray.java | 8 + .../runtime/frame/data/columns/CharArray.java | 8 + .../runtime/frame/data/columns/DDCArray.java | 5 + .../frame/data/columns/DoubleArray.java | 11 + .../frame/data/columns/FloatArray.java | 8 + .../frame/data/columns/HashLongArray.java | 414 ++++++++++++++++++ .../frame/data/columns/IntegerArray.java | 8 + .../runtime/frame/data/columns/LongArray.java | 5 + .../frame/data/columns/OptionalArray.java | 17 + .../frame/data/columns/RaggedArray.java | 5 + .../frame/data/columns/StringArray.java | 31 +- .../frame/data/lib/FrameLibApplySchema.java | 1 + .../runtime/frame/data/lib/FrameUtil.java | 20 +- .../sysds/runtime/util/UtilFunctions.java | 16 +- .../java/org/apache/sysds/test/TestUtils.java | 1 + .../frame/array/CustomArrayTests.java | 55 ++- .../frame/array/FrameArrayConstantTests.java | 2 + .../frame/array/FrameArrayTests.java | 159 ++++++- .../frame/iterators/IteratorTest.java | 37 +- 25 files changed, 835 insertions(+), 49 deletions(-) create mode 100644 src/main/java/org/apache/sysds/runtime/frame/data/columns/HashLongArray.java diff --git a/src/main/java/org/apache/sysds/common/Types.java b/src/main/java/org/apache/sysds/common/Types.java index 4b8f1c3a006..84019e8078c 100644 --- a/src/main/java/org/apache/sysds/common/Types.java +++ b/src/main/java/org/apache/sysds/common/Types.java @@ -77,17 +77,21 @@ public boolean isUnknown() { public enum ValueType { UINT4, UINT8, // Used for parsing in UINT values from numpy. FP32, FP64, INT32, INT64, BOOLEAN, STRING, UNKNOWN, + HASH64, // Indicate that the value is a hash of 64 bit. CHARACTER; public boolean isNumeric() { return this == UINT8 || this == INT32 || this == INT64 || this == FP32 || this == FP64 || this== UINT4; } + public boolean isUnknown() { return this == UNKNOWN; } + public boolean isPseudoNumeric() { return isNumeric() || this == BOOLEAN || this == CHARACTER; } + public String toExternalString() { switch(this) { case FP32: @@ -100,10 +104,13 @@ public String toExternalString() { default: return toString(); } } + public static ValueType fromExternalString(String value) { //for now we support both internal and external strings //until we have completely changed the external types - String lValue = (value != null) ? value.toUpperCase() : null; + if(value == null) + throw new DMLRuntimeException("Unknown null value type"); + final String lValue = value.toUpperCase(); switch(lValue) { case "FP32": return FP32; case "FP64": @@ -117,6 +124,7 @@ public static ValueType fromExternalString(String value) { case "STRING": return STRING; case "CHARACTER": return CHARACTER; case "UNKNOWN": return UNKNOWN; + case "HASH64": return HASH64; default: throw new DMLRuntimeException("Unknown value type: "+value); } @@ -143,6 +151,13 @@ else if(b == UNKNOWN) switch(a){ case CHARACTER: return STRING; + case HASH64: + switch(b){ + case STRING: + return b; + default: + return a; + } case STRING: return a; case FP64: diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/APreAgg.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/APreAgg.java index 655bfc496fa..17f210865be 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/APreAgg.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/APreAgg.java @@ -154,7 +154,7 @@ private void tsmmAPreAgg(APreAgg lg, MatrixBlock result) { final boolean left = shouldPreAggregateLeft(lg); if(!loggedWarningForDirect && shouldDirectMultiply(lg, leftIdx.size(), rightIdx.size(), left)) { loggedWarningForDirect = true; - LOG.warn("Not implemented direct tsmm colgroup"); + LOG.warn("Not implemented direct tsmm colgroup: " + lg.getClass().getSimpleName() + " %*% " + this.getClass().getSimpleName() ); } if(left) { diff --git a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibScalar.java b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibScalar.java index 0da3f2d9690..3dea7f577a9 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibScalar.java +++ b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibScalar.java @@ -58,7 +58,7 @@ private CLALibScalar() { public static MatrixBlock scalarOperations(ScalarOperator sop, CompressedMatrixBlock m1, MatrixValue result) { if(isInvalidForCompressedOutput(m1, sop)) { - LOG.warn("scalar overlapping not supported for op: " + sop.fn); + LOG.warn("scalar overlapping not supported for op: " + sop.fn.getClass().getSimpleName()); MatrixBlock m1d = m1.decompress(sop.getNumThreads()); return m1d.scalarOperations(sop, result); } diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/columns/Array.java b/src/main/java/org/apache/sysds/runtime/frame/data/columns/Array.java index 874364255f3..11accc814bf 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/columns/Array.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/columns/Array.java @@ -423,6 +423,8 @@ public Array changeTypeWithNulls(ValueType t) { case UINT4: case UINT8: throw new NotImplementedException(); + case HASH64: + return new OptionalArray<>(changeTypeHash64(), nulls); case INT32: return new OptionalArray<>(changeTypeInteger(), nulls); case INT64: @@ -457,6 +459,8 @@ public final Array changeType(ValueType t) { case UINT4: case UINT8: throw new NotImplementedException(); + case HASH64: + return changeTypeHash64(); case INT32: return changeTypeInteger(); case INT64: @@ -513,6 +517,13 @@ public final Array changeType(ValueType t) { */ protected abstract Array changeTypeLong(); + /** + * Change type to a Hash46 array type + * + * @return A Hash64 array + */ + protected abstract Array changeTypeHash64(); + /** * Change type to a String array type * diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/columns/ArrayFactory.java b/src/main/java/org/apache/sysds/runtime/frame/data/columns/ArrayFactory.java index 12ca401c6b9..2fd6a74837b 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/columns/ArrayFactory.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/columns/ArrayFactory.java @@ -23,6 +23,7 @@ import java.io.IOException; import java.util.BitSet; +import org.apache.commons.lang3.NotImplementedException; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.apache.sysds.common.Types.ValueType; @@ -35,13 +36,27 @@ public interface ArrayFactory { public final static int bitSetSwitchPoint = 64; public enum FrameArrayType { - STRING, BOOLEAN, BITSET, INT32, INT64, FP32, FP64, CHARACTER, RAGGED, OPTIONAL, DDC; + STRING, BOOLEAN, BITSET, INT32, INT64, FP32, FP64, + CHARACTER, RAGGED, OPTIONAL, DDC, + HASH64; } public static StringArray create(String[] col) { return new StringArray(col); } + public static HashLongArray createHash64(String[] col){ + return new HashLongArray(col); + } + + public static OptionalArray createHash64Opt(String[] col){ + return new OptionalArray(col, ValueType.HASH64); + } + + public static HashLongArray createHash64(long[] col){ + return new HashLongArray(col); + } + public static BooleanArray create(boolean[] col) { return new BooleanArray(col); } @@ -81,6 +96,8 @@ public static RaggedArray create(T[] col, int m) { public static long getInMemorySize(ValueType type, int _numRows, boolean containsNull) { if(containsNull) { switch(type) { + case HASH64: + type = ValueType.INT64; case BOOLEAN: case INT64: case FP64: @@ -108,6 +125,7 @@ public static long getInMemorySize(ValueType type, int _numRows, boolean contain else return BooleanArray.estimateInMemorySize(_numRows); case INT64: + case HASH64: return Array.baseMemoryCost() + (long) MemoryEstimates.longArrayCost(_numRows); case FP64: return Array.baseMemoryCost() + (long) MemoryEstimates.doubleArrayCost(_numRows); @@ -154,6 +172,8 @@ public static Array allocateOptional(ValueType v, int nRow) { return new OptionalArray<>(new DoubleArray(new double[nRow]), true); case CHARACTER: return new OptionalArray<>(new CharArray(new char[nRow]), true); + case HASH64: + return new OptionalArray<>(new HashLongArray(new long[nRow]), true); case UNKNOWN: case STRING: default: @@ -184,6 +204,8 @@ public static Array allocate(ValueType v, int nRow) { return new DoubleArray(new double[nRow]); case CHARACTER: return new CharArray(new char[nRow]); + case HASH64: + return new HashLongArray(new long[nRow]); case UNKNOWN: case STRING: default: @@ -222,9 +244,14 @@ public static Array read(DataInput in, int nRow) throws IOException { return OptionalArray.readOpt(in, nRow); case DDC: return DDCArray.read(in); - default: // String + case STRING: arr = new StringArray(new String[nRow]); break; + case HASH64: + arr = new HashLongArray(new long[nRow]); + break; + default: + throw new NotImplementedException(v + ""); } arr.readFields(in); return arr; @@ -325,6 +352,8 @@ public static Object parseString(String s, ValueType v) { return IntegerArray.parseInt(s); case INT64: return LongArray.parseLong(s); + case HASH64: + return HashLongArray.parseHashLong(s); case STRING: case UNKNOWN: default: diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/columns/BitSetArray.java b/src/main/java/org/apache/sysds/runtime/frame/data/columns/BitSetArray.java index dbd5d7328c3..710d8a8debe 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/columns/BitSetArray.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/columns/BitSetArray.java @@ -465,6 +465,14 @@ protected Array changeTypeLong() { return new LongArray(ret); } + @Override + protected Array changeTypeHash64(){ + long[] ret = new long[size()]; + for(int i = 0; i < size(); i++) + ret[i] = get(i) ? 1L : 0L; + return new HashLongArray(ret); + } + @Override protected Array changeTypeString() { String[] ret = new String[size()]; diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/columns/BooleanArray.java b/src/main/java/org/apache/sysds/runtime/frame/data/columns/BooleanArray.java index da874555d33..b44845bc34c 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/columns/BooleanArray.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/columns/BooleanArray.java @@ -265,6 +265,14 @@ protected Array changeTypeLong() { return new LongArray(ret); } + @Override + protected Array changeTypeHash64(){ + long[] ret = new long[size()]; + for(int i = 0; i < size(); i++) + ret[i] = _data[i] ? 1L : 0L; + return new HashLongArray(ret); + } + @Override protected Array changeTypeString() { String[] ret = new String[size()]; diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/columns/CharArray.java b/src/main/java/org/apache/sysds/runtime/frame/data/columns/CharArray.java index 9862974ad77..14fcfd9f692 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/columns/CharArray.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/columns/CharArray.java @@ -253,6 +253,14 @@ protected Array changeTypeLong() { return new LongArray(ret); } + @Override + protected Array changeTypeHash64(){ + long[] ret = new long[size()]; + for(int i = 0; i < size(); i++) + ret[i] = _data[i]; + return new HashLongArray(ret); + } + @Override protected Array changeTypeString() { String[] ret = new String[size()]; diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/columns/DDCArray.java b/src/main/java/org/apache/sysds/runtime/frame/data/columns/DDCArray.java index 4ddc3e4367c..b634cfe6ff3 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/columns/DDCArray.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/columns/DDCArray.java @@ -231,6 +231,11 @@ protected Array changeTypeLong() { return new DDCArray<>(dict.changeTypeLong(), map); } + @Override + protected Array changeTypeHash64(){ + return new DDCArray<>(dict.changeTypeHash64(), map); + } + @Override protected Array changeTypeString() { return new DDCArray<>(dict.changeTypeString(), map); diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/columns/DoubleArray.java b/src/main/java/org/apache/sysds/runtime/frame/data/columns/DoubleArray.java index 754748a28b3..e4e1a76b6ac 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/columns/DoubleArray.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/columns/DoubleArray.java @@ -312,6 +312,17 @@ protected Array changeTypeLong() { return new LongArray(ret); } + @Override + protected Array changeTypeHash64() { + long[] ret = new long[size()]; + for(int i = 0; i < size(); i++) { + if(_data[i] != (long) _data[i]) + throw new DMLRuntimeException("Unable to change to Long from Double array because of value:" + _data[i]); + ret[i] = (long) _data[i]; + } + return new HashLongArray(ret); + } + @Override protected Array changeTypeString() { String[] ret = new String[size()]; diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/columns/FloatArray.java b/src/main/java/org/apache/sysds/runtime/frame/data/columns/FloatArray.java index 51d29b167db..47627894d92 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/columns/FloatArray.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/columns/FloatArray.java @@ -253,6 +253,14 @@ protected Array changeTypeLong() { return new LongArray(ret); } + @Override + protected Array changeTypeHash64() { + long[] ret = new long[size()]; + for(int i = 0; i < size(); i++) + ret[i] = (int) _data[i]; + return new HashLongArray(ret); + } + @Override protected Array changeTypeFloat() { return this; diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/columns/HashLongArray.java b/src/main/java/org/apache/sysds/runtime/frame/data/columns/HashLongArray.java new file mode 100644 index 00000000000..506c5d435f4 --- /dev/null +++ b/src/main/java/org/apache/sysds/runtime/frame/data/columns/HashLongArray.java @@ -0,0 +1,414 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.runtime.frame.data.columns; + +import java.io.DataInput; +import java.io.DataOutput; +import java.io.IOException; +import java.util.Arrays; +import java.util.BitSet; + +import org.apache.commons.lang3.NotImplementedException; +import org.apache.sysds.common.Types.ValueType; +import org.apache.sysds.runtime.DMLRuntimeException; +import org.apache.sysds.runtime.frame.data.columns.ArrayFactory.FrameArrayType; +import org.apache.sysds.runtime.matrix.data.Pair; +import org.apache.sysds.utils.MemoryEstimates; + +public class HashLongArray extends Array { + private long[] _data; + + public HashLongArray(long[] data) { + super(data.length); + _data = data; + } + + public HashLongArray(String[] data) { + super(data.length); + _data = new long[data.length]; + for(int i = 0; i < data.length; i++) { + _data[i] = parseHashLong(data[i]); + } + } + + @Override + public Object get() { + throw new NotImplementedException("Invalid to get underlying array in Hash"); + } + + @Override + public Object get(int index) { + return Long.toHexString(_data[index]); + } + + public long getLong(int index) { + return _data[index]; + } + + @Override + public void set(int index, Object value) { + if(value instanceof String) + _data[index] = parseHashLong((String) value); + else if(value instanceof Long) + _data[index] = (Long) value; + else if (value == null) + _data[index] = 0L; + else + throw new NotImplementedException("not supported : " + value); + } + + @Override + public void set(int index, String value) { + _data[index] = parseHashLong((String) value); + } + + @Override + public void set(int index, double value) { + _data[index] = (long) value; + } + + @Override + public void set(int rl, int ru, Array value) { + set(rl, ru, value, 0); + } + + @Override + public void setFromOtherType(int rl, int ru, Array value) { + for(int i = rl; i <= ru; i++) + _data[i] = parseHashLong(value.get(i)); + } + + @Override + public void setNz(int rl, int ru, Array value) { + if(value instanceof HashLongArray) { + long[] thatVals = ((HashLongArray) value)._data; + for(int i = rl; i <= ru; i++) + if(thatVals[i] != 0) + _data[i] = thatVals[i]; + } + else { + throw new NotImplementedException("Not supported type of array: " + value.getClass().getSimpleName()); + } + } + + @Override + public void setFromOtherTypeNz(int rl, int ru, Array value) { + if(value instanceof HashLongArray) + setNz(rl, ru, (HashLongArray) value); + else if(value instanceof StringArray) { + StringArray st = ((StringArray) value); + for(int i = rl; i <= ru; i++) + if(st.get(i) != null) + _data[i] = parseHashLong(st.get(i)); + } + else { + throw new NotImplementedException("Not supported type of array: " + value.getClass().getSimpleName()); + } + } + + @Override + public void append(Object value) { + append(parseHashLong(value)); + } + + @Override + public void append(String value) { + append(parseHashLong(value)); + } + + public void append(long value) { + if(_data.length <= _size) + _data = Arrays.copyOf(_data, newSize()); + _data[_size++] = value; + } + + @Override + public Array append(Array other) { + if(other instanceof HashLongArray) { + + final int endSize = this._size + other.size(); + final long[] ret = new long[endSize]; + System.arraycopy(_data, 0, ret, 0, this._size); + System.arraycopy(((HashLongArray) other)._data, 0, ret, this._size, other.size()); + if(other instanceof OptionalArray) + return OptionalArray.appendOther((OptionalArray) other, new HashLongArray(ret)); + else + return new HashLongArray(ret); + } + else if(other instanceof OptionalArray) { + + OptionalArray ot = (OptionalArray) other; + if(ot._a instanceof HashLongArray) { + Array a = this.append((HashLongArray) ot._a); + return OptionalArray.appendOther(ot, a); + } + else { + throw new NotImplementedException("Invalid call with not hashArray"); + } + } + else { + throw new NotImplementedException(other.getClass().getSimpleName() + " not append supported in hashColumn"); + } + } + + @Override + public void write(DataOutput out) throws IOException { + out.writeByte(FrameArrayType.HASH64.ordinal()); + for(int i = 0; i < _size; i++) + out.writeLong(_data[i]); + } + + @Override + public void readFields(DataInput in) throws IOException { + _size = _data.length; + for(int i = 0; i < _size; i++) + _data[i] = in.readLong(); + } + + @Override + public Array clone() { + return new HashLongArray(Arrays.copyOf(_data, _size)); + } + + @Override + public Array slice(int rl, int ru) { + return new HashLongArray(Arrays.copyOfRange(_data, rl, ru)); + } + + @Override + public void reset(int size) { + if(_data.length < size || _data.length > 2 * size) + _data = new long[size]; + else + for(int i = 0; i < size; i++) + _data[i] = 0; + _size = size; + } + + @Override + public byte[] getAsByteArray() { + throw new NotImplementedException("Unclear how this byte array should look like for Hash"); + } + + @Override + public ValueType getValueType() { + return ValueType.HASH64; + } + + @Override + public Pair analyzeValueType() { + return new Pair<>(ValueType.HASH64, false); + } + + @Override + public FrameArrayType getFrameArrayType() { + return FrameArrayType.HASH64; + } + + @Override + public long getInMemorySize() { + long size = super.getInMemorySize(); // object header + object reference + size += MemoryEstimates.longArrayCost(_data.length); + return size; + } + + @Override + public long getExactSerializedSize() { + return 1 + 8 * _data.length; + } + + @Override + protected Array changeTypeBitSet() { + BitSet ret = new BitSet(size()); + for(int i = 0; i < size(); i++) { + if(_data[i] != 0 && _data[i] != 1) + throw new DMLRuntimeException( + "Unable to change to Boolean from Integer array because of value:" + _data[i]); + ret.set(i, _data[i] == 0 ? false : true); + } + return new BitSetArray(ret, size()); + } + + @Override + protected Array changeTypeBoolean() { + boolean[] ret = new boolean[size()]; + for(int i = 0; i < size(); i++) { + if(_data[i] < 0 || _data[i] > 1) + throw new DMLRuntimeException( + "Unable to change to Boolean from Integer array because of value:" + _data[i]); + ret[i] = _data[i] == 0 ? false : true; + } + return new BooleanArray(ret); + } + + @Override + protected Array changeTypeDouble() { + double[] ret = new double[size()]; + for(int i = 0; i < size(); i++) + ret[i] = _data[i]; + return new DoubleArray(ret); + } + + @Override + protected Array changeTypeFloat() { + float[] ret = new float[size()]; + for(int i = 0; i < size(); i++) + ret[i] = _data[i]; + return new FloatArray(ret); + } + + @Override + protected Array changeTypeInteger() { + int[] ret = new int[size()]; + for(int i = 0; i < size(); i++) { + if(Math.abs(_data[i]) > Integer.MAX_VALUE) + throw new DMLRuntimeException("Unable to change to integer from long array because of value:" + _data[i]); + ret[i] = (int) _data[i]; + } + return new IntegerArray(ret); + } + + @Override + protected Array changeTypeLong() { + return new LongArray(_data); + } + + @Override + protected Array changeTypeHash64() { + return this; + } + + @Override + protected Array changeTypeString() { + String[] ret = new String[size()]; + for(int i = 0; i < size(); i++) + ret[i] = get(i).toString(); + return new StringArray(ret); + } + + @Override + public void fill(String value) { + fill(parseHashLong(value)); + } + + @Override + public void fill(Object value) { + fill(parseHashLong(value)); + } + + public void fill(Long value) { + Arrays.fill(_data, value != null ? value : 0L); + } + + @Override + public double getAsDouble(int i) { + return _data[i]; + } + + public static long parseHashLong(Object s) { + if(s == null) + return 0L; + else if(s instanceof String) + return parseHashLong((String) s); + else if(s instanceof Long) + return (Long) s; + else + throw new NotImplementedException("not supported" + s); + } + + public static long parseHashLong(String s) { + if(s == null || s.isEmpty()) + return 0L; + return Long.parseUnsignedLong(s, 16); + } + + @Override + public Array changeTypeCharacter() { + char[] ret = new char[size()]; + for(int i = 0; i < size(); i++) + ret[i] = get(i).toString().charAt(0); + return new CharArray(ret); + } + + @Override + public boolean isShallowSerialize() { + return true; + } + + @Override + public boolean isEmpty() { + for(int i = 0; i < _data.length; i++) + if(_data[i] != 0L) + return false; + return true; + } + + @Override + public Array select(int[] indices) { + final long[] ret = new long[indices.length]; + for(int i = 0; i < indices.length; i++) + ret[i] = _data[indices[i]]; + return new HashLongArray(ret); + } + + @Override + public Array select(boolean[] select, int nTrue) { + final long[] ret = new long[nTrue]; + int k = 0; + for(int i = 0; i < select.length; i++) + if(select[i]) + ret[k++] = _data[i]; + return new HashLongArray(ret); + } + + @Override + public final boolean isNotEmpty(int i) { + return _data[i] != 0; + } + + @Override + public double hashDouble(int idx) { + return Long.hashCode(_data[idx]); + } + + @Override + public boolean equals(Array other) { + if(other instanceof HashLongArray) + return Arrays.equals(_data, ((HashLongArray) other)._data); + else + return false; + } + + @Override + public boolean possiblyContainsNaN() { + return false; + } + + @Override + public String toString() { + StringBuilder sb = new StringBuilder(_data.length * 5 + 2); + sb.append(super.toString() + ":["); + for(int i = 0; i < _size - 1; i++) + sb.append(_data[i] + ","); + sb.append(_data[_size - 1]); + sb.append("]"); + return sb.toString(); + } +} diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/columns/IntegerArray.java b/src/main/java/org/apache/sysds/runtime/frame/data/columns/IntegerArray.java index df60803ddad..4a180e264ce 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/columns/IntegerArray.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/columns/IntegerArray.java @@ -255,6 +255,14 @@ protected Array changeTypeLong() { return new LongArray(ret); } + @Override + protected Array changeTypeHash64() { + long[] ret = new long[size()]; + for(int i = 0; i < size(); i++) + ret[i] = _data[i]; + return new HashLongArray(ret); + } + @Override protected Array changeTypeString() { String[] ret = new String[size()]; diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/columns/LongArray.java b/src/main/java/org/apache/sysds/runtime/frame/data/columns/LongArray.java index c1e0fe06c9b..4d90190f672 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/columns/LongArray.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/columns/LongArray.java @@ -258,6 +258,11 @@ protected Array changeTypeLong() { return this; } + @Override + protected Array changeTypeHash64() { + return new HashLongArray(_data); + } + @Override protected Array changeTypeString() { String[] ret = new String[size()]; diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/columns/OptionalArray.java b/src/main/java/org/apache/sysds/runtime/frame/data/columns/OptionalArray.java index 99444015d43..6699f1050aa 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/columns/OptionalArray.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/columns/OptionalArray.java @@ -63,6 +63,17 @@ else if(a instanceof Character[]) // Character } } + @SuppressWarnings("unchecked") + public OptionalArray(T[] a, ValueType vt){ + super(a.length); + _a = (Array) ArrayFactory.allocate(vt, a.length); + _n = ArrayFactory.allocateBoolean(a.length); + for(int i = 0; i < a.length; i++) { + _a.set(i, a[i]); + _n.set(i, a[i] != null); + } + } + public OptionalArray(Array a, boolean empty) { super(a.size()); if(a instanceof OptionalArray) @@ -342,6 +353,12 @@ protected Array changeTypeLong() { return new OptionalArray<>(a, _n); } + @Override + protected Array changeTypeHash64() { + Array a = _a.changeTypeHash64(); + return new OptionalArray<>(a, _n); + } + @Override protected Array changeTypeCharacter() { Array a = _a.changeTypeCharacter(); diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/columns/RaggedArray.java b/src/main/java/org/apache/sysds/runtime/frame/data/columns/RaggedArray.java index a63026b1484..94a30f4980e 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/columns/RaggedArray.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/columns/RaggedArray.java @@ -288,6 +288,11 @@ protected Array changeTypeLong() { return _a.changeTypeLong(); } + @Override + protected Array changeTypeHash64() { + return _a.changeTypeHash64(); + } + @Override protected Array changeTypeString() { return _a.changeTypeString(); diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/columns/StringArray.java b/src/main/java/org/apache/sysds/runtime/frame/data/columns/StringArray.java index fd86286972b..03c2c7cc82c 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/columns/StringArray.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/columns/StringArray.java @@ -22,7 +22,6 @@ import java.io.DataInput; import java.io.DataOutput; import java.io.IOException; -import java.nio.charset.Charset; import java.util.Arrays; import java.util.BitSet; import java.util.HashMap; @@ -236,11 +235,17 @@ public ValueType getValueType() { } private static final ValueType getHighest(ValueType state, ValueType c) { - switch(state) { + case FP64: + switch(c) { + case HASH64: + return c; + default: + } case FP32: switch(c) { case FP64: + case HASH64: return c; default: } @@ -249,6 +254,7 @@ private static final ValueType getHighest(ValueType state, ValueType c) { switch(c) { case FP64: case FP32: + case HASH64: return c; default: } @@ -258,6 +264,7 @@ private static final ValueType getHighest(ValueType state, ValueType c) { case FP64: case FP32: case INT64: + case HASH64: return c; default: } @@ -269,6 +276,7 @@ private static final ValueType getHighest(ValueType state, ValueType c) { case INT64: case INT32: case CHARACTER: + case HASH64: return c; default: } @@ -286,9 +294,8 @@ public Pair analyzeValueType() { boolean nulls = false; for(int i = 0; i < _size; i++) { final ValueType c = FrameUtil.isType(_data[i], state); - if(c == ValueType.STRING) { + if(c == ValueType.STRING) return new Pair<>(ValueType.STRING, false); - } else if(c == ValueType.UNKNOWN) nulls = true; else @@ -560,6 +567,22 @@ protected Array changeTypeLong() { } } + @Override + protected Array changeTypeHash64() { + try { + long[] ret = new long[size()]; + for(int i = 0; i < size(); i++) { + final String s = _data[i]; + if(s != null) + ret[i] = Long.parseLong(s, 16); + } + return new HashLongArray(ret); + } + catch(NumberFormatException e) { + throw new DMLRuntimeException("Unable to change to Hash64 from String array", e); + } + } + @Override public Array changeTypeCharacter() { char[] ret = new char[size()]; diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/lib/FrameLibApplySchema.java b/src/main/java/org/apache/sysds/runtime/frame/data/lib/FrameLibApplySchema.java index 92372ecab23..f782933307f 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/lib/FrameLibApplySchema.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/lib/FrameLibApplySchema.java @@ -20,6 +20,7 @@ package org.apache.sysds.runtime.frame.data.lib; import java.util.ArrayList; +import java.util.Arrays; import java.util.List; import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutorService; diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/lib/FrameUtil.java b/src/main/java/org/apache/sysds/runtime/frame/data/lib/FrameUtil.java index 705aeb24c37..309560c46d6 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/lib/FrameUtil.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/lib/FrameUtil.java @@ -122,6 +122,18 @@ else if(integerFloatPattern.matcher(val).matches()) { return null; } + public static ValueType isHash(final String val, final int len) { + if(len == 8) { + for(int i = 0; i < 8; i++) { + char v = val.charAt(i); + if(v < '0' || v > 'f') + return null; + } + return ValueType.HASH64; + } + return null; + } + public static ValueType isFloatType(final String val, final int len) { if(len <= 30 && (simpleFloatMatch(val, len) || floatPattern.matcher(val).matches())) { if(len <= 7 || (len == 8 && val.charAt(0) == '-')) @@ -169,7 +181,7 @@ private static boolean simpleFloatMatch(final String val, final int len) { final char c = val.charAt(i); if(c >= '0' && c <= '9') continue; - else if(c == '.' || c == ','){ + else if(c == '.' || c == ',') { if(encounteredDot == true) return false; else @@ -209,7 +221,7 @@ public static ValueType isType(String val, ValueType minType) { switch(minType) { case UNKNOWN: case BOOLEAN: - // case CHARACTER: + // case CHARACTER: if(isBooleanType(val, len) != null) return ValueType.BOOLEAN; case UINT8: @@ -226,6 +238,10 @@ public static ValueType isType(String val, ValueType minType) { case CHARACTER: if(len == 1) return ValueType.CHARACTER; + case HASH64: + r = isHash(val, len); + if(r != null) + return r; case STRING: default: return ValueType.STRING; diff --git a/src/main/java/org/apache/sysds/runtime/util/UtilFunctions.java b/src/main/java/org/apache/sysds/runtime/util/UtilFunctions.java index 967855814fa..b46792da029 100644 --- a/src/main/java/org/apache/sysds/runtime/util/UtilFunctions.java +++ b/src/main/java/org/apache/sysds/runtime/util/UtilFunctions.java @@ -46,6 +46,7 @@ import org.apache.sysds.runtime.data.TensorIndexes; import org.apache.sysds.runtime.frame.data.FrameBlock; import org.apache.sysds.runtime.frame.data.columns.CharArray; +import org.apache.sysds.runtime.frame.data.columns.HashLongArray; import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue; import org.apache.sysds.runtime.matrix.data.MatrixIndexes; import org.apache.sysds.runtime.matrix.data.Pair; @@ -483,15 +484,16 @@ public static Object doubleToObject(ValueType vt, double in, boolean sparse) { public static Object stringToObject(ValueType vt, String in) { if( in == null || in.isEmpty() ) return null; switch( vt ) { - case STRING: return in; - case BOOLEAN: return Boolean.parseBoolean(in); + case STRING: return in; + case BOOLEAN: return Boolean.parseBoolean(in); case UINT4: case UINT8: - case INT32: return Integer.parseInt(in); - case INT64: return Long.parseLong(in); - case FP64: return Double.parseDouble(in); - case FP32: return Float.parseFloat(in); + case INT32: return Integer.parseInt(in); + case INT64: return Long.parseLong(in); + case FP64: return Double.parseDouble(in); + case FP32: return Float.parseFloat(in); case CHARACTER: return CharArray.parseChar(in); + case HASH64: return HashLongArray.parseHashLong(in); default: throw new RuntimeException("Unsupported value type: "+vt); } } @@ -674,7 +676,7 @@ else if(in instanceof String && ((String)in).trim().length() == 0) public static Object objectToObject(ValueType vt, Object in) { if( in instanceof Double && vt == ValueType.FP64 || in instanceof Float && vt == ValueType.FP32 - || in instanceof Long && vt == ValueType.INT64 + || in instanceof Long && (vt == ValueType.INT64 || vt == ValueType.HASH64) || in instanceof Integer && vt == ValueType.INT32 || in instanceof Boolean && vt == ValueType.BOOLEAN || in instanceof String && vt == ValueType.STRING ) diff --git a/src/test/java/org/apache/sysds/test/TestUtils.java b/src/test/java/org/apache/sysds/test/TestUtils.java index 45fe79a4a3a..acda5eaf839 100644 --- a/src/test/java/org/apache/sysds/test/TestUtils.java +++ b/src/test/java/org/apache/sysds/test/TestUtils.java @@ -2549,6 +2549,7 @@ public static Object generateRandomValueFromValueType(ValueType valueType, Rando case INT32: return random.nextInt(); case INT64: return random.nextLong(); case BOOLEAN: return random.nextBoolean(); + case HASH64: return Long.toHexString(random.nextLong()); case STRING: return random.ints('a', 'z' + 1) .limit(10) diff --git a/src/test/java/org/apache/sysds/test/component/frame/array/CustomArrayTests.java b/src/test/java/org/apache/sysds/test/component/frame/array/CustomArrayTests.java index f0dcbf9c6eb..94a5810bf4d 100644 --- a/src/test/java/org/apache/sysds/test/component/frame/array/CustomArrayTests.java +++ b/src/test/java/org/apache/sysds/test/component/frame/array/CustomArrayTests.java @@ -45,6 +45,7 @@ import org.apache.sysds.runtime.frame.data.columns.DDCArray; import org.apache.sysds.runtime.frame.data.columns.DoubleArray; import org.apache.sysds.runtime.frame.data.columns.FloatArray; +import org.apache.sysds.runtime.frame.data.columns.HashLongArray; import org.apache.sysds.runtime.frame.data.columns.IntegerArray; import org.apache.sysds.runtime.frame.data.columns.LongArray; import org.apache.sysds.runtime.frame.data.columns.OptionalArray; @@ -857,7 +858,7 @@ public void testDDCIn() { try { Array a = null; Array b = new DDCArray(new LongArray(new long[] {1, 2, 3, 4}), // - MapToFactory.create(10, new int[] {0, 0, 0, 0, 1, 1, 1, 2, 2, 3,3}, 4)); + MapToFactory.create(10, new int[] {0, 0, 0, 0, 1, 1, 1, 2, 2, 3, 3}, 4)); Array c = ArrayFactory.set(a, b, 10, 19, 20); assertEquals((long) c.get(0), 0L); assertEquals((long) c.get(10), 1L); @@ -873,7 +874,7 @@ public void testDDCInOptional() { try { Array a = null; Array b = new DDCArray(new OptionalArray(new Long[] {1L, 2L, 3L, 4L}), // - MapToFactory.create(10, new int[] {0, 0, 0, 0, 1, 1, 1, 2, 2, 3,3}, 4)); + MapToFactory.create(10, new int[] {0, 0, 0, 0, 1, 1, 1, 2, 2, 3, 3}, 4)); Array c = ArrayFactory.set(a, b, 10, 19, 20); assertEquals(c.get(0), null); assertEquals((long) c.get(10), 1L); @@ -884,8 +885,6 @@ public void testDDCInOptional() { } } - - @Test public void testSetOptionalB() { try { @@ -1364,4 +1363,52 @@ public void hashDoubleOnStringNull() { assertEquals(a.hashDouble(i), Double.NaN, 0.0); } } + + @Test + public void parseHash() { + assertEquals(10, HashLongArray.parseHashLong("a")); + } + + @Test + public void parseHash_ff() { + assertEquals(255, HashLongArray.parseHashLong("ff")); + } + + @Test + public void parseHash_fff() { + assertEquals(4095, HashLongArray.parseHashLong("fff")); + } + + @Test + public void parseHash_ffff() { + assertEquals(65535, HashLongArray.parseHashLong("ffff")); + } + + + @Test + public void parseHash_fffff() { + assertEquals(1048575, HashLongArray.parseHashLong("fffff")); + } + + @Test + public void parseHash_ffffff() { + assertEquals(16777215, HashLongArray.parseHashLong("ffffff")); + } + + @Test + public void parseHash_fffffff() { + assertEquals(268435455L, HashLongArray.parseHashLong("fffffff")); + } + + + @Test + public void parseHash_ffffffff() { + assertEquals(4294967295L, HashLongArray.parseHashLong("ffffffff")); + } + + @Test + public void parseHash_ffffffff_ffffffff() { + assertEquals(-1, HashLongArray.parseHashLong("ffffffffffffffff")); + } + } diff --git a/src/test/java/org/apache/sysds/test/component/frame/array/FrameArrayConstantTests.java b/src/test/java/org/apache/sysds/test/component/frame/array/FrameArrayConstantTests.java index ca707b7156b..645eb30ad4b 100644 --- a/src/test/java/org/apache/sysds/test/component/frame/array/FrameArrayConstantTests.java +++ b/src/test/java/org/apache/sysds/test/component/frame/array/FrameArrayConstantTests.java @@ -102,6 +102,8 @@ public void testConstruction_default() { @Test public void testConstruction_1() { try { + if(t == ValueType.HASH64) + return; Array a = ArrayFactory.allocate(t, nRow, "1.0"); for(int i = 0; i < nRow; i++) assertEquals(a.getAsDouble(i), 1.0, 0.0000000001); diff --git a/src/test/java/org/apache/sysds/test/component/frame/array/FrameArrayTests.java b/src/test/java/org/apache/sysds/test/component/frame/array/FrameArrayTests.java index 35d4d0e87c9..71211ab52c1 100644 --- a/src/test/java/org/apache/sysds/test/component/frame/array/FrameArrayTests.java +++ b/src/test/java/org/apache/sysds/test/component/frame/array/FrameArrayTests.java @@ -49,6 +49,7 @@ import org.apache.sysds.runtime.frame.data.columns.DDCArray; import org.apache.sysds.runtime.frame.data.columns.DoubleArray; import org.apache.sysds.runtime.frame.data.columns.FloatArray; +import org.apache.sysds.runtime.frame.data.columns.HashLongArray; import org.apache.sysds.runtime.frame.data.columns.IntegerArray; import org.apache.sysds.runtime.frame.data.columns.LongArray; import org.apache.sysds.runtime.frame.data.columns.OptionalArray; @@ -377,7 +378,7 @@ public void getStatistics() { @Test public void setWithDDC() { if(a.size() > 31) { - try{ + try { Array t = a.clone(); Array ddc = DDCArray.compressToDDC(// @@ -388,20 +389,20 @@ public void setWithDDC() { assertEquals(t.get(0), (Boolean) false); break; default: - + } } - catch(DMLCompressionException e){ + catch(DMLCompressionException e) { // valid error, Illegal to set range in a compressed array. } - catch(DMLRuntimeException e){ + catch(DMLRuntimeException e) { // is intentional here. - if(!e.getMessage().contains("RaggedArray")){ + if(!e.getMessage().contains("RaggedArray")) { e.printStackTrace(); fail(e.getMessage()); } } - catch(Exception e){ + catch(Exception e) { e.printStackTrace(); fail(e.getMessage()); } @@ -468,6 +469,7 @@ public void get() { x = a.get(); break; case RAGGED: + case HASH64: case OPTIONAL: try { a.get(); @@ -538,6 +540,9 @@ public void testSetRange(int start, int end, int off) { case CHARACTER: ((Array) aa).set(start, end, (Array) a, off); break; + case HASH64: + ((Array) aa).set(start, end, (Array) a, off); + break; default: throw new NotImplementedException(); } @@ -593,6 +598,9 @@ public void testSetRange(int start, int end, int otherSize, int seed) { case CHARACTER: ((Array) aa).set(start, end, (Array) other); break; + case HASH64: + ((Array) aa).set(start, end, (Array) other); + break; default: throw new NotImplementedException(); } @@ -602,6 +610,9 @@ public void testSetRange(int start, int end, int otherSize, int seed) { catch(DMLCompressionException e) { return;// valid } + catch(NumberFormatException e){ + return; // valid + } catch(Exception e) { e.printStackTrace(); fail(e.getMessage()); @@ -650,6 +661,16 @@ public void set() { ((Array) a).set(0, c); assertEquals(((Array) a).get(0), c); return; + case HASH64: + String hash = "abcdefaaaa"; + ((Array) a).set(0, hash); + assertEquals(((Array) a).get(0), hash); + if(a instanceof HashLongArray) { + long hashL = Long.parseUnsignedLong("abcdefaaaa", 16); + ((HashLongArray) a).set(0, hashL); + assertEquals(((HashLongArray) a).get(0), hash); + } + return; default: throw new NotImplementedException(); } @@ -689,6 +710,9 @@ public void setDouble() { case CHARACTER: assertEquals((int) ((Array) a).get(0), 1); return; + case HASH64: + assertEquals(((Array) a).get(0), "1"); + return; default: throw new NotImplementedException(); } @@ -728,6 +752,9 @@ public void setDouble_2() { case CHARACTER: assertEquals(((Array) a).get(0), Character.valueOf((char) 0)); return; + case HASH64: + assertEquals(((Array) a).get(0), "0"); + return; default: throw new NotImplementedException(); } @@ -928,6 +955,15 @@ public void appendString() { aa.append(vci); assertEquals((char) aa.get(aa.size() - 1), vc); break; + case HASH64: + String hash = "aaaab"; + aa.append(hash); + assertEquals(aa.get(aa.size() - 1), hash); + + hash = "abbbbaa"; + aa.append(hash); + assertEquals(aa.get(aa.size() - 1), hash); + break; case UNKNOWN: default: throw new DMLRuntimeException("Invalid type"); @@ -973,6 +1009,9 @@ public void appendNull() { case CHARACTER: assertEquals((char) aa.get(aa.size() - 1), 0); break; + case HASH64: + assertEquals(aa.get(aa.size() - 1), "0"); + break; case UNKNOWN: default: throw new DMLRuntimeException("Invalid type"); @@ -1020,6 +1059,9 @@ public void append60Null() { case CHARACTER: assertEquals((char) aa.get(aa.size() - 1), 0); break; + case HASH64: + assertEquals(aa.get(aa.size() - 1), "0"); + break; case UNKNOWN: default: throw new DMLRuntimeException("Invalid type"); @@ -1060,6 +1102,9 @@ public void testSetNzSelf() { case CHARACTER: ((Array) aa).setNz((Array) a); break; + case HASH64: + ((Array) aa).setNz((Array) a); + break; case UNKNOWN: default: throw new DMLRuntimeException("Invalid type"); @@ -1082,7 +1127,6 @@ public void testSetNzString() { Array aa = a.clone(); Array af = (Array) aa.changeType(ValueType.STRING); try { - aa.setFromOtherTypeNz(af); } catch(DMLCompressionException e) { @@ -1102,7 +1146,6 @@ public void testSetNzStringWithNull() { Array aa = a.clone(); Array af = (Array) aa.changeTypeWithNulls(ValueType.STRING); try { - aa.setFromOtherTypeNz(af); } catch(DMLCompressionException e) { @@ -1122,7 +1165,6 @@ public void testSetFromString() { Array aa = a.clone(); Array af = (Array) aa.changeType(ValueType.STRING); try { - aa.setFromOtherType(0, af.size() - 1, af); } catch(DMLCompressionException e) { @@ -1140,8 +1182,11 @@ public void testSetFromString() { public void testSetFromStringWithNull() { Array aa = a.clone(); Array af; - if(aa.getFrameArrayType() == FrameArrayType.OPTIONAL && aa.getValueType() != ValueType.STRING) + if(aa.getFrameArrayType() == FrameArrayType.OPTIONAL // + && aa.getValueType() != ValueType.STRING // + && aa.getValueType() != ValueType.HASH64) { af = aa.changeTypeWithNulls(ValueType.FP64); + } else af = aa.changeTypeWithNulls(ValueType.STRING); @@ -1289,7 +1334,6 @@ public void setNullType() { ((Array) aa).set(0, (Character) null); assertTrue(aa.get(0) == null || aa.get(0).equals(Character.valueOf((char) 0))); break; - case FP32: ((Array) aa).set(0, (Float) null); assertTrue(aa.get(0) == null || aa.get(0).equals(Float.valueOf(0.0f))); @@ -1310,12 +1354,17 @@ public void setNullType() { ((Array) aa).set(0, (Integer) null); assertTrue(aa.get(0) == null || aa.get(0).equals(Integer.valueOf(0))); break; - default: + case HASH64: + aa.set(0, (String) null); + assertTrue(aa.get(0) == null || aa.get(0).equals("0")); + break; case STRING: case UNKNOWN: aa.set(0, (String) null); assertTrue(aa.get(0) == null); break; + default: + throw new NotImplementedException(); } } catch(DMLCompressionException e) { @@ -1374,6 +1423,12 @@ public void testAppendArray() { for(int i = 0; i < 10; i++) assertEquals(aa.get(i + a.size()), null); break; + case HASH64: + aa = ((Array) aa).append(new HashLongArray(new long[10])); + assertEquals(aa.size(), a.size() + 10); + for(int i = 0; i < 10; i++) + assertEquals(aa.get(i + a.size()), "0"); + break; case UNKNOWN: default: throw new NotImplementedException("Not supported"); @@ -1385,6 +1440,10 @@ public void testAppendArray() { catch(DMLCompressionException e) { return; // valid } + catch(Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } } @@ -1439,6 +1498,12 @@ public void testAppendValue() { if(!isOptional) assertEquals(aa.get(a.size()), null); break; + case HASH64: + aa.append((String) null); + assertEquals(aa.size(), a.size() + 1); + if(!isOptional) + assertEquals(aa.get(a.size()), "0"); + break; case UNKNOWN: default: throw new NotImplementedException("Not supported"); @@ -1490,6 +1555,9 @@ public void testAppendArrayOptional() { case INT64: aa = ((Array) aa).append(new OptionalArray<>(new Long[10])); break; + case HASH64: + aa = ((Array) aa).append(new OptionalArray<>(new HashLongArray(new long[10]), true)); + break; case STRING: return; // not relevant case UNKNOWN: @@ -1555,6 +1623,11 @@ public void fillNull() { for(int i = 0; i < aa.size(); i++) assertEquals(aa.get(i), null); break; + case HASH64: + if(!isOptional) + for(int i = 0; i < aa.size(); i++) + assertEquals(aa.get(i), "0"); + break; case UNKNOWN: default: throw new NotImplementedException("Not supported"); @@ -1567,6 +1640,10 @@ public void fillNull() { catch(DMLCompressionException e) { return;// valid } + catch(Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } } @Test @@ -1606,6 +1683,10 @@ public void fill1String() { for(int i = 0; i < aa.size(); i++) assertEquals(aa.get(i), "1"); break; + case HASH64: + for(int i = 0; i < aa.size(); i++) + assertEquals(aa.get(i), "1"); + break; case UNKNOWN: default: throw new NotImplementedException("Not supported"); @@ -1659,6 +1740,11 @@ public void fill1Value() { for(int i = 0; i < aa.size(); i++) assertEquals(aa.get(i), "1"); break; + case HASH64: + aa.fill("1"); + for(int i = 0; i < aa.size(); i++) + assertEquals(aa.get(i), "1"); + break; case UNKNOWN: default: throw new NotImplementedException("Not supported"); @@ -1721,6 +1807,12 @@ public void fill1ValueNull() { for(int i = 0; i < aa.size(); i++) assertEquals(aa.get(i), null); break; + case HASH64: + ((Array) aa).fill((Object) null); + if(!isOptional) + for(int i = 0; i < aa.size(); i++) + assertEquals(aa.get(i), "0"); + break; case UNKNOWN: default: throw new NotImplementedException("Not supported"); @@ -1788,7 +1880,6 @@ public void changeTypeWithNulls() { } catch(Exception e) { e.printStackTrace(); - LOG.error(a); fail(e.getMessage()); } } @@ -1868,7 +1959,7 @@ protected static Array serializeAndBack(Array g) { DataOutputStream fos = new DataOutputStream(bos); g.write(fos); DataInputStream fis = new DataInputStream(new ByteArrayInputStream(bos.toByteArray())); - Array gr = ArrayFactory.read(fis, nRow); + Array gr = ArrayFactory.read(fis, nRow); return gr; } catch(Exception e) { @@ -1900,6 +1991,9 @@ protected static Array createDDC(FrameArrayType t, int size, int seed) { case CHARACTER: return DDCArray .compressToDDC(ArrayFactory.create(generateRandomCharacterNUniqueLengthOpt(size, seed, nUnique))); + case HASH64: + return DDCArray + .compressToDDC(ArrayFactory.createHash64(generateRandomHash64OptNUnique(size, seed, nUnique))); case OPTIONAL: Random r = new Random(seed); switch(r.nextInt(7)) { @@ -1985,6 +2079,8 @@ protected static Array createOptional(FrameArrayType t, int size, int seed) { return ArrayFactory.create(generateRandomDoubleOpt(size, seed)); case CHARACTER: return ArrayFactory.create(generateRandomCharacterOpt(size, seed)); + case HASH64: + return ArrayFactory.createHash64Opt(generateRandomHash64Opt(size, seed)); case OPTIONAL: case RAGGED: // lets not test this case here. Random r = new Random(seed); @@ -2051,6 +2147,8 @@ protected static Array create(FrameArrayType t, int size, int seed) { return ArrayFactory.create(generateRandomDouble(size, seed)); case CHARACTER: return ArrayFactory.create(generateRandomChar(size, seed)); + case HASH64: + return ArrayFactory.createHash64(generateRandomHash64(size, seed)); case RAGGED: Random rand = new Random(seed); switch(rand.nextInt(7)) { @@ -2082,6 +2180,8 @@ protected static Array create(FrameArrayType t, int size, int seed) { return ArrayFactory.create(generateRandomFloatOpt(size, seed)); case 4: return ArrayFactory.create(generateRandomCharacterOpt(size, seed)); + case 5: + return ArrayFactory.create(generateRandomHash64Opt(size, seed)); default: return ArrayFactory.create(generateRandomBooleanOpt(size, seed)); } @@ -2163,6 +2263,18 @@ public static String[] generateRandomStringNUniqueLengthOpt(int size, int seed, return ret; } + public static String[] generateRandomHash64OptNUnique(int size, int seed, int nUnique) { + nUnique = Math.max(1, nUnique); + String[] rands = generateRandomHash64(nUnique, seed); + rands[rands.length - 1] = null; + Random r = new Random(seed + 1); + + String[] ret = new String[size]; + for(int i = 0; i < size; i++) + ret[i] = rands[r.nextInt(nUnique)]; + return ret; + } + public static Character[] generateRandomCharacterNUniqueLengthOpt(int size, int seed, int nUnique) { Character[] rands = generateRandomCharacterOpt(nUnique, seed); rands[rands.length - 1] = null; @@ -2228,6 +2340,25 @@ public static String[] generateRandomStringOpt(int size, int seed) { return ret; } + public static String[] generateRandomHash64(int size, int seed) { + Random r = new Random(seed); + String[] ret = new String[size]; + for(int i = 0; i < size; i++) { + ret[i] = Long.toHexString(r.nextLong()); + } + return ret; + } + + public static String[] generateRandomHash64Opt(int size, int seed) { + Random r = new Random(seed); + String[] ret = new String[size]; + for(int i = 0; i < size; i++) { + if(r.nextBoolean()) + ret[i] = Long.toHexString(r.nextLong()); + } + return ret; + } + public static String[] generateRandom01String(int size, int seed) { Random r = new Random(seed); String[] ret = new String[size]; diff --git a/src/test/java/org/apache/sysds/test/component/frame/iterators/IteratorTest.java b/src/test/java/org/apache/sysds/test/component/frame/iterators/IteratorTest.java index c6f5bfd621a..8ad57f3c52d 100644 --- a/src/test/java/org/apache/sysds/test/component/frame/iterators/IteratorTest.java +++ b/src/test/java/org/apache/sysds/test/component/frame/iterators/IteratorTest.java @@ -22,6 +22,7 @@ import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotEquals; import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; import java.util.Arrays; @@ -36,8 +37,20 @@ public class IteratorTest { - private final FrameBlock fb1 = TestUtils.generateRandomFrameBlock(10, 10, 23); - private final FrameBlock fb2 = TestUtils.generateRandomFrameBlock(40, 30, 22); + private final FrameBlock fb1; + private final FrameBlock fb2; + + public IteratorTest() { + try { + fb1 = TestUtils.generateRandomFrameBlock(10, 10, 23); + fb2 = TestUtils.generateRandomFrameBlock(40, 30, 22); + } + catch(Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + throw new RuntimeException(e); + } + } @Test public void StringObjectStringFB1() { @@ -236,29 +249,27 @@ public void iteratorWithSchema() { compareIterators(a, b); } - - @Test(expected= DMLRuntimeException.class) - public void invalidRange1(){ + @Test(expected = DMLRuntimeException.class) + public void invalidRange1() { IteratorFactory.getStringRowIterator(fb2, -1, 1); } - @Test(expected= DMLRuntimeException.class) - public void invalidRange2(){ + @Test(expected = DMLRuntimeException.class) + public void invalidRange2() { IteratorFactory.getStringRowIterator(fb2, 132415, 132416); } - @Test(expected= DMLRuntimeException.class) - public void invalidRange3(){ + @Test(expected = DMLRuntimeException.class) + public void invalidRange3() { IteratorFactory.getStringRowIterator(fb2, 13, 4); } - @Test(expected= DMLRuntimeException.class) - public void remove(){ - RowIterator a =IteratorFactory.getStringRowIterator(fb2, 0, 4); + @Test(expected = DMLRuntimeException.class) + public void remove() { + RowIterator a = IteratorFactory.getStringRowIterator(fb2, 0, 4); a.remove(); } - private static void compareIterators(RowIterator a, RowIterator b) { while(a.hasNext()) { assertTrue(b.hasNext()); From bc277e546ded349bb1d646c11012479ce562b925 Mon Sep 17 00:00:00 2001 From: Sebastian Baunsgaard Date: Thu, 26 Oct 2023 14:00:01 +0200 Subject: [PATCH 07/28] [SYSTEMDS-3637] Manifest jar with ClassPath To improve the startup time further we can specify a manifest that is included in the SystemDS jar file. this improves performance of a hello world script to 1,1932 sec from 1.2804. It also change the way we can execute our scripts from java from: java -cp target/SystemDS.jar:./lib/*:./target/lib/* org.apache.sysds.api.DMLScript -f tmp/test.dml to: java -jar target/SystemDS.jar -f tmp/test.dml The old versions of executing systemds still is supported, but it is recommended to change the scripts to execute with the new version Closes #1934 --- pom.xml | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/pom.xml b/pom.xml index 4fc1dd2ccdc..b97cfb30ada 100644 --- a/pom.xml +++ b/pom.xml @@ -166,6 +166,14 @@ src/assembly/bin META-INF/ + + + conf + + log4j.properties + + ./ + @@ -213,8 +221,13 @@ true + true + lib/ org.apache.sysds.api.DMLScript + + SystemDS.jar ${project.artifactId}-${project.version}.jar + @@ -227,7 +240,6 @@ perf org/apache/sysds/performance/** - log4j.properties @@ -236,7 +248,7 @@ org.apache.sysds.performance.Main - SystemDS.jar ${project.build.directory}/${project.artifactId}-${project.version}-tests.jar + SystemDS.jar ${project.artifactId}-${project.version}-tests.jar @@ -1427,7 +1439,7 @@ io.netty netty-all 4.1.68.Final - provided + org.apache.logging.log4j From f038787f3887787f32d0d8a52aa1856ad14a7e00 Mon Sep 17 00:00:00 2001 From: Matthias Boehm Date: Thu, 26 Oct 2023 11:39:24 +0200 Subject: [PATCH 08/28] [MINOR] Fix warnings and code quality issues --- src/main/java/org/apache/sysds/hops/FunctionOp.java | 2 -- .../sysds/hops/ipa/InterProceduralAnalysis.java | 2 -- .../apache/sysds/hops/rewrite/ProgramRewriter.java | 2 -- .../java/org/apache/sysds/parser/DMLProgram.java | 1 - .../apache/sysds/parser/ParForStatementBlock.java | 2 -- .../apache/sysds/runtime/frame/data/FrameBlock.java | 5 ----- .../runtime/frame/data/lib/FrameFromMatrixBlock.java | 1 - .../instructions/gpu/context/GPUContextPool.java | 1 - .../sysds/runtime/lineage/LineageCacheConfig.java | 2 ++ .../sysds/runtime/lineage/LineageCacheEntry.java | 2 +- .../sysds/runtime/matrix/data/LibMatrixMult.java | 1 - .../sysds/performance/micro/InformationLoss.java | 4 ++-- .../sysds/test/component/frame/FrameCustomTest.java | 6 +++--- .../test/component/frame/array/CustomArrayTests.java | 4 ++-- .../test/component/frame/array/FrameArrayTests.java | 12 ++++++------ .../component/frame/array/NegativeArrayTests.java | 8 ++++---- .../transform/TransformCompressedTestSingleCol.java | 1 - .../compress/matrixByBin/CompressByBinTest.java | 7 ++++--- .../pipelines/BuiltinImageTransformLinTest.java | 1 - 19 files changed, 24 insertions(+), 40 deletions(-) diff --git a/src/main/java/org/apache/sysds/hops/FunctionOp.java b/src/main/java/org/apache/sysds/hops/FunctionOp.java index e5e66f2ccfb..45f21e975b5 100644 --- a/src/main/java/org/apache/sysds/hops/FunctionOp.java +++ b/src/main/java/org/apache/sysds/hops/FunctionOp.java @@ -25,8 +25,6 @@ import org.apache.sysds.api.DMLScript; import org.apache.sysds.lops.Compression; -import org.apache.sysds.lops.Data; -import org.apache.sysds.lops.DeCompression; import org.apache.sysds.lops.FunctionCallCP; import org.apache.sysds.lops.Lop; import org.apache.sysds.common.Types.ExecType; diff --git a/src/main/java/org/apache/sysds/hops/ipa/InterProceduralAnalysis.java b/src/main/java/org/apache/sysds/hops/ipa/InterProceduralAnalysis.java index d7a3fc7f8ce..d0ea21a8aac 100644 --- a/src/main/java/org/apache/sysds/hops/ipa/InterProceduralAnalysis.java +++ b/src/main/java/org/apache/sysds/hops/ipa/InterProceduralAnalysis.java @@ -21,8 +21,6 @@ import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; -import org.apache.log4j.Level; -import org.apache.log4j.Logger; import org.apache.sysds.common.Types.DataType; import org.apache.sysds.common.Types.ValueType; import org.apache.sysds.conf.ConfigurationManager; diff --git a/src/main/java/org/apache/sysds/hops/rewrite/ProgramRewriter.java b/src/main/java/org/apache/sysds/hops/rewrite/ProgramRewriter.java index 21c70a6850a..1754b72b5ea 100644 --- a/src/main/java/org/apache/sysds/hops/rewrite/ProgramRewriter.java +++ b/src/main/java/org/apache/sysds/hops/rewrite/ProgramRewriter.java @@ -22,8 +22,6 @@ import java.util.ArrayList; import java.util.List; -import org.apache.log4j.Level; -import org.apache.log4j.Logger; import org.apache.sysds.api.DMLScript; import org.apache.sysds.conf.ConfigurationManager; import org.apache.sysds.conf.CompilerConfig.ConfigType; diff --git a/src/main/java/org/apache/sysds/parser/DMLProgram.java b/src/main/java/org/apache/sysds/parser/DMLProgram.java index 5b5b5be60fc..99de236651d 100644 --- a/src/main/java/org/apache/sysds/parser/DMLProgram.java +++ b/src/main/java/org/apache/sysds/parser/DMLProgram.java @@ -20,7 +20,6 @@ package org.apache.sysds.parser; import java.util.ArrayList; -import java.util.Arrays; import java.util.HashMap; import java.util.List; import java.util.Map; diff --git a/src/main/java/org/apache/sysds/parser/ParForStatementBlock.java b/src/main/java/org/apache/sysds/parser/ParForStatementBlock.java index 5ed91adbd3b..8856ee07b23 100644 --- a/src/main/java/org/apache/sysds/parser/ParForStatementBlock.java +++ b/src/main/java/org/apache/sysds/parser/ParForStatementBlock.java @@ -35,8 +35,6 @@ import org.apache.sysds.hops.rewrite.HopRewriteUtils; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; -import org.apache.log4j.Level; -import org.apache.log4j.Logger; import org.apache.sysds.common.Builtins; import org.apache.sysds.common.Types.DataType; import org.apache.sysds.common.Types.OpOp1; diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/FrameBlock.java b/src/main/java/org/apache/sysds/runtime/frame/data/FrameBlock.java index 456fc9afc7b..3efafbb30bc 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/FrameBlock.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/FrameBlock.java @@ -24,9 +24,7 @@ import java.io.Externalizable; import java.io.IOException; import java.io.ObjectInput; -import java.io.ObjectInputStream; import java.io.ObjectOutput; -import java.io.ObjectOutputStream; import java.io.Serializable; import java.lang.ref.SoftReference; import java.lang.reflect.InvocationTargetException; @@ -71,7 +69,6 @@ import org.apache.sysds.runtime.instructions.cp.ScalarObject; import org.apache.sysds.runtime.io.IOUtilFunctions; import org.apache.sysds.runtime.matrix.data.MatrixBlock; -import org.apache.sysds.runtime.matrix.data.MatrixBlockDataInput; import org.apache.sysds.runtime.matrix.data.Pair; import org.apache.sysds.runtime.matrix.operators.BinaryOperator; import org.apache.sysds.runtime.meta.DataCharacteristics; @@ -79,8 +76,6 @@ import org.apache.sysds.runtime.util.CommonThreadPool; import org.apache.sysds.runtime.util.DMVUtils; import org.apache.sysds.runtime.util.EMAUtils; -import org.apache.sysds.runtime.util.FastBufferedDataInputStream; -import org.apache.sysds.runtime.util.FastBufferedDataOutputStream; import org.apache.sysds.runtime.util.IndexRange; import org.apache.sysds.runtime.util.UtilFunctions; import org.apache.sysds.utils.MemoryEstimates; diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/lib/FrameFromMatrixBlock.java b/src/main/java/org/apache/sysds/runtime/frame/data/lib/FrameFromMatrixBlock.java index 7e75c621a57..eeac27e2e1d 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/lib/FrameFromMatrixBlock.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/lib/FrameFromMatrixBlock.java @@ -20,7 +20,6 @@ package org.apache.sysds.runtime.frame.data.lib; import java.util.ArrayList; -import java.util.Arrays; import java.util.List; import java.util.concurrent.Callable; import java.util.concurrent.ExecutionException; diff --git a/src/main/java/org/apache/sysds/runtime/instructions/gpu/context/GPUContextPool.java b/src/main/java/org/apache/sysds/runtime/instructions/gpu/context/GPUContextPool.java index aa9bb8563af..049f4923b60 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/gpu/context/GPUContextPool.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/gpu/context/GPUContextPool.java @@ -30,7 +30,6 @@ import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.apache.sysds.runtime.DMLRuntimeException; -import org.apache.sysds.runtime.lineage.LineageGPUCacheEviction; import org.apache.sysds.utils.GPUStatistics; import jcuda.driver.JCudaDriver; diff --git a/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheConfig.java b/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheConfig.java index 67eeed9481a..5df5302c979 100644 --- a/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheConfig.java +++ b/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheConfig.java @@ -100,7 +100,9 @@ public static boolean isNone() { protected static final double CPU_CACHE_FRAC = 0.05; // 5% of JVM heap size private static ReuseCacheType _cacheType = null; + @SuppressWarnings("unused") private static CachedItemHead _itemH = null; + @SuppressWarnings("unused") private static CachedItemTail _itemT = null; private static boolean _compilerAssistedRW = false; private static boolean _onlyEstimate = false; diff --git a/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheEntry.java b/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheEntry.java index e019f9810dd..3bf53cbeeaa 100644 --- a/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheEntry.java +++ b/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheEntry.java @@ -355,7 +355,7 @@ protected synchronized void initiateScoreGPU(Map removeLis if (_timestamp < 0) throw new DMLRuntimeException ("Execution timestamp shouldn't be -ve. Key: "+_key); // Weights for scoring components in GPU - double w1 = 0; + //double w1 = 0; double w2 = 1; double w3 = 1; // Generate initial score diff --git a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixMult.java b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixMult.java index abf6621a780..711197352ac 100644 --- a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixMult.java +++ b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixMult.java @@ -49,7 +49,6 @@ import org.apache.sysds.runtime.data.SparseBlockCSR; import org.apache.sysds.runtime.data.SparseBlockFactory; import org.apache.sysds.runtime.data.SparseBlockMCSR; -import org.apache.sysds.runtime.data.SparseRow; import org.apache.sysds.runtime.data.SparseRowScalar; import org.apache.sysds.runtime.functionobjects.SwapIndex; import org.apache.sysds.runtime.functionobjects.ValueFunction; diff --git a/src/test/java/org/apache/sysds/performance/micro/InformationLoss.java b/src/test/java/org/apache/sysds/performance/micro/InformationLoss.java index 88fd767e76d..68dbdb21cec 100644 --- a/src/test/java/org/apache/sysds/performance/micro/InformationLoss.java +++ b/src/test/java/org/apache/sysds/performance/micro/InformationLoss.java @@ -113,8 +113,8 @@ private static Pair readFrame(String path) throws Excep private static Pair getMinMax(final MatrixBlock org) throws Exception { ExecutorService pool = CommonThreadPool.get(16); - Future minF = pool.submit(() -> (MatrixBlock) org.colMin(16)); - Future maxF = pool.submit(() -> (MatrixBlock) org.colMax(16)); + Future minF = pool.submit(() -> org.colMin(16)); + Future maxF = pool.submit(() -> org.colMax(16)); MatrixBlock min = minF.get(); MatrixBlock max = maxF.get(); diff --git a/src/test/java/org/apache/sysds/test/component/frame/FrameCustomTest.java b/src/test/java/org/apache/sysds/test/component/frame/FrameCustomTest.java index f5fbfd8588e..9d6c7aa4827 100644 --- a/src/test/java/org/apache/sysds/test/component/frame/FrameCustomTest.java +++ b/src/test/java/org/apache/sysds/test/component/frame/FrameCustomTest.java @@ -32,7 +32,7 @@ public class FrameCustomTest { @Test public void castToFrame() { - double maxp1 = ((double) Integer.MAX_VALUE) + 1.0; + double maxp1 = Integer.MAX_VALUE + 1.0; MatrixBlock mb = TestUtils.generateTestMatrixBlock(100, 100, maxp1, maxp1, 1.0, 23); FrameBlock f = DataConverter.convertToFrameBlock(mb); assertTrue(f.getSchema()[0] == ValueType.INT64); @@ -40,7 +40,7 @@ public void castToFrame() { @Test public void castToFrame3() { - double maxp1 = ((double) Integer.MAX_VALUE) - 1.0; + double maxp1 = Integer.MAX_VALUE - 1.0; MatrixBlock mb = TestUtils.generateTestMatrixBlock(100, 100, maxp1, maxp1, 1.0, 23); FrameBlock f = DataConverter.convertToFrameBlock(mb); assertTrue(f.getSchema()[0] == ValueType.INT32); @@ -56,7 +56,7 @@ public void castErrorValue() { @Test public void castToFrame2() { - double maxp1 = ((double) Integer.MAX_VALUE) + 1.1111; + double maxp1 = Integer.MAX_VALUE + 1.1111; MatrixBlock mb = TestUtils.generateTestMatrixBlock(100, 100, maxp1, maxp1, 1.0, 23); FrameBlock f = DataConverter.convertToFrameBlock(mb); assertTrue(f.getSchema()[0] == ValueType.FP64); diff --git a/src/test/java/org/apache/sysds/test/component/frame/array/CustomArrayTests.java b/src/test/java/org/apache/sysds/test/component/frame/array/CustomArrayTests.java index 94a5810bf4d..90c41db1a4e 100644 --- a/src/test/java/org/apache/sysds/test/component/frame/array/CustomArrayTests.java +++ b/src/test/java/org/apache/sysds/test/component/frame/array/CustomArrayTests.java @@ -857,7 +857,7 @@ public void testSetBChangeType() { public void testDDCIn() { try { Array a = null; - Array b = new DDCArray(new LongArray(new long[] {1, 2, 3, 4}), // + Array b = new DDCArray<>(new LongArray(new long[] {1, 2, 3, 4}), // MapToFactory.create(10, new int[] {0, 0, 0, 0, 1, 1, 1, 2, 2, 3, 3}, 4)); Array c = ArrayFactory.set(a, b, 10, 19, 20); assertEquals((long) c.get(0), 0L); @@ -873,7 +873,7 @@ public void testDDCIn() { public void testDDCInOptional() { try { Array a = null; - Array b = new DDCArray(new OptionalArray(new Long[] {1L, 2L, 3L, 4L}), // + Array b = new DDCArray<>(new OptionalArray<>(new Long[] {1L, 2L, 3L, 4L}), // MapToFactory.create(10, new int[] {0, 0, 0, 0, 1, 1, 1, 2, 2, 3, 3}, 4)); Array c = ArrayFactory.set(a, b, 10, 19, 20); assertEquals(c.get(0), null); diff --git a/src/test/java/org/apache/sysds/test/component/frame/array/FrameArrayTests.java b/src/test/java/org/apache/sysds/test/component/frame/array/FrameArrayTests.java index 71211ab52c1..165f1327b20 100644 --- a/src/test/java/org/apache/sysds/test/component/frame/array/FrameArrayTests.java +++ b/src/test/java/org/apache/sysds/test/component/frame/array/FrameArrayTests.java @@ -309,10 +309,10 @@ public void equalsOtherType() { switch(a.getValueType()) { case BOOLEAN: - assertFalse(a.equals((Object) ArrayFactory.create(new char[] {'a', 'b'}))); + assertFalse(a.equals(ArrayFactory.create(new char[] {'a', 'b'}))); break; default: - assertFalse(a.equals((Object) ArrayFactory.create(new boolean[] {true, false}))); + assertFalse(a.equals(ArrayFactory.create(new boolean[] {true, false}))); } } catch(Exception e) { @@ -324,7 +324,7 @@ public void equalsOtherType() { @Test public void equalsSelf() { try { - assertTrue(a.equals((Object) a)); + assertTrue(a.equals(a)); } catch(Exception e) { e.printStackTrace(); @@ -335,7 +335,7 @@ public void equalsSelf() { @Test public void equalsClone() { try { - assertTrue(a.equals((Object) a.clone())); + assertTrue(a.equals(a.clone())); } catch(Exception e) { e.printStackTrace(); @@ -358,7 +358,7 @@ public void notEqualsRandomObject() { public void sameValueTypeNotEquals() { try { Array b = ArrayFactory.allocate(a.getValueType(), a.size() == 1 ? 2 : 1); - assertFalse(a.equals((Object) b)); + assertFalse(a.equals(b)); } catch(Exception e) { e.printStackTrace(); @@ -386,7 +386,7 @@ public void setWithDDC() { ArrayFactory.set(t, ddc, 0, 29, t.size()); switch(t.getValueType()) { case BOOLEAN: - assertEquals(t.get(0), (Boolean) false); + assertEquals(t.get(0), false); break; default: diff --git a/src/test/java/org/apache/sysds/test/component/frame/array/NegativeArrayTests.java b/src/test/java/org/apache/sysds/test/component/frame/array/NegativeArrayTests.java index dd471e3159f..105785ebc99 100644 --- a/src/test/java/org/apache/sysds/test/component/frame/array/NegativeArrayTests.java +++ b/src/test/java/org/apache/sysds/test/component/frame/array/NegativeArrayTests.java @@ -294,7 +294,7 @@ public void set3() { @Test(expected = DMLRuntimeException.class) public void testInvalidRLen() { Array a = null; - Array b = new OptionalArray(new Long[] {1L, 2L, 3L, 4L}); + Array b = new OptionalArray<>(new Long[] {1L, 2L, 3L, 4L}); ArrayFactory.set(a, b, 10, 20, 20); } @@ -308,21 +308,21 @@ public void testNull() { @Test(expected = DMLRuntimeException.class) public void testInvalidBLength() { Array a = null; - Array b = new OptionalArray(new Long[] {1L, 2L, 3L, 4L}); + Array b = new OptionalArray<>(new Long[] {1L, 2L, 3L, 4L}); ArrayFactory.set(a, b, 10, 15, 20);// one to short } @Test(expected = DMLRuntimeException.class) public void testInvalidALength() { Array a = ArrayFactory.allocate( ValueType.INT32, 10); - Array b = new OptionalArray(new Long[] {1L, 2L, 3L, 4L}); + Array b = new OptionalArray<>(new Long[] {1L, 2L, 3L, 4L}); ArrayFactory.set(a, b, 10, 14, 20);// one to short } @Test(expected = DMLRuntimeException.class) public void testInvalidRL() { Array a = ArrayFactory.allocate( ValueType.INT32, 10); - Array b = new OptionalArray(new Long[] {1L, 2L, 3L, 4L}); + Array b = new OptionalArray<>(new Long[] {1L, 2L, 3L, 4L}); ArrayFactory.set(a, b, -1, 15, 20);// one to short } } diff --git a/src/test/java/org/apache/sysds/test/component/frame/transform/TransformCompressedTestSingleCol.java b/src/test/java/org/apache/sysds/test/component/frame/transform/TransformCompressedTestSingleCol.java index 14a552c9934..3a5d05919e9 100644 --- a/src/test/java/org/apache/sysds/test/component/frame/transform/TransformCompressedTestSingleCol.java +++ b/src/test/java/org/apache/sysds/test/component/frame/transform/TransformCompressedTestSingleCol.java @@ -147,7 +147,6 @@ public void test(String spec) { data.getNumColumns(), meta); MatrixBlock outNormal = encoderNormal.encode(data, k); - meta = null; MultiColumnEncoder encoderCompressed = EncoderFactory.createEncoder(spec, data.getColumnNames(), data.getNumColumns(), meta); MatrixBlock outCompressed = encoderCompressed.encode(data, k, true); diff --git a/src/test/java/org/apache/sysds/test/functions/compress/matrixByBin/CompressByBinTest.java b/src/test/java/org/apache/sysds/test/functions/compress/matrixByBin/CompressByBinTest.java index 8265b261d2c..1fe40002c29 100644 --- a/src/test/java/org/apache/sysds/test/functions/compress/matrixByBin/CompressByBinTest.java +++ b/src/test/java/org/apache/sysds/test/functions/compress/matrixByBin/CompressByBinTest.java @@ -52,7 +52,7 @@ public class CompressByBinTest extends AutomatedTestBase { private final static int nbins = 10; - private final static int[] dVector = new int[cols]; + //private final static int[] dVector = new int[cols]; @Override public void setUp() { @@ -146,7 +146,7 @@ private double[][] generateMatrixData(ColumnEncoderBin.BinMethod binMethod) { // Create one column for(int i = 0, j = 0; i < rows; i++) { X[i][c] = vals[j]; - if(i == (int) ((j + 1) * (rows / nbins))) + if(i == ((j + 1) * (rows / nbins))) j++; } } @@ -156,6 +156,7 @@ private double[][] generateMatrixData(ColumnEncoderBin.BinMethod binMethod) { return X; } + @SuppressWarnings("unchecked") private FrameBlock generateFrameData(ColumnEncoderBin.BinMethod binMethod, Types.ValueType[] schema) { FrameBlock Xf; if(binMethod == ColumnEncoderBin.BinMethod.EQUI_WIDTH) { @@ -173,7 +174,7 @@ private FrameBlock generateFrameData(ColumnEncoderBin.BinMethod binMethod, Types Array f = (Array) ArrayFactory.allocate(Types.ValueType.FP32, rows); for(int i = 0, j = 0; i < rows; i++) { f.set(i, vals[j]); - if(i == (int) ((j + 1) * (rows / nbins))) + if(i == ((j + 1) * (rows / nbins))) j++; } Xf.appendColumn(f); diff --git a/src/test/java/org/apache/sysds/test/functions/pipelines/BuiltinImageTransformLinTest.java b/src/test/java/org/apache/sysds/test/functions/pipelines/BuiltinImageTransformLinTest.java index e240b8e6fa7..8448af7262f 100644 --- a/src/test/java/org/apache/sysds/test/functions/pipelines/BuiltinImageTransformLinTest.java +++ b/src/test/java/org/apache/sysds/test/functions/pipelines/BuiltinImageTransformLinTest.java @@ -24,7 +24,6 @@ import org.apache.sysds.test.AutomatedTestBase; import org.apache.sysds.test.TestConfiguration; import org.apache.sysds.test.TestUtils; -import org.junit.Ignore; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; From aab900c181fc5b33483d527792fec889d86862cb Mon Sep 17 00:00:00 2001 From: Matthias Boehm Date: Thu, 26 Oct 2023 21:35:27 +0200 Subject: [PATCH 09/28] [SYSTEMDS-3636] Sparse Transpose-Self MatMult w/ Sparse Outputs So far all dense and sparse tsmm operations always worked with dense outputs and only finally converted the output to sparse where needed. On large graph operations like G %*% t(G) this quickly runs output of memory. This patch adds for tsmm right a dedicated kernel that directly outputs sparse representations because we can perform sparse dot products for row-column combinations. --- .../runtime/matrix/data/LibMatrixMult.java | 165 +++++++++++++++--- .../TransposeMatrixMultiplicationTest.java | 1 - ...MatrixMultiplicationTransposeSelfTest.java | 38 ++++ .../TransposeSelfMatrixMultiplication1.dml | 2 +- .../TransposeSelfMatrixMultiplication2.dml | 2 +- 5 files changed, 185 insertions(+), 23 deletions(-) diff --git a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixMult.java b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixMult.java index 711197352ac..03dcbc359be 100644 --- a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixMult.java +++ b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixMult.java @@ -50,6 +50,7 @@ import org.apache.sysds.runtime.data.SparseBlockFactory; import org.apache.sysds.runtime.data.SparseBlockMCSR; import org.apache.sysds.runtime.data.SparseRowScalar; +import org.apache.sysds.runtime.data.SparseRowVector; import org.apache.sysds.runtime.functionobjects.SwapIndex; import org.apache.sysds.runtime.functionobjects.ValueFunction; import org.apache.sysds.runtime.matrix.operators.ReorgOperator; @@ -434,13 +435,13 @@ public static void matrixMultTransposeSelf(MatrixBlock m1, MatrixBlock ret, bool //Timing time = new Timing(true); //pre-processing - ret.sparse = false; - ret.allocateDenseBlock(); + double sp = m1.getSparsity(); + double osp = OptimizerUtils.getMatMultSparsity(sp, sp, m1.rlen, m1.clen, m1.rlen, false); + ret.sparse = !leftTranspose && m1.sparse && osp < MatrixBlock.SPARSITY_TURN_POINT; + ret.allocateBlock(); - if( m1.sparse ) - matrixMultTransposeSelfSparse(m1, ret, leftTranspose, 0, ret.rlen); - else - matrixMultTransposeSelfDense(m1, ret, leftTranspose, 0, ret.rlen ); + //core tsmm operation + matrixMultTransposeSelf(m1, ret, leftTranspose, 0, m1.rlen); //post-processing if(copyToLowerTriangle){ @@ -476,15 +477,18 @@ public static void matrixMultTransposeSelf(MatrixBlock m1, MatrixBlock ret, bool //Timing time = new Timing(true); //pre-processing (no need to check isThreadSafe) - ret.sparse = false; - ret.allocateDenseBlock(); - + double sp = m1.getSparsity(); + ret.sparse = !leftTranspose && m1.sparse && MatrixBlock.SPARSITY_TURN_POINT > + OptimizerUtils.getMatMultSparsity(sp, sp, m1.rlen, m1.clen, m1.rlen, false); + + ret.allocateBlock(); + //core multi-threaded matrix mult computation ExecutorService pool = CommonThreadPool.get(k); try { ArrayList tasks = new ArrayList<>(); - //load balance via #tasks=2k due to triangular shape - int blklen = (int)(Math.ceil((double)ret.rlen / (2 * k))); + //load balance via #tasks=4k due to triangular shape + int blklen = (int)(Math.ceil((double)ret.rlen / (4 * k))); for(int i = 0; i < ret.rlen; i += blklen) tasks.add(new MatrixMultTransposeTask(m1, ret, leftTranspose, i, Math.min(i+blklen, ret.rlen))); for( Future rtask : pool.invokeAll(tasks) ) @@ -500,7 +504,7 @@ public static void matrixMultTransposeSelf(MatrixBlock m1, MatrixBlock ret, bool //post-processing long nnz = copyUpperToLowerTriangle(ret); ret.setNonZeros(nnz); - ret.examSparsity(); + ret.examSparsity(); //System.out.println("TSMM k="+k+" ("+m1.isInSparseFormat()+","+m1.getNumRows()+","+m1.getNumColumns()+","+m1.getNonZeros()+","+leftTranspose+") in "+time.stop()); } @@ -2236,6 +2240,15 @@ private static void matrixMultTransposeSelfDense( MatrixBlock m1, MatrixBlock re } } + private static void matrixMultTransposeSelf(MatrixBlock m1, MatrixBlock ret, boolean leftTranspose, int rl, int ru) { + if(m1.sparse && ret.sparse) + matrixMultTransposeSelfUltraSparse(m1, ret, leftTranspose, rl, ru); + else if( m1.sparse ) + matrixMultTransposeSelfSparse(m1, ret, leftTranspose, rl, ru); + else + matrixMultTransposeSelfDense(m1, ret, leftTranspose, rl, ru ); + } + private static void matrixMultTransposeSelfSparse( MatrixBlock m1, MatrixBlock ret, boolean leftTranspose, int rl, int ru ) { //2) transpose self matrix multiply sparse // (compute only upper-triangular matrix due to symmetry) @@ -2357,6 +2370,46 @@ private static void matrixMultTransposeSelfSparse( MatrixBlock m1, MatrixBlock r } } } + + private static void matrixMultTransposeSelfUltraSparse( MatrixBlock m1, MatrixBlock ret, boolean leftTranspose, int rl, int ru ) { + if( leftTranspose ) + throw new DMLRuntimeException("Left tsmm with sparse output not supported"); + + // Operation X%*%t(X), sparse input and output + SparseBlock a = m1.sparseBlock; + SparseBlock c = ret.sparseBlock; + int m = m1.rlen; + + final int blocksize = 256; + for(int bi=rl; bibix[bsize-1] || aix[asize-1] bixk ) + k2++; + else { // === + v += a[k] * b[k2]; + k++; k2++; + } + //note: branchless version slower + //v += (aixk==bixk) ? a[k] * b[k2] : 0; + //k += (aixk <= bixk) ? 1 : 0; + //k2 += (aixk >= bixk) ? 1 : 0; + } + return v; + } //note: public for use by codegen for consistency public static void vectMultiplyAdd( final double aval, double[] b, double[] c, int bi, int ci, final int len ) @@ -4025,6 +4108,13 @@ private static double dotProductGeneric(MatrixBlock a, MatrixBlock b) return val; } + public static long copyUpperToLowerTriangle( MatrixBlock ret ) { + return ret.sparse ? + copyUpperToLowerTriangleSparse(ret) : + copyUpperToLowerTriangleDense(ret); + } + + /** * Used for all version of TSMM where the result is known to be symmetric. * Hence, we compute only the upper triangular matrix and copy this partial @@ -4033,7 +4123,7 @@ private static double dotProductGeneric(MatrixBlock a, MatrixBlock b) * @param ret matrix * @return number of non zeros */ - public static long copyUpperToLowerTriangle( MatrixBlock ret ) + public static long copyUpperToLowerTriangleDense( MatrixBlock ret ) { //ret is guaranteed to be a squared, symmetric matrix if( ret.rlen != ret.clen ) @@ -4074,18 +4164,56 @@ public static long copyUpperToLowerTriangle( MatrixBlock ret ) return nnz; } + public static long copyUpperToLowerTriangleSparse( MatrixBlock ret ) + { + //ret is guaranteed to be a squared, symmetric matrix + if( ret.rlen != ret.clen ) + throw new RuntimeException("Invalid non-squared input matrix."); + + SparseBlock c = ret.getSparseBlock(); + int n = ret.rlen; + long nnz = 0; + + //copy non-diagonal values from upper-triangular matrix + for(int i=0; i i ) { + c.append(cix[k], i, cvals[k]); + nnz += 2; + } + } + } + + //sort sparse rows (because append out of order) + c.sort(); + + return nnz; + } + public static MatrixBlock prepMatrixMultTransposeSelfInput( MatrixBlock m1, boolean leftTranspose, boolean par ) { MatrixBlock ret = m1; final int rlen = m1.rlen; final int clen = m1.clen; + double sp = m1.getSparsity(); + double osp = OptimizerUtils.getMatMultSparsity(sp, sp, m1.rlen, m1.clen, m1.rlen, false); + boolean retSparse = !leftTranspose && m1.sparse && osp < MatrixBlock.SPARSITY_TURN_POINT; - if( !leftTranspose && m1.sparse && rlen > 1) { //X%*%t(X) SPARSE MATRIX + if( !leftTranspose && !retSparse && m1.sparse && rlen > 1) { //X%*%t(X) SPARSE MATRIX //directly via LibMatrixReorg in order to prevent sparsity change MatrixBlock tmpBlock = new MatrixBlock(clen, rlen, m1.sparse); LibMatrixReorg.reorg(m1, tmpBlock, new ReorgOperator(SwapIndex.getSwapIndexFnObject())); ret = tmpBlock; } - else if( leftTranspose && m1.sparse && m1.sparseBlock instanceof SparseBlockCSR ) { + else if( leftTranspose && !retSparse && m1.sparse && m1.sparseBlock instanceof SparseBlockCSR ) { //for a special case of CSR inputs where all non-empty rows are dense, we can //create a shallow copy of the values arrays to a "dense" block and perform //tsmm with the existing dense block operations w/o unnecessary gather/scatter @@ -4158,7 +4286,7 @@ public static boolean satisfiesMultiThreadingConstraints(MatrixBlock m1, MatrixB (sharedTP ? PAR_MINFLOP_THRESHOLD2 : PAR_MINFLOP_THRESHOLD1)); } - private static boolean satisfiesMultiThreadingConstraintsTSMM(MatrixBlock m1, boolean leftTranspose, long FPfactor, int k) { + private static boolean satisfiesMultiThreadingConstraintsTSMM(MatrixBlock m1, boolean leftTranspose, double FPfactor, int k) { boolean sharedTP = (InfrastructureAnalyzer.getLocalParallelism() == k); double threshold = sharedTP ? PAR_MINFLOP_THRESHOLD2 : PAR_MINFLOP_THRESHOLD1; return k > 1 && LOW_LEVEL_OPTIMIZATION && (leftTranspose?m1.clen:m1.rlen)!=1 @@ -4425,10 +4553,7 @@ protected MatrixMultTransposeTask( MatrixBlock m1, MatrixBlock ret, boolean left @Override public Object call() { - if( _m1.sparse ) - matrixMultTransposeSelfSparse(_m1, _ret, _left, _rl, _ru); - else - matrixMultTransposeSelfDense(_m1, _ret, _left, _rl, _ru); + matrixMultTransposeSelf(_m1, _ret, _left, _rl, _ru); return null; } } diff --git a/src/test/java/org/apache/sysds/test/functions/binary/matrix/TransposeMatrixMultiplicationTest.java b/src/test/java/org/apache/sysds/test/functions/binary/matrix/TransposeMatrixMultiplicationTest.java index 6a73499f909..ba65889a958 100644 --- a/src/test/java/org/apache/sysds/test/functions/binary/matrix/TransposeMatrixMultiplicationTest.java +++ b/src/test/java/org/apache/sysds/test/functions/binary/matrix/TransposeMatrixMultiplicationTest.java @@ -285,5 +285,4 @@ private void runTransposeMatrixMultiplicationTest( boolean sparseM1, boolean spa DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld; } } - } \ No newline at end of file diff --git a/src/test/java/org/apache/sysds/test/functions/binary/matrix_full_other/FullMatrixMultiplicationTransposeSelfTest.java b/src/test/java/org/apache/sysds/test/functions/binary/matrix_full_other/FullMatrixMultiplicationTransposeSelfTest.java index 6413db4e063..1599bdac860 100644 --- a/src/test/java/org/apache/sysds/test/functions/binary/matrix_full_other/FullMatrixMultiplicationTransposeSelfTest.java +++ b/src/test/java/org/apache/sysds/test/functions/binary/matrix_full_other/FullMatrixMultiplicationTransposeSelfTest.java @@ -119,6 +119,10 @@ public void testVVRightSparseCP() { runTransposeSelfVectorMultiplicationTest(MMTSJType.RIGHT, ExecType.CP, true); } + @Test + public void testRightUltraSparseCP() { + runTransposeSelfUltraSparseTest(MMTSJType.RIGHT); + } private void runTransposeSelfMatrixMultiplicationTest( MMTSJType type, ExecType instType, boolean sparse ) { @@ -261,4 +265,38 @@ private void runTransposeSelfVectorMultiplicationTest( MMTSJType type, ExecType rtplatform = platformOld; } } + + private void runTransposeSelfUltraSparseTest( MMTSJType type ) + { + //rtplatform for MR + ExecMode platformOld = rtplatform; + rtplatform = ExecMode.SINGLE_NODE; + + try { + loadTestConfiguration(getTestConfiguration(TEST_NAME2)); + int dim = 10000; + + String HOME = SCRIPT_DIR + TEST_DIR; + fullDMLScriptName = HOME + TEST_NAME2 + ".dml"; + programArgs = new String[]{"-stats","-args", input("A"), + String.valueOf(dim), String.valueOf(dim), output("B") }; + fullRScriptName = HOME + TEST_NAME2 + ".R"; + rCmd = "Rscript" + " " + fullRScriptName + " " + inputDir() + " " + expectedDir(); + + //generate actual dataset + double[][] A = getRandomMatrix(dim, dim, 0, 1, 0.0002, 7); + writeInputMatrix("A", A, true); + + runTest(true, false, null, -1); + //runRScript(true); + + //compare matrices + //HashMap dmlfile = readDMLMatrixFromOutputDir("B"); + //HashMap rfile = readRMatrixFromExpectedDir("B"); + //TestUtils.compareMatrices(dmlfile, rfile, eps, "Stat-DML", "Stat-R"); + } + finally { + rtplatform = platformOld; + } + } } diff --git a/src/test/scripts/functions/binary/matrix_full_other/TransposeSelfMatrixMultiplication1.dml b/src/test/scripts/functions/binary/matrix_full_other/TransposeSelfMatrixMultiplication1.dml index 9e8779f41cb..562fafa7bfa 100644 --- a/src/test/scripts/functions/binary/matrix_full_other/TransposeSelfMatrixMultiplication1.dml +++ b/src/test/scripts/functions/binary/matrix_full_other/TransposeSelfMatrixMultiplication1.dml @@ -24,4 +24,4 @@ A = read($1, rows=$2, cols=$3, format="text"); B = t(A) %*% A; -write(B, $4, format="text"); \ No newline at end of file +write(B, $4, format="text"); diff --git a/src/test/scripts/functions/binary/matrix_full_other/TransposeSelfMatrixMultiplication2.dml b/src/test/scripts/functions/binary/matrix_full_other/TransposeSelfMatrixMultiplication2.dml index 4679990870c..91ecd00f69c 100644 --- a/src/test/scripts/functions/binary/matrix_full_other/TransposeSelfMatrixMultiplication2.dml +++ b/src/test/scripts/functions/binary/matrix_full_other/TransposeSelfMatrixMultiplication2.dml @@ -24,4 +24,4 @@ A = read($1, rows=$2, cols=$3, format="text"); B = A %*% t(A); -write(B, $4, format="text"); \ No newline at end of file +write(B, $4, format="text"); From ba20e11d3103df6e76c154e356a7fb89e84d8604 Mon Sep 17 00:00:00 2001 From: Matthias Boehm Date: Thu, 26 Oct 2023 22:08:20 +0200 Subject: [PATCH 10/28] [SYSTEMDS-3636] Fix arg passing in single-threaded tsmm kernel dispatch --- .../org/apache/sysds/runtime/matrix/data/LibMatrixMult.java | 2 +- .../java/org/apache/sysds/test/component/matrix/TSMMTest.java | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixMult.java b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixMult.java index 03dcbc359be..0f08176fe11 100644 --- a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixMult.java +++ b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixMult.java @@ -441,7 +441,7 @@ public static void matrixMultTransposeSelf(MatrixBlock m1, MatrixBlock ret, bool ret.allocateBlock(); //core tsmm operation - matrixMultTransposeSelf(m1, ret, leftTranspose, 0, m1.rlen); + matrixMultTransposeSelf(m1, ret, leftTranspose, 0, ret.rlen); //post-processing if(copyToLowerTriangle){ diff --git a/src/test/java/org/apache/sysds/test/component/matrix/TSMMTest.java b/src/test/java/org/apache/sysds/test/component/matrix/TSMMTest.java index e99f4803dd1..9e464077d05 100644 --- a/src/test/java/org/apache/sysds/test/component/matrix/TSMMTest.java +++ b/src/test/java/org/apache/sysds/test/component/matrix/TSMMTest.java @@ -90,7 +90,7 @@ public void testTSMMLeftSparseVSDense() { final MMTSJType mType = MMTSJType.LEFT; final MatrixBlock expected = in.transposeSelfMatrixMultOperations(null, mType, 1); - if(k > 1) // test multithread + if(k > 1) // test multi-threaded testCompare(expected, in, "Compare single vs multithread"); final boolean isSparse = in.isInSparseFormat(); From 15bef0394d71eefee98769a1f50eb3dadb336aaa Mon Sep 17 00:00:00 2001 From: Matthias Boehm Date: Fri, 27 Oct 2023 14:28:01 +0200 Subject: [PATCH 11/28] [SYSTEMDS-3636] Fix new ultra-sparse tsmm right, and new tests --- .../runtime/matrix/data/LibMatrixMult.java | 21 ++++--- ...MatrixMultiplicationTransposeSelfTest.java | 55 +++++++++---------- 2 files changed, 35 insertions(+), 41 deletions(-) diff --git a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixMult.java b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixMult.java index 0f08176fe11..0b8bd216f42 100644 --- a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixMult.java +++ b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixMult.java @@ -435,9 +435,7 @@ public static void matrixMultTransposeSelf(MatrixBlock m1, MatrixBlock ret, bool //Timing time = new Timing(true); //pre-processing - double sp = m1.getSparsity(); - double osp = OptimizerUtils.getMatMultSparsity(sp, sp, m1.rlen, m1.clen, m1.rlen, false); - ret.sparse = !leftTranspose && m1.sparse && osp < MatrixBlock.SPARSITY_TURN_POINT; + ret.sparse = isSparseOutputTSMM(m1, leftTranspose); ret.allocateBlock(); //core tsmm operation @@ -477,10 +475,7 @@ public static void matrixMultTransposeSelf(MatrixBlock m1, MatrixBlock ret, bool //Timing time = new Timing(true); //pre-processing (no need to check isThreadSafe) - double sp = m1.getSparsity(); - ret.sparse = !leftTranspose && m1.sparse && MatrixBlock.SPARSITY_TURN_POINT > - OptimizerUtils.getMatMultSparsity(sp, sp, m1.rlen, m1.clen, m1.rlen, false); - + ret.sparse = isSparseOutputTSMM(m1, leftTranspose); ret.allocateBlock(); //core multi-threaded matrix mult computation @@ -3549,7 +3544,7 @@ private static double dotProduct(double[] a, int[] aix, final int apos, final in double v = 0; while( k bixk ) @@ -4203,9 +4198,7 @@ public static MatrixBlock prepMatrixMultTransposeSelfInput( MatrixBlock m1, bool MatrixBlock ret = m1; final int rlen = m1.rlen; final int clen = m1.clen; - double sp = m1.getSparsity(); - double osp = OptimizerUtils.getMatMultSparsity(sp, sp, m1.rlen, m1.clen, m1.rlen, false); - boolean retSparse = !leftTranspose && m1.sparse && osp < MatrixBlock.SPARSITY_TURN_POINT; + boolean retSparse = isSparseOutputTSMM(m1, leftTranspose); if( !leftTranspose && !retSparse && m1.sparse && rlen > 1) { //X%*%t(X) SPARSE MATRIX //directly via LibMatrixReorg in order to prevent sparsity change @@ -4323,6 +4316,12 @@ public static boolean isSparseOutputMatrixMult(MatrixBlock m1, MatrixBlock m2) { boolean sparseOut = MatrixBlock.evalSparseFormatInMemory(m1.rlen, m2.clen, estNnz); return m2.clen < 4*1024 && sparseOut; } + + public static boolean isSparseOutputTSMM(MatrixBlock m1, boolean leftTranspose) { + double sp = m1.getSparsity(); + double osp = OptimizerUtils.getMatMultSparsity(sp, sp, m1.rlen, m1.clen, m1.rlen, false); + return !leftTranspose && m1.sparse && osp < MatrixBlock.ULTRA_SPARSITY_TURN_POINT2; + } public static boolean isOuterProductTSMM(int rlen, int clen, boolean left) { return left ? rlen == 1 & clen > 1 : rlen > 1 & clen == 1; diff --git a/src/test/java/org/apache/sysds/test/functions/binary/matrix_full_other/FullMatrixMultiplicationTransposeSelfTest.java b/src/test/java/org/apache/sysds/test/functions/binary/matrix_full_other/FullMatrixMultiplicationTransposeSelfTest.java index 1599bdac860..3d1b4c239f2 100644 --- a/src/test/java/org/apache/sysds/test/functions/binary/matrix_full_other/FullMatrixMultiplicationTransposeSelfTest.java +++ b/src/test/java/org/apache/sysds/test/functions/binary/matrix_full_other/FullMatrixMultiplicationTransposeSelfTest.java @@ -22,11 +22,16 @@ import java.util.HashMap; import org.junit.AfterClass; +import org.junit.Assert; import org.junit.BeforeClass; import org.junit.Test; import org.apache.sysds.common.Types.ExecMode; import org.apache.sysds.common.Types.ExecType; import org.apache.sysds.lops.MMTSJ.MMTSJType; +import org.apache.sysds.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer; +import org.apache.sysds.runtime.matrix.data.LibMatrixMult; +import org.apache.sysds.runtime.matrix.data.LibMatrixReorg; +import org.apache.sysds.runtime.matrix.data.MatrixBlock; import org.apache.sysds.runtime.matrix.data.MatrixValue.CellIndex; import org.apache.sysds.test.AutomatedTestBase; import org.apache.sysds.test.TestConfiguration; @@ -268,35 +273,25 @@ private void runTransposeSelfVectorMultiplicationTest( MMTSJType type, ExecType private void runTransposeSelfUltraSparseTest( MMTSJType type ) { - //rtplatform for MR - ExecMode platformOld = rtplatform; - rtplatform = ExecMode.SINGLE_NODE; - - try { - loadTestConfiguration(getTestConfiguration(TEST_NAME2)); - int dim = 10000; - - String HOME = SCRIPT_DIR + TEST_DIR; - fullDMLScriptName = HOME + TEST_NAME2 + ".dml"; - programArgs = new String[]{"-stats","-args", input("A"), - String.valueOf(dim), String.valueOf(dim), output("B") }; - fullRScriptName = HOME + TEST_NAME2 + ".R"; - rCmd = "Rscript" + " " + fullRScriptName + " " + inputDir() + " " + expectedDir(); - - //generate actual dataset - double[][] A = getRandomMatrix(dim, dim, 0, 1, 0.0002, 7); - writeInputMatrix("A", A, true); - - runTest(true, false, null, -1); - //runRScript(true); - - //compare matrices - //HashMap dmlfile = readDMLMatrixFromOutputDir("B"); - //HashMap rfile = readRMatrixFromExpectedDir("B"); - //TestUtils.compareMatrices(dmlfile, rfile, eps, "Stat-DML", "Stat-R"); - } - finally { - rtplatform = platformOld; - } + //compare sparse tsmm and gemm directly to avoid unnecessary overhead (e.g., R) + int dim = 10000; + + MatrixBlock G = MatrixBlock.randOperations(dim, dim, 0.0002, 0, 1, "uniform", 7); + MatrixBlock Gt = LibMatrixReorg.transpose(G); + MatrixBlock Gtt = LibMatrixReorg.transpose(Gt); + TestUtils.compareMatrices(G, Gtt, 1e-16); + + //single-threaded core operations + MatrixBlock R11 = G.transposeSelfMatrixMultOperations(new MatrixBlock(), MMTSJType.RIGHT); + MatrixBlock R12 = LibMatrixMult.matrixMult(G, Gt); + Assert.assertEquals(R11.getNonZeros(), R12.getNonZeros()); + TestUtils.compareMatrices(R11, R12, 1e-8); + + //multi-threaded core operations + int k = InfrastructureAnalyzer.getLocalParallelism(); + MatrixBlock R21 = G.transposeSelfMatrixMultOperations(new MatrixBlock(), MMTSJType.RIGHT, k); + MatrixBlock R22 = LibMatrixMult.matrixMult(G, Gt, k); + Assert.assertEquals(R21.getNonZeros(), R22.getNonZeros()); + TestUtils.compareMatrices(R21, R22, 1e-8); } } From a997b4194c81d6026780da4d0463d62ac5c34c59 Mon Sep 17 00:00:00 2001 From: Matthias Boehm Date: Fri, 27 Oct 2023 21:15:21 +0200 Subject: [PATCH 12/28] [SYSTEMDS-3636] Alternative ultra-sparse tsmm kernel (still disabled) --- .../runtime/matrix/data/LibMatrixMult.java | 41 +++++++++++++++++++ 1 file changed, 41 insertions(+) diff --git a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixMult.java b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixMult.java index 0b8bd216f42..6fda33ad096 100644 --- a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixMult.java +++ b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixMult.java @@ -2405,6 +2405,47 @@ private static void matrixMultTransposeSelfUltraSparse( MatrixBlock m1, MatrixBl } } } + + //alternative matrixMultTransposeSelfUltraSparse2 w/ IKJ iteration order and dense buffering + //(for moderately large graphs 4x improvement compared to above, but for large graphs slower -> non-conclusive) + @SuppressWarnings("unused") + private static void matrixMultTransposeSelfUltraSparse2( MatrixBlock m1, MatrixBlock m1t, MatrixBlock ret, boolean leftTranspose, int rl, int ru ) { + if( leftTranspose ) + throw new DMLRuntimeException("Left tsmm with sparse output not supported"); + + // Operation X%*%t(X), sparse input and output + SparseBlock a = m1.sparseBlock; + SparseBlock b = m1t.sparseBlock; + SparseBlock c = ret.sparseBlock; + int m = m1.rlen; + double[] tmp = new double[m]; + + for(int i=rl; i Date: Fri, 27 Oct 2023 21:40:23 +0200 Subject: [PATCH 13/28] [MINOR] Robustness random forest for very small sampling fractions --- scripts/builtin/randomForest.dml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/scripts/builtin/randomForest.dml b/scripts/builtin/randomForest.dml index ccebd59d868..8daeb5bc7f0 100644 --- a/scripts/builtin/randomForest.dml +++ b/scripts/builtin/randomForest.dml @@ -110,6 +110,8 @@ m_randomForest = function(Matrix[Double] X, Matrix[Double] y, Matrix[Double] cty if( sample_frac < 1.0 ) { si1 = as.integer(as.scalar(randSeeds[3*(i-1)+1,1])); I1 = rand(rows=nrow(X), cols=1, seed=si1) <= sample_frac; + if( sum(I1) <= 1 ) # min 2 tuples + I1[1:2,] = matrix(1,2,1); Xi = removeEmpty(target=X, margin="rows", select=I1); yi = removeEmpty(target=y, margin="rows", select=I1); } From 9426792b009b638667a8415c58552945e1be3d1b Mon Sep 17 00:00:00 2001 From: Sebastian Baunsgaard Date: Mon, 30 Oct 2023 12:22:00 +0100 Subject: [PATCH 14/28] [MINOR] Remove potential for compression Scalars --- .../apache/sysds/hops/rewrite/RewriteCompressedReblock.java | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/main/java/org/apache/sysds/hops/rewrite/RewriteCompressedReblock.java b/src/main/java/org/apache/sysds/hops/rewrite/RewriteCompressedReblock.java index 8dd323dd44e..ec917b01458 100644 --- a/src/main/java/org/apache/sysds/hops/rewrite/RewriteCompressedReblock.java +++ b/src/main/java/org/apache/sysds/hops/rewrite/RewriteCompressedReblock.java @@ -156,7 +156,7 @@ else if(hop.getDim1() >= 1){ public static boolean satisfiesCompressionCondition(Hop hop) { boolean satisfies = false; if(satisfiesSizeConstraintsForCompression(hop)){ - satisfies |= HopRewriteUtils.isData(hop, OpOpData.PERSISTENTREAD); + satisfies |= HopRewriteUtils.isData(hop, OpOpData.PERSISTENTREAD) && !hop.isScalar(); satisfies |= HopRewriteUtils.isTransformEncode(hop); } return satisfies; @@ -171,7 +171,7 @@ public static boolean satisfiesAggressiveCompressionCondition(Hop hop) { satisfies |= HopRewriteUtils.isTernary(hop, OpOp3.CTABLE) && hop.getInput(0).getDataType().isMatrix() && hop.getInput(1).getDataType().isMatrix(); - satisfies |= HopRewriteUtils.isData(hop, OpOpData.PERSISTENTREAD); + satisfies |= HopRewriteUtils.isData(hop, OpOpData.PERSISTENTREAD) && !hop.isScalar(); satisfies |= HopRewriteUtils.isUnary(hop, OpOp1.ROUND, OpOp1.FLOOR, OpOp1.NOT, OpOp1.CEIL); satisfies |= HopRewriteUtils.isBinary(hop, OpOp2.EQUAL, OpOp2.NOTEQUAL, OpOp2.LESS, OpOp2.LESSEQUAL, OpOp2.GREATER, OpOp2.GREATEREQUAL, OpOp2.AND, OpOp2.OR, OpOp2.MODULUS); From 7f04d1642c1a679457c8dc4d6f9003e5e2fc4bf3 Mon Sep 17 00:00:00 2001 From: Sebastian Baunsgaard Date: Mon, 30 Oct 2023 12:24:06 +0100 Subject: [PATCH 15/28] [MINOR] Filter pre-aggregate warning In compressed linear algebra, we print a warning in case of uncompressed matrix multiplication. This commit filters that error out if the input is one column. The one-column case is special since transposition is a no-op that only touches metadata. Therefore, we filter this error. Also introduced in this commit is an error where we try to allocate a pre-aggregate output larger than Integer.MAX_VALUE. This happens in cases where the number of columns in a single-column group is large, such as in a recode-bin encoding scenario of transform encoding. --- .../sysds/runtime/compress/colgroup/APreAgg.java | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/APreAgg.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/APreAgg.java index 17f210865be..8b8a7b7df02 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/APreAgg.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/APreAgg.java @@ -19,6 +19,7 @@ package org.apache.sysds.runtime.compress.colgroup; +import org.apache.commons.lang.NotImplementedException; import org.apache.sysds.runtime.DMLRuntimeException; import org.apache.sysds.runtime.compress.DMLCompressionException; import org.apache.sysds.runtime.compress.colgroup.dictionary.IDictionary; @@ -84,9 +85,11 @@ else if(lhs instanceof ColGroupUncompressed) * @return A aggregate dictionary */ public final IDictionary preAggregateThatIndexStructure(APreAgg that) { - int outputLength = that._colIndexes.size() * this.getNumValues(); + long outputLength = (long)that._colIndexes.size() * this.getNumValues(); + if(outputLength > Integer.MAX_VALUE) + throw new NotImplementedException("Not supported pre aggregate of above integer length"); // create empty Dictionary that we slowly fill, hence the dictionary is empty and no check - final Dictionary ret = Dictionary.createNoCheck(new double[outputLength]); + final Dictionary ret = Dictionary.createNoCheck(new double[(int)outputLength]); if(that instanceof ColGroupDDC) preAggregateThatDDCStructure((ColGroupDDC) that, ret); @@ -224,7 +227,8 @@ else if(shouldPreAggregateLeft(lhs)) {// left preAgg } private void leftMultByUncompressedColGroup(ColGroupUncompressed lhs, MatrixBlock result) { - LOG.warn("Transpose of uncompressed to fit to template need t(a) %*% b"); + if(lhs.getNumCols() != 1) + LOG.warn("Transpose of uncompressed to fit to template need t(a) %*% b"); final MatrixBlock tmp = LibMatrixReorg.transpose(lhs.getData(), InfrastructureAnalyzer.getLocalParallelism()); final int numVals = getNumValues(); final MatrixBlock preAgg = new MatrixBlock(tmp.getNumRows(), numVals, false); From c21fa9997deadc7534b40b3b303a445b3c68c630 Mon Sep 17 00:00:00 2001 From: Sebastian Baunsgaard Date: Mon, 30 Oct 2023 12:34:43 +0100 Subject: [PATCH 16/28] [SYSTEMDS-3642] CLA NaN in Dictionaries replace This commit fixes a bug of replace in ColumnGroups that did not correctly replace NaN values with replacement values. Example: X_test = replace(target=X_test, pattern=NaN, replacement=0); --- .../compress/colgroup/ColGroupDDC.java | 40 ++++++++++++++++--- .../compress/colgroup/ColGroupDDCFOR.java | 5 ++- .../compress/colgroup/ColGroupSDC.java | 3 +- .../compress/colgroup/ColGroupSDCFOR.java | 3 +- .../compress/colgroup/ColGroupSDCSingle.java | 3 +- .../colgroup/ColGroupUncompressed.java | 14 ++++--- .../compress/colgroup/mapping/AMapToData.java | 13 ++++++ 7 files changed, 66 insertions(+), 15 deletions(-) diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupDDC.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupDDC.java index 8f5fccaf7d6..6340affede3 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupDDC.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupDDC.java @@ -35,6 +35,8 @@ import org.apache.sysds.runtime.compress.colgroup.indexes.ColIndexFactory; import org.apache.sysds.runtime.compress.colgroup.indexes.IColIndex; import org.apache.sysds.runtime.compress.colgroup.mapping.AMapToData; +import org.apache.sysds.runtime.compress.colgroup.mapping.MapToByte; +import org.apache.sysds.runtime.compress.colgroup.mapping.MapToChar; import org.apache.sysds.runtime.compress.colgroup.mapping.MapToFactory; import org.apache.sysds.runtime.compress.colgroup.offset.AOffsetIterator; import org.apache.sysds.runtime.compress.colgroup.scheme.DDCScheme; @@ -78,8 +80,8 @@ private ColGroupDDC(IColIndex colIndexes, IDictionary dict, AMapToData data, int int[] c = getCounts(); if(c.length != dict.getNumberOfValues(colIndexes.size())) throw new DMLCompressionException("Invalid DDC Construction"); + data.verify(); } - } public static AColGroup create(IColIndex colIndexes, IDictionary dict, AMapToData data, int[] cachedCounts) { @@ -157,8 +159,37 @@ public AMapToData getMapToData() { private final void decompressToDenseBlockDenseDictSingleColOutContiguous(DenseBlock db, int rl, int ru, int offR, int offC, double[] values) { final double[] c = db.values(0); - for(int i = rl, offT = rl + offR + _colIndexes.get(0) + offC; i < ru; i++, offT++) - c[offT] += values[_data.getIndex(i)]; + decompressToDenseBlockDenseDictSingleColOutContiguous(c, rl, ru, offR + _colIndexes.get(0), values, _data); + } + + private final static void decompressToDenseBlockDenseDictSingleColOutContiguous(double[] c, int rl, int ru, int offR, + double[] values, AMapToData data) { + + if(data instanceof MapToByte) + decompressToDenseBlockDenseDictSingleColOutContiguousByteM(c, rl, ru, offR, values, (MapToByte) data); + else if(data instanceof MapToChar) + decompressToDenseBlockDenseDictSingleColOutContiguousCharM(c, rl, ru, offR, values, (MapToChar) data); + else + decompressToDenseBlockDenseDictSingleColOutContiguousGenM(c, rl, ru, offR, values, data); + + } + + private final static void decompressToDenseBlockDenseDictSingleColOutContiguousByteM(double[] c, int rl, int ru, + int offR, double[] values, MapToByte data) { + for(int i = rl, offT = rl + offR; i < ru; i++, offT++) + c[offT] += values[data.getIndex(i)]; + } + + private final static void decompressToDenseBlockDenseDictSingleColOutContiguousCharM(double[] c, int rl, int ru, + int offR, double[] values, MapToChar data) { + for(int i = rl, offT = rl + offR; i < ru; i++, offT++) + c[offT] += values[data.getIndex(i)]; + } + + private final static void decompressToDenseBlockDenseDictSingleColOutContiguousGenM(double[] c, int rl, int ru, + int offR, double[] values, AMapToData data) { + for(int i = rl, offT = rl + offR; i < ru; i++, offT++) + c[offT] += values[data.getIndex(i)]; } private final void decompressToDenseBlockDenseDictAllColumnsContiguous(DenseBlock db, int rl, int ru, int offR, @@ -287,8 +318,7 @@ private void leftMultByMatrixNoPreAggSingleCol(MatrixBlock matrix, MatrixBlock r lmSparseMatrixNoPreAggSingleCol(matrix.getSparseBlock(), nColM, retV, nColRet, dictVals, rl, ru); } else - lmDenseMatrixNoPreAggSingleCol(matrix.getDenseBlockValues(), nColM, retV, nColRet, dictVals, rl, ru, cl, - cu); + lmDenseMatrixNoPreAggSingleCol(matrix.getDenseBlockValues(), nColM, retV, nColRet, dictVals, rl, ru, cl, cu); } private void lmSparseMatrixNoPreAggSingleCol(SparseBlock sb, int nColM, double[] retV, int nColRet, double[] vals, diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupDDCFOR.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupDDCFOR.java index 85f48ae5b6f..d09ba4e624d 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupDDCFOR.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupDDCFOR.java @@ -26,8 +26,8 @@ import org.apache.commons.lang3.NotImplementedException; import org.apache.sysds.runtime.DMLRuntimeException; -import org.apache.sysds.runtime.compress.colgroup.dictionary.IDictionary; import org.apache.sysds.runtime.compress.colgroup.dictionary.DictionaryFactory; +import org.apache.sysds.runtime.compress.colgroup.dictionary.IDictionary; import org.apache.sysds.runtime.compress.colgroup.dictionary.MatrixBlockDictionary; import org.apache.sysds.runtime.compress.colgroup.indexes.ColIndexFactory; import org.apache.sysds.runtime.compress.colgroup.indexes.IColIndex; @@ -39,6 +39,7 @@ import org.apache.sysds.runtime.compress.estim.EstimationFactors; import org.apache.sysds.runtime.compress.estim.encoding.EncodingFactory; import org.apache.sysds.runtime.compress.estim.encoding.IEncode; +import org.apache.sysds.runtime.compress.utils.Util; import org.apache.sysds.runtime.functionobjects.Builtin; import org.apache.sysds.runtime.functionobjects.Divide; import org.apache.sysds.runtime.functionobjects.Minus; @@ -251,7 +252,7 @@ public AColGroup replace(double pattern, double replace) { if(patternInReference) { double[] nRef = new double[_reference.length]; for(int i = 0; i < _reference.length; i++) - if(pattern == _reference[i]) + if(Util.eq(pattern ,_reference[i])) nRef[i] = replace; else nRef[i] = _reference[i]; diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDC.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDC.java index 476e86c9730..a905e401e42 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDC.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDC.java @@ -42,6 +42,7 @@ import org.apache.sysds.runtime.compress.cost.ComputationCostEstimator; import org.apache.sysds.runtime.compress.estim.encoding.EncodingFactory; import org.apache.sysds.runtime.compress.estim.encoding.IEncode; +import org.apache.sysds.runtime.compress.utils.Util; import org.apache.sysds.runtime.functionobjects.Builtin; import org.apache.sysds.runtime.instructions.cp.CM_COV_Object; import org.apache.sysds.runtime.matrix.data.MatrixBlock; @@ -458,7 +459,7 @@ public AColGroup replace(double pattern, double replace) { IDictionary replaced = _dict.replace(pattern, replace, _colIndexes.size()); double[] newDefaultTuple = new double[_defaultTuple.length]; for(int i = 0; i < _defaultTuple.length; i++) - newDefaultTuple[i] = _defaultTuple[i] == pattern ? replace : _defaultTuple[i]; + newDefaultTuple[i] = Util.eq(_defaultTuple[i],pattern) ? replace : _defaultTuple[i]; return create(_colIndexes, _numRows, replaced, newDefaultTuple, _indexes, _data, getCachedCounts()); } diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCFOR.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCFOR.java index 65a5a42aa40..dfb9a605118 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCFOR.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCFOR.java @@ -41,6 +41,7 @@ import org.apache.sysds.runtime.compress.cost.ComputationCostEstimator; import org.apache.sysds.runtime.compress.estim.encoding.EncodingFactory; import org.apache.sysds.runtime.compress.estim.encoding.IEncode; +import org.apache.sysds.runtime.compress.utils.Util; import org.apache.sysds.runtime.functionobjects.Builtin; import org.apache.sysds.runtime.functionobjects.Divide; import org.apache.sysds.runtime.functionobjects.Minus; @@ -260,7 +261,7 @@ public AColGroup replace(double pattern, double replace) { if(patternInReference) { double[] nRef = new double[_reference.length]; for(int i = 0; i < _reference.length; i++) - if(pattern == _reference[i]) + if(Util.eq(pattern, _reference[i])) nRef[i] = replace; else nRef[i] = _reference[i]; diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCSingle.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCSingle.java index 5d3be0d3f11..a13150c12c6 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCSingle.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCSingle.java @@ -41,6 +41,7 @@ import org.apache.sysds.runtime.compress.cost.ComputationCostEstimator; import org.apache.sysds.runtime.compress.estim.encoding.EncodingFactory; import org.apache.sysds.runtime.compress.estim.encoding.IEncode; +import org.apache.sysds.runtime.compress.utils.Util; import org.apache.sysds.runtime.functionobjects.Builtin; import org.apache.sysds.runtime.instructions.cp.CM_COV_Object; import org.apache.sysds.runtime.matrix.operators.BinaryOperator; @@ -423,7 +424,7 @@ public AColGroup replace(double pattern, double replace) { IDictionary replaced = _dict.replace(pattern, replace, _colIndexes.size()); double[] newDefaultTuple = new double[_defaultTuple.length]; for(int i = 0; i < _defaultTuple.length; i++) - newDefaultTuple[i] = _defaultTuple[i] == pattern ? replace : _defaultTuple[i]; + newDefaultTuple[i] = Util.eq(_defaultTuple[i], pattern) ? replace : _defaultTuple[i]; return create(_colIndexes, _numRows, replaced, newDefaultTuple, _indexes, getCachedCounts()); } diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupUncompressed.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupUncompressed.java index 17e954b2197..c4713d6e59a 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupUncompressed.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupUncompressed.java @@ -568,9 +568,11 @@ private void leftMultByAPreAggColGroup(APreAgg paCG, MatrixBlock result) { final MatrixBlock dictM = paCG._dict.getMBDict(nCols).getMatrixBlock(); if(dictM == null) return; - LOG.warn("\nInefficient transpose of uncompressed to fit to" - + " t(AColGroup) %*% UncompressedColGroup mult by colGroup uncompressed column" - + "\nCurrently solved by t(t(Uncompressed) %*% AColGroup)"); + if(paCG.getNumCols() != 1) { + LOG.warn("\nInefficient transpose of uncompressed to fit to" + + " t(AColGroup) %*% UncompressedColGroup mult by colGroup uncompressed column" + + "\nCurrently solved by t(t(Uncompressed) %*% AColGroup)"); + } final int k = InfrastructureAnalyzer.getLocalParallelism(); final MatrixBlock ucCGT = LibMatrixReorg.transpose(getData(), k); final MatrixBlock preAgg = new MatrixBlock(1, paCG.getNumValues(), false); @@ -606,10 +608,12 @@ private void leftMultByAPreAggColGroup(APreAgg paCG, MatrixBlock result) { } private void leftMultByAColGroupUncompressed(ColGroupUncompressed lhs, MatrixBlock result) { - LOG.warn("Inefficient Left Matrix Multiplication with transpose of left hand side : t(l) %*% r"); final MatrixBlock tmpRet = new MatrixBlock(lhs.getNumCols(), _colIndexes.size(), 0); final int k = InfrastructureAnalyzer.getLocalParallelism(); - + + if(lhs._data.getNumColumns() != 1){ + LOG.warn("Inefficient Left Matrix Multiplication with transpose of left hand side : t(l) %*% r"); + } // multiply to temp MatrixBlock lhData = lhs._data; MatrixBlock transposed = LibMatrixReorg.transpose(lhData, k); diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/AMapToData.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/AMapToData.java index d22234945d4..b12461bf7c4 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/AMapToData.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/AMapToData.java @@ -27,6 +27,8 @@ import org.apache.commons.lang3.NotImplementedException; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; +import org.apache.sysds.runtime.compress.CompressedMatrixBlock; +import org.apache.sysds.runtime.compress.DMLCompressionException; import org.apache.sysds.runtime.compress.colgroup.IMapToDataGroup; import org.apache.sysds.runtime.compress.colgroup.dictionary.Dictionary; import org.apache.sysds.runtime.compress.colgroup.dictionary.IDictionary; @@ -860,6 +862,17 @@ public boolean equals(Object e) { */ public abstract boolean equals(AMapToData e); + /** Debugging verification that this mapping is correctly made. */ + public void verify() { + if(CompressedMatrixBlock.debug) { + for(int i = 0; i < size(); i++) { + if(getIndex(i) >= nUnique) { + throw new DMLCompressionException("invalid construction of Mapping data containing values above unique"); + } + } + } + } + @Override public String toString() { final int sz = size(); From a826c10a5149f139918395151ce6d573a97dd663 Mon Sep 17 00:00:00 2001 From: Sebastian Baunsgaard Date: Mon, 30 Oct 2023 13:24:32 +0100 Subject: [PATCH 17/28] [MINOR] JIT optimize LMM Pre-aggregate Because of abstract classes the efficiency of the JIT compiler is subpar in the AMapToData instance. To improve this i have added individual overwritten instructions in some of the Map types. This duplicate code, but improve performance by 30-50% according to the profiler. --- .../compress/colgroup/mapping/AMapToData.java | 85 ++++++++++++------- .../compress/colgroup/mapping/MapToByte.java | 27 +++--- .../compress/colgroup/mapping/MapToChar.java | 52 +++++++++--- .../colgroup/mapping/MapToCharPByte.java | 23 +++++ .../compress/colgroup/mapping/MapToInt.java | 28 +++--- .../compress/colgroup/mapping/MapToUByte.java | 28 +++--- 6 files changed, 167 insertions(+), 76 deletions(-) diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/AMapToData.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/AMapToData.java index b12461bf7c4..b66c7ddb877 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/AMapToData.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/AMapToData.java @@ -129,8 +129,8 @@ public void set(int n, Integer v) { * * @param n index to set. * @param v the value to set it to. - * @return v as encoded, note this value can be different that the one put in if the map is not able to represent - * the value + * @return v as encoded, note this value can be different that the one put in if the map is not able to represent the + * value */ public abstract int setAndGet(int n, int v); @@ -235,16 +235,19 @@ protected void preAggregateDenseToRowBy8(double[] mV, double[] preAV, int cl, in off += cl; for(int rc = cl; rc < cl + h; rc++, off++) preAV[getIndex(rc)] += mV[off]; - for(int rc = cl + h; rc < cu; rc += 8, off += 8) { - preAV[getIndex(rc)] += mV[off]; - preAV[getIndex(rc + 1)] += mV[off + 1]; - preAV[getIndex(rc + 2)] += mV[off + 2]; - preAV[getIndex(rc + 3)] += mV[off + 3]; - preAV[getIndex(rc + 4)] += mV[off + 4]; - preAV[getIndex(rc + 5)] += mV[off + 5]; - preAV[getIndex(rc + 6)] += mV[off + 6]; - preAV[getIndex(rc + 7)] += mV[off + 7]; - } + for(int rc = cl + h; rc < cu; rc += 8, off += 8) + preAggregateDenseToRowVec8(mV, preAV, rc, off); + } + + protected void preAggregateDenseToRowVec8(double[] mV, double[] preAV, int rc, int off){ + preAV[getIndex(rc)] += mV[off]; + preAV[getIndex(rc + 1)] += mV[off + 1]; + preAV[getIndex(rc + 2)] += mV[off + 2]; + preAV[getIndex(rc + 3)] += mV[off + 3]; + preAV[getIndex(rc + 4)] += mV[off + 4]; + preAV[getIndex(rc + 5)] += mV[off + 5]; + preAV[getIndex(rc + 6)] += mV[off + 6]; + preAV[getIndex(rc + 7)] += mV[off + 7]; } /** @@ -329,8 +332,7 @@ protected void preAggregateDenseMultiRowContiguousBy1(double[] mV, int nCol, int * @param cu The column in m to end at (not inclusive) * @param indexes The Offset Indexes to iterate through */ - public final void preAggregateDense(MatrixBlock m, double[] preAV, int rl, int ru, int cl, int cu, - AOffset indexes) { + public final void preAggregateDense(MatrixBlock m, double[] preAV, int rl, int ru, int cl, int cu, AOffset indexes) { indexes.preAggregateDenseMap(m, preAV, rl, ru, cl, cu, getUnique(), this); } @@ -417,6 +419,8 @@ public final int[] getCounts() { * @param nCol The number of columns */ public final void preAggregateDDC_DDC(AMapToData tm, IDictionary td, Dictionary ret, int nCol) { + if(td.getNumberOfValues(nCol) != tm.nUnique) + throw new DMLCompressionException("Invalid map and dict combination"); if(nCol == 1) preAggregateDDC_DDCSingleCol(tm, td.getValues(), ret.getValues()); else @@ -431,31 +435,55 @@ public final void preAggregateDDC_DDC(AMapToData tm, IDictionary td, Dictionary * @param ret The output dictionary to aggregate into */ protected void preAggregateDDC_DDCSingleCol(AMapToData tm, double[] td, double[] v) { + final int sz = size(); - for(int r = 0; r < sz; r++) + final int h = sz % 8; + for(int r = 0; r < h; r++) v[getIndex(r)] += td[tm.getIndex(r)]; + for(int r = h; r < sz; r += 8) + preAggregateDDC_DDCSingleCol_vec(tm, td, v, r); + + } + + protected void preAggregateDDC_DDCSingleCol_vec(AMapToData tm, double[] td, double[] v, int r) { + final int r2 = r + 1, r3 = r + 2, r4 = r + 3, r5 = r + 4, r6 = r + 5, r7 = r + 6, r8 = r + 7; + v[getIndex(r)] += td[tm.getIndex(r)]; + v[getIndex(r2)] += td[tm.getIndex(r2)]; + v[getIndex(r3)] += td[tm.getIndex(r3)]; + v[getIndex(r4)] += td[tm.getIndex(r4)]; + v[getIndex(r5)] += td[tm.getIndex(r5)]; + v[getIndex(r6)] += td[tm.getIndex(r6)]; + v[getIndex(r7)] += td[tm.getIndex(r7)]; + v[getIndex(r8)] += td[tm.getIndex(r8)]; } /** * PreAggregate into dictionary with two sides of DDC guaranteed to multiple column tuples. * - * @param tm Map of other side + * @param tm Map of other side that indicate the indexes to take out and put into ret * @param td Dictionary to take values from (other side dictionary) * @param ret The output dictionary to aggregate into - * @param nCol The number of columns + * @param nCol The number of columns in td */ - protected void preAggregateDDC_DDCMultiCol(AMapToData tm, IDictionary td, double[] v, int nCol) { + protected void preAggregateDDC_DDCMultiCol(final AMapToData tm, final IDictionary td, final double[] v, + final int nCol) { + final int sz = size(); final int h = sz % 8; for(int r = 0; r < h; r++) td.addToEntry(v, tm.getIndex(r), getIndex(r), nCol); + for(int r = h; r < sz; r += 8) + preAggregateDDC_DDCMultiCol_vec(tm, td, v, nCol, r); - for(int r = h; r < sz; r += 8) { - int r2 = r + 1, r3 = r + 2, r4 = r + 3, r5 = r + 4, r6 = r + 5, r7 = r + 6, r8 = r + 7; - td.addToEntryVectorized(v, tm.getIndex(r), tm.getIndex(r2), tm.getIndex(r3), tm.getIndex(r4), - tm.getIndex(r5), tm.getIndex(r6), tm.getIndex(r7), tm.getIndex(r8), getIndex(r), getIndex(r2), - getIndex(r3), getIndex(r4), getIndex(r5), getIndex(r6), getIndex(r7), getIndex(r8), nCol); - } + } + + protected void preAggregateDDC_DDCMultiCol_vec(final AMapToData tm, final IDictionary td, final double[] v, + final int nCol, final int r) { + final int r2 = r + 1, r3 = r + 2, r4 = r + 3, r5 = r + 4, r6 = r + 5, r7 = r + 6, r8 = r + 7; + td.addToEntryVectorized(v, // + tm.getIndex(r), tm.getIndex(r2), tm.getIndex(r3), tm.getIndex(r4), tm.getIndex(r5), tm.getIndex(r6), + tm.getIndex(r7), tm.getIndex(r8), getIndex(r), // + getIndex(r2), getIndex(r3), getIndex(r4), getIndex(r5), getIndex(r6), getIndex(r7), getIndex(r8), nCol); } /** @@ -577,8 +605,8 @@ private int preAggregateSDCZ_DDCMultiCol_vect(AMapToData tm, IDictionary td, dou final int h = size % 8; int i = 0; while(i < size - h) { - int t1 = getIndex(i), t2 = getIndex(i + 1), t3 = getIndex(i + 2), t4 = getIndex(i + 3), - t5 = getIndex(i + 4), t6 = getIndex(i + 5), t7 = getIndex(i + 6), t8 = getIndex(i + 7); + int t1 = getIndex(i), t2 = getIndex(i + 1), t3 = getIndex(i + 2), t4 = getIndex(i + 3), t5 = getIndex(i + 4), + t6 = getIndex(i + 5), t7 = getIndex(i + 6), t8 = getIndex(i + 7); int f1 = it.value(), f2 = it.next(), f3 = it.next(), f4 = it.next(), f5 = it.next(), f6 = it.next(), f7 = it.next(), f8 = it.next(); @@ -607,8 +635,7 @@ public final void preAggregateSDCZ_SDCZ(AMapToData tm, IDictionary td, AOffset t preAggregateSDCZ_SDCZMultiCol(tm, td, tof, of, ret.getValues(), nCol); } - private final void preAggregateSDCZ_SDCZSingleCol(AMapToData tm, double[] td, AOffset tof, AOffset of, - double[] dv) { + private final void preAggregateSDCZ_SDCZSingleCol(AMapToData tm, double[] td, AOffset tof, AOffset of, double[] dv) { final AOffsetIterator itThat = tof.getOffsetIterator(); final AOffsetIterator itThis = of.getOffsetIterator(); final int tSize = tm.size() - 1, size = size() - 1; @@ -872,7 +899,7 @@ public void verify() { } } } - + @Override public String toString() { final int sz = size(); diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/MapToByte.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/MapToByte.java index 837468d3ebf..fcbc84ce984 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/MapToByte.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/MapToByte.java @@ -48,6 +48,7 @@ public MapToByte(int unique, int size) { protected MapToByte(int unique, byte[] data) { super(unique); _data = data; + verify(); } protected MapToUByte toUByte() { @@ -155,17 +156,21 @@ protected void preAggregateDenseToRowBy8(double[] mV, double[] preAV, int cl, in final int h = (cu - cl) % 8; off += cl; for(int rc = cl; rc < cl + h; rc++, off++) - preAV[_data[rc] & 0xFF] += mV[off]; - for(int rc = cl + h; rc < cu; rc += 8, off += 8) { - preAV[_data[rc] & 0xFF] += mV[off]; - preAV[_data[rc + 1] & 0xFF] += mV[off + 1]; - preAV[_data[rc + 2] & 0xFF] += mV[off + 2]; - preAV[_data[rc + 3] & 0xFF] += mV[off + 3]; - preAV[_data[rc + 4] & 0xFF] += mV[off + 4]; - preAV[_data[rc + 5] & 0xFF] += mV[off + 5]; - preAV[_data[rc + 6] & 0xFF] += mV[off + 6]; - preAV[_data[rc + 7] & 0xFF] += mV[off + 7]; - } + preAV[getIndex(rc)] += mV[off]; + for(int rc = cl + h; rc < cu; rc += 8, off += 8) + preAggregateDenseToRowVec8(mV, preAV, rc, off); + } + + @Override + protected void preAggregateDenseToRowVec8(double[] mV, double[] preAV, int rc, int off) { + preAV[getIndex(rc)] += mV[off]; + preAV[getIndex(rc + 1)] += mV[off + 1]; + preAV[getIndex(rc + 2)] += mV[off + 2]; + preAV[getIndex(rc + 3)] += mV[off + 3]; + preAV[getIndex(rc + 4)] += mV[off + 4]; + preAV[getIndex(rc + 5)] += mV[off + 5]; + preAV[getIndex(rc + 6)] += mV[off + 6]; + preAV[getIndex(rc + 7)] += mV[off + 7]; } @Override diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/MapToChar.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/MapToChar.java index bdab7891b82..1f46cc3886f 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/MapToChar.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/MapToChar.java @@ -49,6 +49,7 @@ public MapToChar(int unique, int size) { public MapToChar(int unique, char[] data) { super(unique); _data = data; + verify(); } @Override @@ -113,8 +114,8 @@ public void write(DataOutput out) throws IOException { out.writeInt(_data.length); final int BS = 100; if(_data.length > BS) { - final byte[] buff = new byte[BS*2]; - for(int i = 0; i < _data.length; ) { + final byte[] buff = new byte[BS * 2]; + for(int i = 0; i < _data.length;) { if(i + BS <= _data.length) { for(int o = 0; o < BS; o++) { IOUtilFunctions.shortToBa(_data[i++], buff, o * 2); @@ -152,17 +153,21 @@ protected void preAggregateDenseToRowBy8(double[] mV, double[] preAV, int cl, in final int h = (cu - cl) % 8; off += cl; for(int rc = cl; rc < cl + h; rc++, off++) - preAV[_data[rc]] += mV[off]; - for(int rc = cl + h; rc < cu; rc += 8, off += 8) { - preAV[_data[rc]] += mV[off]; - preAV[_data[rc + 1]] += mV[off + 1]; - preAV[_data[rc + 2]] += mV[off + 2]; - preAV[_data[rc + 3]] += mV[off + 3]; - preAV[_data[rc + 4]] += mV[off + 4]; - preAV[_data[rc + 5]] += mV[off + 5]; - preAV[_data[rc + 6]] += mV[off + 6]; - preAV[_data[rc + 7]] += mV[off + 7]; - } + preAV[getIndex(rc)] += mV[off]; + for(int rc = cl + h; rc < cu; rc += 8, off += 8) + preAggregateDenseToRowVec8(mV, preAV, rc, off); + } + + @Override + protected void preAggregateDenseToRowVec8(double[] mV, double[] preAV, int rc, int off){ + preAV[getIndex(rc)] += mV[off]; + preAV[getIndex(rc + 1)] += mV[off + 1]; + preAV[getIndex(rc + 2)] += mV[off + 2]; + preAV[getIndex(rc + 3)] += mV[off + 3]; + preAV[getIndex(rc + 4)] += mV[off + 4]; + preAV[getIndex(rc + 5)] += mV[off + 5]; + preAV[getIndex(rc + 6)] += mV[off + 6]; + preAV[getIndex(rc + 7)] += mV[off + 7]; } @Override @@ -304,4 +309,25 @@ public boolean equals(AMapToData e) { e.getUnique() == getUnique() && // Arrays.equals(((MapToChar) e)._data, _data); } + + @Override + protected void preAggregateDDC_DDCSingleCol_vec(AMapToData tm, double[] td, double[] v, int r) { + if(tm instanceof MapToChar) + preAggregateDDC_DDCSingleCol_vecChar((MapToChar) tm, td, v, r); + else + super.preAggregateDDC_DDCSingleCol_vec(tm, td, v, r); + } + + protected final void preAggregateDDC_DDCSingleCol_vecChar(MapToChar tm, double[] td, double[] v, int r) { + final int r2 = r + 1, r3 = r + 2, r4 = r + 3, r5 = r + 4, r6 = r + 5, r7 = r + 6, r8 = r + 7; + v[getIndex(r)] += td[tm.getIndex(r)]; + v[getIndex(r2)] += td[tm.getIndex(r2)]; + v[getIndex(r3)] += td[tm.getIndex(r3)]; + v[getIndex(r4)] += td[tm.getIndex(r4)]; + v[getIndex(r5)] += td[tm.getIndex(r5)]; + v[getIndex(r6)] += td[tm.getIndex(r6)]; + v[getIndex(r7)] += td[tm.getIndex(r7)]; + v[getIndex(r8)] += td[tm.getIndex(r8)]; + } + } diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/MapToCharPByte.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/MapToCharPByte.java index cb7d6199cf2..99d53878844 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/MapToCharPByte.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/MapToCharPByte.java @@ -53,6 +53,7 @@ public MapToCharPByte(int unique, char[] data_c, byte[] data_b) { super(unique); _data_c = data_c; _data_b = data_b; + verify(); } @Override @@ -278,4 +279,26 @@ public boolean equals(AMapToData e) { Arrays.equals(((MapToCharPByte) e)._data_b, _data_b) && // Arrays.equals(((MapToCharPByte) e)._data_c, _data_c); } + + @Override + protected void preAggregateDenseToRowBy8(double[] mV, double[] preAV, int cl, int cu, int off) { + final int h = (cu - cl) % 8; + off += cl; + for(int rc = cl; rc < cl + h; rc++, off++) + preAV[getIndex(rc)] += mV[off]; + for(int rc = cl + h; rc < cu; rc += 8, off += 8) + preAggregateDenseToRowVec8(mV, preAV, rc, off); + } + + @Override + protected void preAggregateDenseToRowVec8(double[] mV, double[] preAV, int rc, int off){ + preAV[getIndex(rc)] += mV[off]; + preAV[getIndex(rc + 1)] += mV[off + 1]; + preAV[getIndex(rc + 2)] += mV[off + 2]; + preAV[getIndex(rc + 3)] += mV[off + 3]; + preAV[getIndex(rc + 4)] += mV[off + 4]; + preAV[getIndex(rc + 5)] += mV[off + 5]; + preAV[getIndex(rc + 6)] += mV[off + 6]; + preAV[getIndex(rc + 7)] += mV[off + 7]; + } } diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/MapToInt.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/MapToInt.java index b3c509b78cf..20b2c77c7c8 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/MapToInt.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/MapToInt.java @@ -48,6 +48,7 @@ public MapToInt(int unique, int size) { private MapToInt(int unique, int[] data) { super(unique); _data = data; + verify(); } protected int[] getData() { @@ -130,19 +131,24 @@ protected void preAggregateDenseToRowBy8(double[] mV, double[] preAV, int cl, in final int h = (cu - cl) % 8; off += cl; for(int rc = cl; rc < cl + h; rc++, off++) - preAV[_data[rc]] += mV[off]; - for(int rc = cl + h; rc < cu; rc += 8, off += 8) { - preAV[_data[rc]] += mV[off]; - preAV[_data[rc + 1]] += mV[off + 1]; - preAV[_data[rc + 2]] += mV[off + 2]; - preAV[_data[rc + 3]] += mV[off + 3]; - preAV[_data[rc + 4]] += mV[off + 4]; - preAV[_data[rc + 5]] += mV[off + 5]; - preAV[_data[rc + 6]] += mV[off + 6]; - preAV[_data[rc + 7]] += mV[off + 7]; - } + preAV[getIndex(rc)] += mV[off]; + for(int rc = cl + h; rc < cu; rc += 8, off += 8) + preAggregateDenseToRowVec8(mV, preAV, rc, off); } + @Override + protected void preAggregateDenseToRowVec8(double[] mV, double[] preAV, int rc, int off){ + preAV[getIndex(rc)] += mV[off]; + preAV[getIndex(rc + 1)] += mV[off + 1]; + preAV[getIndex(rc + 2)] += mV[off + 2]; + preAV[getIndex(rc + 3)] += mV[off + 3]; + preAV[getIndex(rc + 4)] += mV[off + 4]; + preAV[getIndex(rc + 5)] += mV[off + 5]; + preAV[getIndex(rc + 6)] += mV[off + 6]; + preAV[getIndex(rc + 7)] += mV[off + 7]; + } + + @Override protected void preAggregateDenseMultiRowContiguousBy8(double[] mV, int nCol, int nVal, double[] preAV, int rl, int ru, int cl, int cu) { diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/MapToUByte.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/MapToUByte.java index f94e95a9ed3..d545c362996 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/MapToUByte.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/MapToUByte.java @@ -95,17 +95,21 @@ protected void preAggregateDenseToRowBy8(double[] mV, double[] preAV, int cl, in final int h = (cu - cl) % 8; off += cl; for(int rc = cl; rc < cl + h; rc++, off++) - preAV[_data[rc]] += mV[off]; - for(int rc = cl + h; rc < cu; rc += 8, off += 8) { - preAV[_data[rc]] += mV[off]; - preAV[_data[rc + 1]] += mV[off + 1]; - preAV[_data[rc + 2]] += mV[off + 2]; - preAV[_data[rc + 3]] += mV[off + 3]; - preAV[_data[rc + 4]] += mV[off + 4]; - preAV[_data[rc + 5]] += mV[off + 5]; - preAV[_data[rc + 6]] += mV[off + 6]; - preAV[_data[rc + 7]] += mV[off + 7]; - } + preAV[getIndex(rc)] += mV[off]; + for(int rc = cl + h; rc < cu; rc += 8, off += 8) + preAggregateDenseToRowVec8(mV, preAV, rc, off); + } + + @Override + protected void preAggregateDenseToRowVec8(double[] mV, double[] preAV, int rc, int off) { + preAV[getIndex(rc)] += mV[off]; + preAV[getIndex(rc + 1)] += mV[off + 1]; + preAV[getIndex(rc + 2)] += mV[off + 2]; + preAV[getIndex(rc + 3)] += mV[off + 3]; + preAV[getIndex(rc + 4)] += mV[off + 4]; + preAV[getIndex(rc + 5)] += mV[off + 5]; + preAV[getIndex(rc + 6)] += mV[off + 6]; + preAV[getIndex(rc + 7)] += mV[off + 7]; } @Override @@ -121,7 +125,7 @@ public int[] getCounts(int[] ret) { } @Override - public int getMaxPossible(){ + public int getMaxPossible() { return 128; } From 8ca0bf1eb4e4e5c55f4aa610d2cc54ce9705b77b Mon Sep 17 00:00:00 2001 From: Sebastian Baunsgaard Date: Mon, 30 Oct 2023 13:51:01 +0100 Subject: [PATCH 18/28] [MINOR] Refine Error on Scalar compression --- .../instructions/cp/CompressionCPInstruction.java | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/CompressionCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/CompressionCPInstruction.java index b59e4d9db85..c9dd5c8961e 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/cp/CompressionCPInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/CompressionCPInstruction.java @@ -22,6 +22,7 @@ import java.util.ArrayList; import java.util.List; +import org.apache.commons.lang3.NotImplementedException; import org.apache.commons.lang3.tuple.Pair; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; @@ -122,10 +123,13 @@ private void processSimpleCompressInstruction(ExecutionContext ec) { final int k = OptimizerUtils.getConstrainedNumThreads(-1); - if(ec.isMatrixObject(input1.getName())) - processMatrixBlockCompression(ec, ec.getMatrixInput(input1.getName()), k, root); - else + if(ec.isFrameObject(input1.getName())) processFrameBlockCompression(ec, ec.getFrameInput(input1.getName()), k, root); + else if(ec.isMatrixObject(input1.getName())) + processMatrixBlockCompression(ec, ec.getMatrixInput(input1.getName()), k, root); + else{ + throw new NotImplementedException("Not supported other types of input for compression than frame and matrix"); + } } private void processMatrixBlockCompression(ExecutionContext ec, MatrixBlock in, int k, WTreeRoot root) { From c398e8ec5e163647706ac309b8c854a62b594c97 Mon Sep 17 00:00:00 2001 From: Sebastian Baunsgaard Date: Mon, 30 Oct 2023 13:55:26 +0100 Subject: [PATCH 19/28] [SYSTEMDS-3644] Compressed-Compressed Transform Encode (PassThrough) Initial instance of direct compressed frame to compressed matrix transform encode, to start with in the case of PassThrough. --- .../runtime/frame/data/columns/DDCArray.java | 6 +++++- .../transform/encode/CompressedEncode.java | 19 +++++++++++++++++++ .../transform/encode/MultiColumnEncoder.java | 5 +++-- 3 files changed, 27 insertions(+), 3 deletions(-) diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/columns/DDCArray.java b/src/main/java/org/apache/sysds/runtime/frame/data/columns/DDCArray.java index b634cfe6ff3..8f3dcd9dcba 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/columns/DDCArray.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/columns/DDCArray.java @@ -55,10 +55,14 @@ public DDCArray(Array dict, AMapToData map) { } } - protected Array getDict(){ + public Array getDict(){ return dict; } + public AMapToData getMap(){ + return map; + } + /** * Try to compress array into DDC format. * diff --git a/src/main/java/org/apache/sysds/runtime/transform/encode/CompressedEncode.java b/src/main/java/org/apache/sysds/runtime/transform/encode/CompressedEncode.java index 8ca8b6d9fc2..7fbdb1ea3c8 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/encode/CompressedEncode.java +++ b/src/main/java/org/apache/sysds/runtime/transform/encode/CompressedEncode.java @@ -49,7 +49,9 @@ import org.apache.sysds.runtime.compress.colgroup.mapping.AMapToData; import org.apache.sysds.runtime.compress.colgroup.mapping.MapToFactory; import org.apache.sysds.runtime.frame.data.FrameBlock; +import org.apache.sysds.runtime.frame.data.columns.ACompressedArray; import org.apache.sysds.runtime.frame.data.columns.Array; +import org.apache.sysds.runtime.frame.data.columns.DDCArray; import org.apache.sysds.runtime.matrix.data.MatrixBlock; import org.apache.sysds.runtime.util.CommonThreadPool; import org.apache.sysds.runtime.util.UtilFunctions; @@ -164,6 +166,7 @@ private AColGroup recodeToDummy(ColumnEncoderComposite c) { IColIndex colIndexes = ColIndexFactory.create(0, domain); if(domain == 1 && !containsNull) return ColGroupConst.create(colIndexes, new double[] {1}); + ADictionary d = new IdentityDictionary(colIndexes.size(), containsNull); AMapToData m = createMappingAMapToData(a, map, containsNull); return ColGroupDDC.create(colIndexes, d, m, null); @@ -288,6 +291,22 @@ private AColGroup passThrough(ColumnEncoderComposite c) { IColIndex colIndexes = ColIndexFactory.create(1); int colId = c._colID; Array a = in.getColumn(colId - 1); + if(a instanceof ACompressedArray){ + switch(a.getFrameArrayType()) { + case DDC: + DDCArray aDDC = (DDCArray) a; + Array dict = aDDC.getDict(); + double[] vals = new double[dict.size()]; + for(int i = 0; i < dict.size(); i++) { + vals[i] = dict.getAsDouble(i); + } + ADictionary d = Dictionary.create(vals); + + return ColGroupDDC.create(colIndexes, d, aDDC.getMap(), null); + default: + throw new NotImplementedException(); + } + } boolean containsNull = a.containsNull(); HashMap map = (HashMap) a.getRecodeMap(); final int blockSz = ConfigurationManager.getDMLConfig().getIntValue(DMLConfig.DEFAULT_BLOCK_SIZE); diff --git a/src/main/java/org/apache/sysds/runtime/transform/encode/MultiColumnEncoder.java b/src/main/java/org/apache/sysds/runtime/transform/encode/MultiColumnEncoder.java index f1813e29a77..bd9e2ba79f8 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/encode/MultiColumnEncoder.java +++ b/src/main/java/org/apache/sysds/runtime/transform/encode/MultiColumnEncoder.java @@ -102,11 +102,12 @@ public MatrixBlock encode(CacheBlock in, boolean compressedOut) { } public MatrixBlock encode(CacheBlock in, int k, boolean compressedOut){ - deriveNumRowPartitions(in, k); try { if(isCompressedTransformEncode(in, compressedOut)) return CompressedEncode.encode(this, (FrameBlock ) in, k); - else if(k > 1 && !MULTI_THREADED_STAGES && !hasLegacyEncoder()) { + + deriveNumRowPartitions(in, k); + if(k > 1 && !MULTI_THREADED_STAGES && !hasLegacyEncoder()) { MatrixBlock out = new MatrixBlock(); DependencyThreadPool pool = new DependencyThreadPool(k); LOG.debug("Encoding with full DAG on " + k + " Threads"); From 0ba2aa994f8f3006a2a660c8cad4fdd8e78ac94f Mon Sep 17 00:00:00 2001 From: Sebastian Baunsgaard Date: Mon, 30 Oct 2023 13:30:15 +0100 Subject: [PATCH 20/28] [SYSTEMDS-3643] Fused Scaling Compressed Multiplication This commit contains the code to fuse the scaling part into the Matrix Multiplication kernels of CLA. This is used to not allocate new Dictionaries, when the two column group sides have identical index structures. The change improve instructions such as MMChain and TSMM. The improvements are biggest if there are few column groups. Closes #1936 --- .../runtime/compress/colgroup/APreAgg.java | 5 +- .../dictionary/DictLibMatrixMult.java | 127 +++++++++- .../colgroup/dictionary/Dictionary.java | 48 +++- .../colgroup/dictionary/IDictionary.java | 94 ++++--- .../dictionary/IdentityDictionary.java | 168 +++++++++++-- .../dictionary/IdentityDictionarySlice.java | 23 +- .../dictionary/MatrixBlockDictionary.java | 71 +++++- .../colgroup/dictionary/PlaceHolderDict.java | 18 ++ .../colgroup/dictionary/QDictionary.java | 18 ++ .../runtime/data/SparseBlockFactory.java | 45 +++- .../java/org/apache/sysds/test/TestUtils.java | 11 + .../compress/dictionary/DictionaryTests.java | 232 +++++++++++++++++- .../test/component/matrix/SparseFactory.java | 42 ++++ 13 files changed, 821 insertions(+), 81 deletions(-) create mode 100644 src/test/java/org/apache/sysds/test/component/matrix/SparseFactory.java diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/APreAgg.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/APreAgg.java index 8b8a7b7df02..7f585f2d7ac 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/APreAgg.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/APreAgg.java @@ -85,9 +85,12 @@ else if(lhs instanceof ColGroupUncompressed) * @return A aggregate dictionary */ public final IDictionary preAggregateThatIndexStructure(APreAgg that) { - long outputLength = (long)that._colIndexes.size() * this.getNumValues(); + final long outputLength = (long)that._colIndexes.size() * this.getNumValues(); if(outputLength > Integer.MAX_VALUE) throw new NotImplementedException("Not supported pre aggregate of above integer length"); + if(outputLength <= 0) // if the pre aggregate output is empty or nothing, return null + return null; + // create empty Dictionary that we slowly fill, hence the dictionary is empty and no check final Dictionary ret = Dictionary.createNoCheck(new double[(int)outputLength]); diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/DictLibMatrixMult.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/DictLibMatrixMult.java index 240e57cc124..9aba711a30e 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/DictLibMatrixMult.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/DictLibMatrixMult.java @@ -65,11 +65,7 @@ else if(row > col) // swap because in lower triangle */ public static void MMDictsWithScaling(IDictionary left, IDictionary right, IColIndex leftRows, IColIndex rightColumns, MatrixBlock result, int[] counts) { - LOG.warn("Inefficient double allocation of dictionary"); - final boolean modifyRight = right.getInMemorySize() > left.getInMemorySize(); - final IDictionary rightM = modifyRight ? right.scaleTuples(counts, rightColumns.size()) : right; - final IDictionary leftM = modifyRight ? left : left.scaleTuples(counts, leftRows.size()); - MMDicts(leftM, rightM, leftRows, rightColumns, result); + left.MMDictScaling(right, leftRows, rightColumns, result, counts); } /** @@ -198,17 +194,43 @@ protected static void TSMMDictsSparseWithScaling(SparseBlock sb, IColIndex rowsL protected static void MMDictsDenseDense(double[] left, double[] right, IColIndex rowsLeft, IColIndex colsRight, MatrixBlock result) { - final int commonDim = Math.min(left.length / rowsLeft.size(), right.length / colsRight.size()); + final int leftSide = rowsLeft.size(); + final int rightSide = colsRight.size(); + final int commonDim = Math.min(left.length / leftSide, right.length / rightSide); final int resCols = result.getNumColumns(); final double[] resV = result.getDenseBlockValues(); + for(int k = 0; k < commonDim; k++) { - final int offL = k * rowsLeft.size(); - final int offR = k * colsRight.size(); - for(int i = 0; i < rowsLeft.size(); i++) { + final int offL = k * leftSide; + final int offR = k * rightSide; + for(int i = 0; i < leftSide; i++) { final int offOut = rowsLeft.get(i) * resCols; final double vl = left[offL + i]; if(vl != 0) { - for(int j = 0; j < colsRight.size(); j++) + for(int j = 0; j < rightSide; j++) + resV[offOut + colsRight.get(j)] += vl * right[offR + j]; + } + } + } + } + + protected static void MMDictsScalingDenseDense(double[] left, double[] right, IColIndex rowsLeft, + IColIndex colsRight, MatrixBlock result, int[] scaling) { + final int leftSide = rowsLeft.size(); + final int rightSide = colsRight.size(); + final int commonDim = Math.min(left.length / leftSide, right.length / rightSide); + final int resCols = result.getNumColumns(); + final double[] resV = result.getDenseBlockValues(); + + for(int k = 0; k < commonDim; k++) { + final int offL = k * leftSide; + final int offR = k * rightSide; + final int s = scaling[k]; + for(int i = 0; i < leftSide; i++) { + final int offOut = rowsLeft.get(i) * resCols; + final double vl = left[offL + i] * s; + if(vl != 0) { + for(int j = 0; j < rightSide; j++) resV[offOut + colsRight.get(j)] += vl * right[offR + j]; } } @@ -236,10 +258,34 @@ protected static void MMDictsSparseDense(SparseBlock left, double[] right, IColI } } + protected static void MMDictsScalingSparseDense(SparseBlock left, double[] right, IColIndex rowsLeft, + IColIndex colsRight, MatrixBlock result, int[] scaling) { + final double[] resV = result.getDenseBlockValues(); + final int commonDim = Math.min(left.numRows(), right.length / colsRight.size()); + for(int i = 0; i < commonDim; i++) { + if(left.isEmpty(i)) + continue; + final int apos = left.pos(i); + final int alen = left.size(i) + apos; + final int[] aix = left.indexes(i); + final double[] leftVals = left.values(i); + final int offRight = i * colsRight.size(); + final int s = scaling[i]; + for(int k = apos; k < alen; k++) { + final int offOut = rowsLeft.get(aix[k]) * result.getNumColumns(); + final double v = leftVals[k] * s; + for(int j = 0; j < colsRight.size(); j++) + resV[offOut + colsRight.get(j)] += v * right[offRight + j]; + } + } + } + protected static void MMDictsDenseSparse(double[] left, SparseBlock right, IColIndex rowsLeft, IColIndex colsRight, MatrixBlock result) { final double[] resV = result.getDenseBlockValues(); - final int commonDim = Math.min(left.length / rowsLeft.size(), right.numRows()); + final int leftSize = rowsLeft.size(); + final int commonDim = Math.min(left.length / leftSize, right.numRows()); + for(int i = 0; i < commonDim; i++) { if(right.isEmpty(i)) continue; @@ -247,8 +293,8 @@ protected static void MMDictsDenseSparse(double[] left, SparseBlock right, IColI final int alen = right.size(i) + apos; final int[] aix = right.indexes(i); final double[] rightVals = right.values(i); - final int offLeft = i * rowsLeft.size(); - for(int j = 0; j < rowsLeft.size(); j++) { + final int offLeft = i * leftSize; + for(int j = 0; j < leftSize; j++) { final int offOut = rowsLeft.get(j) * result.getNumColumns(); final double v = left[offLeft + j]; if(v != 0) { @@ -259,6 +305,32 @@ protected static void MMDictsDenseSparse(double[] left, SparseBlock right, IColI } } + protected static void MMDictsScalingDenseSparse(double[] left, SparseBlock right, IColIndex rowsLeft, IColIndex colsRight, + MatrixBlock result, int[] scaling) { + final double[] resV = result.getDenseBlockValues(); + final int leftSize = rowsLeft.size(); + final int commonDim = Math.min(left.length / leftSize, right.numRows()); + + for(int i = 0; i < commonDim; i++) { + if(right.isEmpty(i)) + continue; + final int apos = right.pos(i); + final int alen = right.size(i) + apos; + final int[] aix = right.indexes(i); + final double[] rightVals = right.values(i); + final int offLeft = i * leftSize; + final int s = scaling[i]; + for(int j = 0; j < leftSize; j++) { + final int offOut = rowsLeft.get(j) * result.getNumColumns(); + final double v = left[offLeft + j] * s; + if(v != 0) { + for(int k = apos; k < alen; k++) + resV[offOut + colsRight.get(aix[k])] += v * rightVals[k]; + } + } + } + } + protected static void MMDictsSparseSparse(SparseBlock left, SparseBlock right, IColIndex rowsLeft, IColIndex colsRight, MatrixBlock result) { final int commonDim = Math.min(left.numRows(), right.numRows()); @@ -286,6 +358,35 @@ protected static void MMDictsSparseSparse(SparseBlock left, SparseBlock right, I } } + protected static void MMDictsScalingSparseSparse(SparseBlock left, SparseBlock right, IColIndex rowsLeft, + IColIndex colsRight, MatrixBlock result, int[] scaling) { + final int commonDim = Math.min(left.numRows(), right.numRows()); + final double[] resV = result.getDenseBlockValues(); + final int resCols = result.getNumColumns(); + // remember that the left side is transposed... + for(int i = 0; i < commonDim; i++) { + if(left.isEmpty(i) || right.isEmpty(i)) + continue; + final int leftAPos = left.pos(i); + final int leftAlen = left.size(i) + leftAPos; + final int[] leftAix = left.indexes(i); + final double[] leftVals = left.values(i); + final int rightAPos = right.pos(i); + final int rightAlen = right.size(i) + rightAPos; + final int[] rightAix = right.indexes(i); + final double[] rightVals = right.values(i); + + final int s = scaling[i]; + + for(int k = leftAPos; k < leftAlen; k++) { + final int offOut = rowsLeft.get(leftAix[k]) * resCols; + final double v = leftVals[k] * s; + for(int j = rightAPos; j < rightAlen; j++) + resV[offOut + colsRight.get(rightAix[j])] += v * rightVals[j]; + } + } + } + protected static void MMToUpperTriangleSparseSparse(SparseBlock left, SparseBlock right, IColIndex rowsLeft, IColIndex colsRight, MatrixBlock result) { final int commonDim = Math.min(left.numRows(), right.numRows()); diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/Dictionary.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/Dictionary.java index 983dc84b507..4f0bbfbee14 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/Dictionary.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/Dictionary.java @@ -22,13 +22,16 @@ import java.io.DataInput; import java.io.DataOutput; import java.io.IOException; +import java.lang.ref.SoftReference; import java.math.BigDecimal; import java.math.MathContext; import java.util.Arrays; +import org.apache.commons.lang3.NotImplementedException; import org.apache.sysds.hops.OptimizerUtils; import org.apache.sysds.runtime.compress.DMLCompressionException; import org.apache.sysds.runtime.compress.colgroup.indexes.IColIndex; +import org.apache.sysds.runtime.compress.utils.Util; import org.apache.sysds.runtime.data.SparseBlock; import org.apache.sysds.runtime.functionobjects.Builtin; import org.apache.sysds.runtime.functionobjects.Plus; @@ -51,6 +54,8 @@ public class Dictionary extends ADictionary { private static final long serialVersionUID = -6517136537249507753L; protected final double[] _values; + /** A Cache to contain a MatrixBlock version of the dictionary. */ + protected volatile SoftReference cache = null; protected Dictionary(double[] values) { _values = values; @@ -799,7 +804,14 @@ public IDictionary subtractTuple(double[] tuple) { @Override public MatrixBlockDictionary getMBDict(int nCol) { - return MatrixBlockDictionary.createDictionary(_values, nCol, true); + if(cache != null) { + MatrixBlockDictionary r = cache.get(); + if(r != null) + return r; + } + MatrixBlockDictionary ret = MatrixBlockDictionary.createDictionary(_values, nCol, true); + cache = new SoftReference<>(ret); + return ret; } @Override @@ -843,13 +855,15 @@ public IDictionary scaleTuples(int[] scaling, int nCol) { @Override public Dictionary preaggValuesFromDense(int numVals, IColIndex colIndexes, IColIndex aggregateColumns, double[] b, int cut) { - double[] ret = new double[numVals * aggregateColumns.size()]; - for(int k = 0, off = 0; k < numVals * colIndexes.size(); k += colIndexes.size(), off += aggregateColumns.size()) { - for(int h = 0; h < colIndexes.size(); h++) { - int idb = colIndexes.get(h) * cut; + final int cz = colIndexes.size(); + final int az = aggregateColumns.size(); + final double[] ret = new double[numVals * az]; + for(int k = 0, off = 0; k < numVals * cz; k += cz, off += az) { + for(int h = 0; h < cz; h++) { + final int idb = colIndexes.get(h) * cut; double v = _values[k + h]; if(v != 0) - for(int i = 0; i < aggregateColumns.size(); i++) + for(int i = 0; i < az; i++) ret[off + i] += v * b[idb + aggregateColumns.get(i)]; } } @@ -861,13 +875,15 @@ public IDictionary replace(double pattern, double replace, int nCol) { double[] retV = new double[_values.length]; for(int i = 0; i < _values.length; i++) { final double v = _values[i]; - retV[i] = v == pattern ? replace : v; + retV[i] = Util.eq(v, pattern) ? replace : v; } return create(retV); } @Override public IDictionary replaceWithReference(double pattern, double replace, double[] reference) { + if(Util.eq(pattern, Double.NaN)) + throw new NotImplementedException(); final double[] retV = new double[_values.length]; final int nCol = reference.length; final int nRow = _values.length / nCol; @@ -1040,16 +1056,34 @@ public void MMDict(IDictionary right, IColIndex rowsLeft, IColIndex colsRight, M right.MMDictDense(_values, rowsLeft, colsRight, result); } + @Override + public void MMDictScaling(IDictionary right, IColIndex rowsLeft, IColIndex colsRight, MatrixBlock result, + int[] scaling) { + right.MMDictScalingDense(_values, rowsLeft, colsRight, result, scaling); + } + @Override public void MMDictDense(double[] left, IColIndex rowsLeft, IColIndex colsRight, MatrixBlock result) { DictLibMatrixMult.MMDictsDenseDense(left, _values, rowsLeft, colsRight, result); } + @Override + public void MMDictScalingDense(double[] left, IColIndex rowsLeft, IColIndex colsRight, MatrixBlock result, + int[] scaling) { + DictLibMatrixMult.MMDictsScalingDenseDense(left, _values, rowsLeft, colsRight, result, scaling); + } + @Override public void MMDictSparse(SparseBlock left, IColIndex rowsLeft, IColIndex colsRight, MatrixBlock result) { DictLibMatrixMult.MMDictsSparseDense(left, _values, rowsLeft, colsRight, result); } + @Override + public void MMDictScalingSparse(SparseBlock left, IColIndex rowsLeft, IColIndex colsRight, MatrixBlock result, + int[] scaling) { + DictLibMatrixMult.MMDictsScalingSparseDense(left, _values, rowsLeft, colsRight, result, scaling); + } + @Override public void TSMMToUpperTriangle(IDictionary right, IColIndex rowsLeft, IColIndex colsRight, MatrixBlock result) { right.TSMMToUpperTriangleDense(_values, rowsLeft, colsRight, result); diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/IDictionary.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/IDictionary.java index 2f3d435673a..1047692f509 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/IDictionary.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/IDictionary.java @@ -524,46 +524,46 @@ public IDictionary binOpRightWithReference(BinaryOperator op, double[] v, IColIn public long getNumberNonZerosWithReference(int[] counts, double[] reference, int nRows); /** - * Copies and adds the dictionary entry from this dictionary to the d dictionary + * Adds the dictionary entry from this dictionary to the d dictionary * - * @param v the target dictionary (dense double array) - * @param fr the from index - * @param to the to index - * @param nCol the number of columns + * @param v The target dictionary (dense double array) + * @param fr The from index is the tuple index to copy from. + * @param to The to index is the row index to copy into. + * @param nCol The number of columns in both cases */ public void addToEntry(double[] v, int fr, int to, int nCol); /** - * copies and adds the dictonary entry from this dictionary yo the d dictionary rep times. + * Adds the dictionary entry from this dictionary to the v dictionary rep times. * - * @param v the target dictionary (dense double array) - * @param fr the from index - * @param to the to index - * @param nCol the number of columns - * @param rep the number of repetitions to apply (simply multiply do not loop) + * @param v The target dictionary (dense double array) + * @param fr The from index is the tuple index to copy from. + * @param to The to index is the row index to copy into. + * @param nCol The number of columns in both cases + * @param rep The number of repetitions to apply (simply multiply do not loop) */ public void addToEntry(double[] v, int fr, int to, int nCol, int rep); /** * Vectorized add to entry, this call helps with a bit of locality for the cache. * - * @param v THe target dictionary (dense double array) - * @param f1 from index 1 - * @param f2 from index 2 - * @param f3 from index 3 - * @param f4 from index 4 - * @param f5 from index 5 - * @param f6 from index 6 - * @param f7 from index 7 - * @param f8 from index 8 - * @param t1 to index 1 - * @param t2 to index 2 - * @param t3 to index 3 - * @param t4 to index 4 - * @param t5 to index 5 - * @param t6 to index 6 - * @param t7 to index 7 - * @param t8 to index 8 + * @param v The target dictionary (dense double array) + * @param f1 From index 1 + * @param f2 From index 2 + * @param f3 From index 3 + * @param f4 From index 4 + * @param f5 From index 5 + * @param f6 From index 6 + * @param f7 From index 7 + * @param f8 From index 8 + * @param t1 To index 1 + * @param t2 To index 2 + * @param t3 To index 3 + * @param t4 To index 4 + * @param t5 To index 5 + * @param t6 To index 6 + * @param t7 To index 7 + * @param t8 To index 8 * @param nCol Number of columns in the dictionary */ public void addToEntryVectorized(double[] v, int f1, int f2, int f3, int f4, int f5, int f6, int f7, int f8, int t1, @@ -820,6 +820,20 @@ public CM_COV_Object centralMomentWithReference(CM_COV_Object ret, ValueFunction */ public void MMDict(IDictionary right, IColIndex rowsLeft, IColIndex colsRight, MatrixBlock result); + /** + * Matrix multiplication of dictionaries + * + * Note the left is this, and it is transposed + * + * @param right Right hand side of multiplication + * @param rowsLeft Offset rows on the left + * @param colsRight Offset cols on the right + * @param result The output matrix block + * @param scaling The scaling + */ + public void MMDictScaling(IDictionary right, IColIndex rowsLeft, IColIndex colsRight, MatrixBlock result, + int[] scaling); + /** * Matrix multiplication of dictionaries left side dense and transposed right side is this. * @@ -830,6 +844,18 @@ public CM_COV_Object centralMomentWithReference(CM_COV_Object ret, ValueFunction */ public void MMDictDense(double[] left, IColIndex rowsLeft, IColIndex colsRight, MatrixBlock result); + /** + * Matrix multiplication of dictionaries left side dense and transposed right side is this. + * + * @param left Dense left side + * @param rowsLeft Offset rows on the left + * @param colsRight Offset cols on the right + * @param result The output matrix block + * @param scaling The scaling + */ + public void MMDictScalingDense(double[] left, IColIndex rowsLeft, IColIndex colsRight, MatrixBlock result, + int[] scaling); + /** * Matrix multiplication of dictionaries left side sparse and transposed right side is this. * @@ -839,6 +865,18 @@ public CM_COV_Object centralMomentWithReference(CM_COV_Object ret, ValueFunction * @param result The output matrix block */ public void MMDictSparse(SparseBlock left, IColIndex rowsLeft, IColIndex colsRight, MatrixBlock result); + +/** + * Matrix multiplication of dictionaries left side sparse and transposed right side is this. + * + * @param left Sparse left side + * @param rowsLeft Offset rows on the left + * @param colsRight Offset cols on the right + * @param result The output matrix block + * @param scaling The scaling + */ + public void MMDictScalingSparse(SparseBlock left, IColIndex rowsLeft, IColIndex colsRight, MatrixBlock result, + int[] scaling); /** * Matrix multiplication but allocate output in upper triangle and twice if on diagonal, note this is left diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/IdentityDictionary.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/IdentityDictionary.java index 39712155e6b..74f5e5b0991 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/IdentityDictionary.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/IdentityDictionary.java @@ -32,6 +32,9 @@ import org.apache.sysds.runtime.data.SparseBlockFactory; import org.apache.sysds.runtime.functionobjects.Builtin; import org.apache.sysds.runtime.functionobjects.Builtin.BuiltinCode; +import org.apache.sysds.runtime.functionobjects.Divide; +import org.apache.sysds.runtime.functionobjects.Minus; +import org.apache.sysds.runtime.functionobjects.Plus; import org.apache.sysds.runtime.functionobjects.ValueFunction; import org.apache.sysds.runtime.instructions.cp.CM_COV_Object; import org.apache.sysds.runtime.matrix.data.MatrixBlock; @@ -51,7 +54,7 @@ public class IdentityDictionary extends ADictionary { /** Specify if the Identity matrix should contain an empty row in the end. */ protected final boolean withEmpty; /** A Cache to contain a materialized version of the identity matrix. */ - protected SoftReference cache = null; + protected volatile SoftReference cache = null; /** * Create an identity matrix dictionary. It behaves as if allocated a Sparse Matrix block but exploits that the @@ -212,7 +215,29 @@ public IDictionary binOpLeftWithReference(BinaryOperator op, double[] v, IColInd @Override public IDictionary binOpRight(BinaryOperator op, double[] v, IColIndex colIndexes) { - return getMBDict().binOpRight(op, v, colIndexes); + boolean same = false; + if(op.fn instanceof Plus || op.fn instanceof Minus) { + same = true; + for(int i = 0; i < colIndexes.size(); i++) { + if(v[colIndexes.get(i)] != 0.0) { + same = false; + break; + } + } + } + if(op.fn instanceof Divide) { + same = true; + for(int i = 0; i < colIndexes.size(); i++) { + if(v[colIndexes.get(i)] != 1.0) { + same = false; + break; + } + } + } + if(same) + return this; + MatrixBlockDictionary mb = getMBDict(); + return mb.binOpRight(op, v, colIndexes); } @Override @@ -243,22 +268,33 @@ public DictType getDictType() { @Override public int getNumberOfValues(int ncol) { + if(ncol != nRowCol) + throw new DMLCompressionException("Invalid call to get Number of values assuming wrong number of columns"); return nRowCol + (withEmpty ? 1 : 0); } @Override public double[] sumAllRowsToDouble(int nrColumns) { - double[] ret = new double[nRowCol]; - Arrays.fill(ret, 1); - return ret; + if(withEmpty) { + double[] ret = new double[nRowCol + 1]; + Arrays.fill(ret, 1); + ret[ret.length - 1] = 0; + return ret; + } + else { + double[] ret = new double[nRowCol]; + Arrays.fill(ret, 1); + return ret; + } } @Override public double[] sumAllRowsToDoubleWithDefault(double[] defaultTuple) { - double[] ret = new double[nRowCol]; - Arrays.fill(ret, 1); + double[] ret = new double[defaultTuple.length]; for(int i = 0; i < defaultTuple.length; i++) - ret[i] += defaultTuple[i]; + ret[i] += 1 + defaultTuple[i]; + if(withEmpty) + ret[ret.length - 1] += -1; return ret; } @@ -341,6 +377,8 @@ public double sum(int[] counts, int ncol) { double s = 0.0; for(int v : counts) s += v; + if(withEmpty) + s -= counts[counts.length - 1]; return s; } @@ -389,13 +427,54 @@ public void addToEntry(final double[] v, final int fr, final int to, final int n @Override public void addToEntry(final double[] v, final int fr, final int to, final int nCol, int rep) { - getMBDict().addToEntry(v, fr, to, nCol, rep); + if(withEmpty) { + if(fr < nRowCol) + v[to * nCol + fr] += rep; + } + else { + v[to * nCol + fr] += rep; + } } @Override public void addToEntryVectorized(double[] v, int f1, int f2, int f3, int f4, int f5, int f6, int f7, int f8, int t1, int t2, int t3, int t4, int t5, int t6, int t7, int t8, int nCol) { - getMBDict().addToEntryVectorized(v, f1, f2, f3, f4, f5, f6, f7, f8, t1, t2, t3, t4, t5, t6, t7, t8, nCol); + if(withEmpty) + addToEntryVectorizedWithEmpty(v, f1, f2, f3, f4, f5, f6, f7, f8, t1, t2, t3, t4, t5, t6, t7, t8, nCol); + else + addToEntryVectorizedNorm(v, f1, f2, f3, f4, f5, f6, f7, f8, t1, t2, t3, t4, t5, t6, t7, t8, nCol); + } + + private void addToEntryVectorizedWithEmpty(double[] v, int f1, int f2, int f3, int f4, int f5, int f6, int f7, + int f8, int t1, int t2, int t3, int t4, int t5, int t6, int t7, int t8, int nCol) { + if(f1 < nRowCol) + v[t1 * nCol + f1] += 1; + if(f2 < nRowCol) + v[t2 * nCol + f2] += 1; + if(f3 < nRowCol) + v[t3 * nCol + f3] += 1; + if(f4 < nRowCol) + v[t4 * nCol + f4] += 1; + if(f5 < nRowCol) + v[t5 * nCol + f5] += 1; + if(f6 < nRowCol) + v[t6 * nCol + f6] += 1; + if(f7 < nRowCol) + v[t7 * nCol + f7] += 1; + if(f8 < nRowCol) + v[t8 * nCol + f8] += 1; + } + + private void addToEntryVectorizedNorm(double[] v, int f1, int f2, int f3, int f4, int f5, int f6, int f7, int f8, + int t1, int t2, int t3, int t4, int t5, int t6, int t7, int t8, int nCol) { + v[t1 * nCol + f1] += 1; + v[t2 * nCol + f2] += 1; + v[t3 * nCol + f3] += 1; + v[t4 * nCol + f4] += 1; + v[t5 * nCol + f5] += 1; + v[t6 * nCol + f6] += 1; + v[t7 * nCol + f7] += 1; + v[t8 * nCol + f8] += 1; } @Override @@ -466,7 +545,28 @@ public long getExactSizeOnDisk() { @Override public IDictionary preaggValuesFromDense(final int numVals, final IColIndex colIndexes, final IColIndex aggregateColumns, final double[] b, final int cut) { - return getMBDict().preaggValuesFromDense(numVals, colIndexes, aggregateColumns, b, cut); + /** + * This operations is Essentially a Identity matrix multiplication with a right hand side dense matrix, but we + * need to slice out the right hand side from the input. + * + * ColIndexes specify the rows to slice out of the right matrix. + * + * aggregate columns specify the columns to slice out from the right. + */ + final int cs = colIndexes.size(); + final int s = aggregateColumns.size(); + + double[] ret = new double[s * numVals]; + int off = 0; + for(int i = 0; i < cs; i++) {// rows on right + final int offB = colIndexes.get(i) * cut; + for(int j = 0; j < s; j++) { + ret[off++] = b[offB + aggregateColumns.get(j)]; + } + } + + MatrixBlock db = new MatrixBlock(numVals, s, ret); + return new MatrixBlockDictionary(db); } @Override @@ -529,7 +629,10 @@ public IDictionary rexpandColsWithReference(int max, boolean ignore, boolean cas @Override public double getSparsity() { - return 1d / nRowCol; + if(withEmpty) + return 1d / (nRowCol + 1); + else + return 1d / nRowCol; } @Override @@ -545,13 +648,44 @@ public void TSMMWithScaling(int[] counts, IColIndex rows, IColIndex cols, Matrix @Override public void MMDict(IDictionary right, IColIndex rowsLeft, IColIndex colsRight, MatrixBlock result) { getMBDict().MMDict(right, rowsLeft, colsRight, result); - // should replace with add to right to output cells. + } + + public void MMDictScaling(IDictionary right, IColIndex rowsLeft, IColIndex colsRight, MatrixBlock result, + int[] scaling) { + getMBDict().MMDictScaling(right, rowsLeft, colsRight, result, scaling); } @Override public void MMDictDense(double[] left, IColIndex rowsLeft, IColIndex colsRight, MatrixBlock result) { - getMBDict().MMDictDense(left, rowsLeft, colsRight, result); + // getMBDict().MMDictDense(left, rowsLeft, colsRight, result); // should replace with add to right to output cells. + final int leftSide = rowsLeft.size(); + final int resCols = result.getNumColumns(); + final int commonDim = Math.min(left.length / leftSide, nRowCol); + final double[] resV = result.getDenseBlockValues(); + for(int i = 0; i < leftSide; i++) {// rows in left side + final int offOut = rowsLeft.get(i) * resCols; + final int leftOff = i * leftSide; + for(int j = 0; j < commonDim; j++) { // cols in left side skipping empty from identity + resV[offOut + colsRight.get(j)] += left[leftOff + j]; + } + } + } + + @Override + public void MMDictScalingDense(double[] left, IColIndex rowsLeft, IColIndex colsRight, MatrixBlock result, + int[] scaling) { + final int leftSide = rowsLeft.size(); + final int resCols = result.getNumColumns(); + final int commonDim = Math.min(left.length / leftSide, nRowCol); + final double[] resV = result.getDenseBlockValues(); + for(int i = 0; i < leftSide; i++) {// rows in left side + final int offOut = rowsLeft.get(i) * resCols; + final int leftOff = i * leftSide; + for(int j = 0; j < commonDim; j++) { // cols in left side skipping empty from identity + resV[offOut + colsRight.get(j)] += left[leftOff + j] * scaling[j]; + } + } } @Override @@ -559,6 +693,12 @@ public void MMDictSparse(SparseBlock left, IColIndex rowsLeft, IColIndex colsRig getMBDict().MMDictSparse(left, rowsLeft, colsRight, result); } + @Override + public void MMDictScalingSparse(SparseBlock left, IColIndex rowsLeft, IColIndex colsRight, MatrixBlock result, + int[] scaling) { + getMBDict().MMDictScalingSparse(left, rowsLeft, colsRight, result, scaling); + } + @Override public void TSMMToUpperTriangle(IDictionary right, IColIndex rowsLeft, IColIndex colsRight, MatrixBlock result) { getMBDict().TSMMToUpperTriangle(right, rowsLeft, colsRight, result); diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/IdentityDictionarySlice.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/IdentityDictionarySlice.java index 6a282e8b267..167328871b4 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/IdentityDictionarySlice.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/IdentityDictionarySlice.java @@ -69,13 +69,13 @@ public double[] getValues() { @Override public double getValue(int i) { throw new NotImplementedException(); - } @Override public final double getValue(int r, int c, int nCol) { - throw new NotImplementedException(); - + if(r < l || r > u) + return 0; + return super.getValue(r - l, c, nCol); } @Override @@ -278,6 +278,23 @@ public double getSparsity() { return 1d / nRowCol; } + @Override + public IDictionary preaggValuesFromDense(final int numVals, final IColIndex colIndexes, + final IColIndex aggregateColumns, final double[] b, final int cut) { + return getMBDict().preaggValuesFromDense(numVals, colIndexes, aggregateColumns, b, cut); + } + + @Override + public void addToEntryVectorized(double[] v, int f1, int f2, int f3, int f4, int f5, int f6, int f7, int f8, int t1, + int t2, int t3, int t4, int t5, int t6, int t7, int t8, int nCol) { + throw new NotImplementedException(); + } + + @Override + public void addToEntry(final double[] v, final int fr, final int to, final int nCol, int rep) { + throw new NotImplementedException(); + } + @Override public boolean equals(IDictionary o) { if(o instanceof IdentityDictionarySlice) { diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/MatrixBlockDictionary.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/MatrixBlockDictionary.java index 3995fc4e364..2a800837c7c 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/MatrixBlockDictionary.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/MatrixBlockDictionary.java @@ -88,8 +88,27 @@ else if(check) { } public static MatrixBlockDictionary createDictionary(double[] values, int nCol, boolean check) { - final MatrixBlock mb = Util.matrixBlockFromDenseArray(values, nCol, check); - return create(mb, check); + if(nCol <= 1) { + final MatrixBlock mb = Util.matrixBlockFromDenseArray(values, nCol, check); + return create(mb, check); + } + else { + final int nnz = checkNNz(values); + if((double) nnz / values.length < 0.4D) { + SparseBlock sb = SparseBlockFactory.createFromArray(values, nCol, nnz); + MatrixBlock mb = new MatrixBlock(values.length / nCol, nCol, nnz, sb); + return create(mb, false); + } + else + return create(Util.matrixBlockFromDenseArray(values, nCol, check), false); + } + } + + private static int checkNNz(double[] values) { + int nnz = 0; + for(int i = 0; i < values.length; i++) + nnz += values[i] == 0 ? 0 : 1; + return nnz; } public MatrixBlock getMatrixBlock() { @@ -837,6 +856,9 @@ public DictType getDictType() { @Override public int getNumberOfValues(int ncol) { + + if(ncol != _data.getNumColumns()) + throw new DMLCompressionException("Invalid call to get Number of values assuming wrong number of columns"); return _data.getNumRows(); } @@ -1771,15 +1793,15 @@ public MatrixBlockDictionary preaggValuesFromDense(final int numVals, final ICol } } else { - double[] values = _data.getDenseBlockValues(); - for(int k = 0, off = 0; - k < numVals * colIndexes.size(); - k += colIndexes.size(), off += aggregateColumns.size()) { - for(int h = 0; h < colIndexes.size(); h++) { - int idb = colIndexes.get(h) * cut; + final int cz = colIndexes.size(); + final int az = aggregateColumns.size(); + final double[] values = _data.getDenseBlockValues(); + for(int k = 0, off = 0; k < numVals * cz; k += cz, off += az) { + for(int h = 0; h < cz; h++) { + final int idb = colIndexes.get(h) * cut; double v = values[k + h]; if(v != 0) - for(int i = 0; i < aggregateColumns.size(); i++) + for(int i = 0; i < az; i++) ret[off + i] += v * b[idb + aggregateColumns.get(i)]; } } @@ -1801,10 +1823,14 @@ public IDictionary replace(double pattern, double replace, int nCol) { @Override public IDictionary replaceWithReference(double pattern, double replace, double[] reference) { + if(Util.eq(pattern, Double.NaN)) + throw new NotImplementedException(); + final int nRow = _data.getNumRows(); final int nCol = _data.getNumColumns(); final MatrixBlock ret = new MatrixBlock(nRow, nCol, false); ret.allocateDenseBlock(); + final double[] retV = ret.getDenseBlockValues(); int off = 0; if(_data.isInSparseFormat()) { @@ -2030,6 +2056,15 @@ public void MMDict(IDictionary right, IColIndex rowsLeft, IColIndex colsRight, M right.MMDictDense(_data.getDenseBlockValues(), rowsLeft, colsRight, result); } + @Override + public void MMDictScaling(IDictionary right, IColIndex rowsLeft, IColIndex colsRight, MatrixBlock result, + int[] scaling) { + if(_data.isInSparseFormat()) + right.MMDictScalingSparse(_data.getSparseBlock(), rowsLeft, colsRight, result, scaling); + else + right.MMDictScalingDense(_data.getDenseBlockValues(), rowsLeft, colsRight, result, scaling); + } + @Override public void MMDictDense(double[] left, IColIndex rowsLeft, IColIndex colsRight, MatrixBlock result) { if(_data.isInSparseFormat()) @@ -2038,6 +2073,15 @@ public void MMDictDense(double[] left, IColIndex rowsLeft, IColIndex colsRight, DictLibMatrixMult.MMDictsDenseDense(left, _data.getDenseBlockValues(), rowsLeft, colsRight, result); } + @Override + public void MMDictScalingDense(double[] left, IColIndex rowsLeft, IColIndex colsRight, MatrixBlock result, + int[] scaling) { + if(_data.isInSparseFormat()) + DictLibMatrixMult.MMDictsScalingDenseSparse(left, _data.getSparseBlock(), rowsLeft, colsRight, result, scaling); + else + DictLibMatrixMult.MMDictsScalingDenseDense(left, _data.getDenseBlockValues(), rowsLeft, colsRight, result,scaling); + } + @Override public void MMDictSparse(SparseBlock left, IColIndex rowsLeft, IColIndex colsRight, MatrixBlock result) { @@ -2047,6 +2091,15 @@ public void MMDictSparse(SparseBlock left, IColIndex rowsLeft, IColIndex colsRig DictLibMatrixMult.MMDictsSparseDense(left, _data.getDenseBlockValues(), rowsLeft, colsRight, result); } + @Override + public void MMDictScalingSparse(SparseBlock left, IColIndex rowsLeft, IColIndex colsRight, MatrixBlock result, + int[] scaling) { + if(_data.isInSparseFormat()) + DictLibMatrixMult.MMDictsScalingSparseSparse(left, _data.getSparseBlock(), rowsLeft, colsRight, result, scaling); + else + DictLibMatrixMult.MMDictsScalingSparseDense(left, _data.getDenseBlockValues(), rowsLeft, colsRight, result, scaling); + } + @Override public void TSMMToUpperTriangle(IDictionary right, IColIndex rowsLeft, IColIndex colsRight, MatrixBlock result) { if(_data.isInSparseFormat()) diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/PlaceHolderDict.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/PlaceHolderDict.java index 94fa9ef5289..51c41ffeec6 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/PlaceHolderDict.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/PlaceHolderDict.java @@ -507,4 +507,22 @@ public IDictionary clone() { return new PlaceHolderDict(nVal); } + @Override + public void MMDictScaling(IDictionary right, IColIndex rowsLeft, IColIndex colsRight, MatrixBlock result, + int[] scaling) { + throw new RuntimeException(errMessage); + } + + @Override + public void MMDictScalingDense(double[] left, IColIndex rowsLeft, IColIndex colsRight, MatrixBlock result, + int[] scaling) { + throw new RuntimeException(errMessage); + } + + @Override + public void MMDictScalingSparse(SparseBlock left, IColIndex rowsLeft, IColIndex colsRight, MatrixBlock result, + int[] scaling) { + throw new RuntimeException(errMessage); + } + } diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/QDictionary.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/QDictionary.java index b55a291ae33..ae833dd7a9f 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/QDictionary.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/QDictionary.java @@ -613,4 +613,22 @@ public IDictionary cbind(IDictionary that, int nCol) { public IDictionary reorder(int[] reorder) { throw new NotImplementedException(); } + + @Override + public void MMDictScaling(IDictionary right, IColIndex rowsLeft, IColIndex colsRight, MatrixBlock result, + int[] scaling) { + throw new NotImplementedException(); + } + + @Override + public void MMDictScalingDense(double[] left, IColIndex rowsLeft, IColIndex colsRight, MatrixBlock result, + int[] scaling) { + throw new NotImplementedException(); + } + + @Override + public void MMDictScalingSparse(SparseBlock left, IColIndex rowsLeft, IColIndex colsRight, MatrixBlock result, + int[] scaling) { + throw new NotImplementedException(); + } } diff --git a/src/main/java/org/apache/sysds/runtime/data/SparseBlockFactory.java b/src/main/java/org/apache/sysds/runtime/data/SparseBlockFactory.java index 66f07ab6ada..6b04cf6d71d 100644 --- a/src/main/java/org/apache/sysds/runtime/data/SparseBlockFactory.java +++ b/src/main/java/org/apache/sysds/runtime/data/SparseBlockFactory.java @@ -19,10 +19,16 @@ package org.apache.sysds.runtime.data; +import java.util.Arrays; + +import org.apache.commons.lang.NotImplementedException; +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; import org.apache.sysds.runtime.matrix.data.MatrixBlock; -public abstract class SparseBlockFactory -{ +public abstract class SparseBlockFactory{ + protected static final Log LOG = LogFactory.getLog(SparseBlockFactory.class.getName()); + public static SparseBlock createSparseBlock(int rlen) { return createSparseBlock(MatrixBlock.DEFAULT_SPARSEBLOCK, rlen); @@ -117,4 +123,39 @@ public static SparseBlock createIdentityMatrixWithEmptyRow(int nRowCol){ rowPtr[nRowCol+1] = nRowCol; return new SparseBlockCSR(rowPtr, colIdx, vals, nnz); } + + /** + * Create a sparse block from an array. Note that the nnz count should be absolutely correct for this call to work. + * + * @param valsDense a double array of values linearized. + * @param nCol The number of columns in reach row. + * @param nnz The number of non zero values. + * @return A sparse block. + */ + public static SparseBlock createFromArray(final double[] valsDense, final int nCol, final int nnz) { + final int nRow = valsDense.length / nCol; + if(nnz > 0) { + + final int[] rowPtr = new int[nRow + 1]; + final int[] colIdx = new int[nnz]; + final double[] valsSparse = new double[nnz]; + int off = 0; + for(int i = 0; i < valsDense.length; i++) { + final int mod = i % nCol; + if(mod == 0) + rowPtr[i / nCol] = off; + if(valsDense[i] != 0) { + valsSparse[off] = valsDense[i]; + colIdx[off] = mod; + off++; + } + } + rowPtr[rowPtr.length -1] = off; + + return new SparseBlockCSR(rowPtr, colIdx, valsSparse, nnz); + } + else { + return new SparseBlockMCSR(nRow); // empty MCSR block + } + } } diff --git a/src/test/java/org/apache/sysds/test/TestUtils.java b/src/test/java/org/apache/sysds/test/TestUtils.java index acda5eaf839..e090912e86b 100644 --- a/src/test/java/org/apache/sysds/test/TestUtils.java +++ b/src/test/java/org/apache/sysds/test/TestUtils.java @@ -2044,6 +2044,17 @@ public static double[][] generateTestMatrix(int rows, int cols, double min, doub return matrix; } + public static double[] generateTestVector(int cols, double min, double max, double sparsity, long seed) { + double[] vector = new double[cols]; + Random random = (seed == -1) ? TestUtils.random : new Random(seed); + for(int j = 0; j < cols; j++) { + if(random.nextDouble() > sparsity) + continue; + vector[j] = (random.nextDouble() * (max - min) + min); + } + return vector; + } + /** * * Generates a test matrix with the specified parameters as a MatrixBlock. diff --git a/src/test/java/org/apache/sysds/test/component/compress/dictionary/DictionaryTests.java b/src/test/java/org/apache/sysds/test/component/compress/dictionary/DictionaryTests.java index 91707565f3c..9307930f1d2 100644 --- a/src/test/java/org/apache/sysds/test/component/compress/dictionary/DictionaryTests.java +++ b/src/test/java/org/apache/sysds/test/component/compress/dictionary/DictionaryTests.java @@ -21,21 +21,29 @@ import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotEquals; +import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; import java.util.ArrayList; +import java.util.Arrays; import java.util.Collection; import java.util.List; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.apache.sysds.runtime.compress.DMLCompressionException; -import org.apache.sysds.runtime.compress.colgroup.dictionary.IDictionary; import org.apache.sysds.runtime.compress.colgroup.dictionary.Dictionary; +import org.apache.sysds.runtime.compress.colgroup.dictionary.IDictionary; +import org.apache.sysds.runtime.compress.colgroup.dictionary.IdentityDictionary; import org.apache.sysds.runtime.compress.colgroup.dictionary.MatrixBlockDictionary; +import org.apache.sysds.runtime.compress.colgroup.indexes.ColIndexFactory; +import org.apache.sysds.runtime.compress.colgroup.indexes.IColIndex; import org.apache.sysds.runtime.functionobjects.Builtin; import org.apache.sysds.runtime.functionobjects.Builtin.BuiltinCode; +import org.apache.sysds.runtime.functionobjects.Divide; +import org.apache.sysds.runtime.functionobjects.Minus; import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import org.apache.sysds.runtime.matrix.operators.BinaryOperator; import org.apache.sysds.test.TestUtils; import org.junit.Test; import org.junit.runner.RunWith; @@ -72,6 +80,30 @@ public static Collection data() { addAll(tests, new double[] {1, 2, 3, 4, 5, 6}, 2); addAll(tests, new double[] {1, 2.2, 3.3, 4.4, 5.5, 6.6}, 3); + tests.add(new Object[] {new IdentityDictionary(2), Dictionary.create(new double[] {1, 0, 0, 1}), 2, 2}); + tests.add(new Object[] {new IdentityDictionary(2, true), // + Dictionary.create(new double[] {1, 0, 0, 1, 0, 0}), 3, 2}); + tests.add(new Object[] {new IdentityDictionary(3), // + Dictionary.create(new double[] {1, 0, 0, 0, 1, 0, 0, 0, 1}), 3, 3}); + tests.add(new Object[] {new IdentityDictionary(3, true), // + Dictionary.create(new double[] {1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0}), 4, 3}); + + tests.add(new Object[] {new IdentityDictionary(4), // + Dictionary.create(new double[] {// + 1, 0, 0, 0, // + 0, 1, 0, 0, // + 0, 0, 1, 0, // + 0, 0, 0, 1,// + }), 4, 4}); + tests.add(new Object[] {new IdentityDictionary(4, true), // + Dictionary.create(new double[] {// + 1, 0, 0, 0, // + 0, 1, 0, 0, // + 0, 0, 1, 0, // + 0, 0, 0, 1, // + 0, 0, 0, 0}), + 5, 4}); + create(tests, 30, 300, 0.2); } catch(Exception e) { @@ -405,6 +437,170 @@ public void contains1WithReferenceMinus1() { containsValueWithReference(1.0, getReference(nCol, 3241, -1.0, -1.0)); } + @Test + public void equalsEl() { + assertEquals(a, b); + } + + @Test + public void opRightMinus() { + BinaryOperator op = new BinaryOperator(Minus.getMinusFnObject()); + double[] vals = TestUtils.generateTestVector(nCol, -1, 1, 1.0, 132L); + opRight(op, vals, ColIndexFactory.create(0, nCol)); + } + + @Test + public void opRightMinusNoCol() { + BinaryOperator op = new BinaryOperator(Minus.getMinusFnObject()); + double[] vals = TestUtils.generateTestVector(nCol, -1, 1, 1.0, 132L); + opRight(op, vals); + } + + @Test + public void opRightMinusZero() { + BinaryOperator op = new BinaryOperator(Minus.getMinusFnObject()); + double[] vals = new double[nCol]; + opRight(op, vals, ColIndexFactory.create(0, nCol)); + } + + @Test + public void opRightDivOne() { + BinaryOperator op = new BinaryOperator(Divide.getDivideFnObject()); + double[] vals = new double[nCol]; + Arrays.fill(vals, 1); + opRight(op, vals, ColIndexFactory.create(0, nCol)); + } + + @Test + public void opRightDiv() { + BinaryOperator op = new BinaryOperator(Divide.getDivideFnObject()); + double[] vals = TestUtils.generateTestVector(nCol, -1, 1, 1.0, 232L); + opRight(op, vals, ColIndexFactory.create(0, nCol)); + } + + private void opRight(BinaryOperator op, double[] vals, IColIndex cols) { + IDictionary aa = a.binOpRight(op, vals, cols); + IDictionary bb = b.binOpRight(op, vals, cols); + compare(aa, bb, nRow, nCol); + } + + private void opRight(BinaryOperator op, double[] vals) { + IDictionary aa = a.binOpRight(op, vals); + IDictionary bb = b.binOpRight(op, vals); + compare(aa, bb, nRow, nCol); + } + + @Test + public void testAddToEntry1() { + double[] ret1 = new double[nCol]; + a.addToEntry(ret1, 0, 0, nCol); + double[] ret2 = new double[nCol]; + b.addToEntry(ret2, 0, 0, nCol); + assertTrue(Arrays.equals(ret1, ret2)); + } + + @Test + public void testAddToEntry2() { + double[] ret1 = new double[nCol * 2]; + a.addToEntry(ret1, 0, 1, nCol); + double[] ret2 = new double[nCol * 2]; + b.addToEntry(ret2, 0, 1, nCol); + assertTrue(Arrays.equals(ret1, ret2)); + } + + @Test + public void testAddToEntry3() { + double[] ret1 = new double[nCol * 3]; + a.addToEntry(ret1, 0, 2, nCol); + double[] ret2 = new double[nCol * 3]; + b.addToEntry(ret2, 0, 2, nCol); + assertTrue(Arrays.equals(ret1, ret2)); + } + + @Test + public void testAddToEntry4() { + if(a.getNumberOfValues(nCol) > 2) { + + double[] ret1 = new double[nCol * 3]; + a.addToEntry(ret1, 2, 2, nCol); + double[] ret2 = new double[nCol * 3]; + b.addToEntry(ret2, 2, 2, nCol); + assertTrue(Arrays.equals(ret1, ret2)); + } + } + + @Test + public void testAddToEntryVectorized1() { + try { + double[] ret1 = new double[nCol * 3]; + a.addToEntryVectorized(ret1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 1, 2, 0, 1, nCol); + double[] ret2 = new double[nCol * 3]; + b.addToEntryVectorized(ret2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 1, 2, 0, 1, nCol); + assertTrue(Arrays.equals(ret1, ret2)); + } + catch(Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + } + + @Test + public void testAddToEntryVectorized2() { + try { + + if(a.getNumberOfValues(nCol) > 1) { + double[] ret1 = new double[nCol * 3]; + a.addToEntryVectorized(ret1, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 2, 0, 1, 2, 0, 1, nCol); + double[] ret2 = new double[nCol * 3]; + b.addToEntryVectorized(ret2, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 2, 0, 1, 2, 0, 1, nCol); + assertTrue("Error: " + a.getClass().getSimpleName() + " " + b.getClass().getSimpleName(), + Arrays.equals(ret1, ret2)); + } + } + catch(Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + } + + @Test + public void testAddToEntryVectorized3() { + try { + + if(a.getNumberOfValues(nCol) > 2) { + double[] ret1 = new double[nCol * 3]; + a.addToEntryVectorized(ret1, 1, 2, 1, 2, 1, 0, 1, 0, 0, 1, 2, 0, 1, 2, 0, 1, nCol); + double[] ret2 = new double[nCol * 3]; + b.addToEntryVectorized(ret2, 1, 2, 1, 2, 1, 0, 1, 0, 0, 1, 2, 0, 1, 2, 0, 1, nCol); + assertTrue("Error: " + a.getClass().getSimpleName() + " " + b.getClass().getSimpleName(), + Arrays.equals(ret1, ret2)); + } + } + catch(Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + } + + @Test + public void testAddToEntryVectorized4() { + try { + + if(a.getNumberOfValues(nCol) > 3) { + double[] ret1 = new double[nCol * 57]; + a.addToEntryVectorized(ret1, 3, 3, 0, 3, 0, 2, 0, 3, 20, 1, 12, 2, 10, 3, 6, 56, nCol); + double[] ret2 = new double[nCol * 57]; + b.addToEntryVectorized(ret2, 3, 3, 0, 3, 0, 2, 0, 3, 20, 1, 12, 2, 10, 3, 6, 56, nCol); + assertTrue("Error: " + a.getClass().getSimpleName() + " " + b.getClass().getSimpleName(), + Arrays.equals(ret1, ret2)); + } + } + catch(Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + } + public void containsValueWithReference(double value, double[] reference) { assertEquals(// a.containsValueWithReference(value, reference), // @@ -412,9 +608,37 @@ public void containsValueWithReference(double value, double[] reference) { } private static void compare(IDictionary a, IDictionary b, int nRow, int nCol) { - for(int i = 0; i < nRow; i++) - for(int j = 0; j < nCol; j++) - assertEquals(a.getValue(i, j, nCol), b.getValue(i, j, nCol), 0.0001); + try { + + String errorM = a.getClass().getSimpleName() + " " + b.getClass().getSimpleName(); + for(int i = 0; i < nRow; i++) + for(int j = 0; j < nCol; j++) + assertEquals(errorM, a.getValue(i, j, nCol), b.getValue(i, j, nCol), 0.0001); + } + catch(Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + } + + @Test + public void preaggValuesFromDense() { + try { + + final int nv = a.getNumberOfValues(nCol); + IColIndex idc = ColIndexFactory.create(0, nCol); + + double[] bv = TestUtils.generateTestVector(nCol * nCol, -1, 1, 1.0, 321521L); + + IDictionary aa = a.preaggValuesFromDense(nv, idc, idc, bv, nCol); + IDictionary bb = b.preaggValuesFromDense(nv, idc, idc, bv, nCol); + + compare(aa, bb, aa.getNumberOfValues(nCol), nCol); + } + catch(Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } } public void productWithDefault(double retV, double[] def) { diff --git a/src/test/java/org/apache/sysds/test/component/matrix/SparseFactory.java b/src/test/java/org/apache/sysds/test/component/matrix/SparseFactory.java new file mode 100644 index 00000000000..6f80eb4f41a --- /dev/null +++ b/src/test/java/org/apache/sysds/test/component/matrix/SparseFactory.java @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.component.matrix; + +import static org.junit.Assert.assertEquals; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.apache.sysds.runtime.data.SparseBlock; +import org.apache.sysds.runtime.data.SparseBlockFactory; +import org.junit.Test; + +public class SparseFactory { + protected static final Log LOG = LogFactory.getLog(SparseFactory.class.getName()); + + @Test + public void testCreateFromArray() { + double[] dense = new double[] {0, 0, 0, 1, 1, 1, 0, 0, 0}; + SparseBlock sb = SparseBlockFactory.createFromArray(dense, 3, 3); + + assertEquals(0, sb.get(0, 0), 0.0); + assertEquals(0, sb.get(1, 1), 1.0); + assertEquals(0, sb.get(2, 2), 0.0); + } +} From 3126e5f794ffc46ca66a61ebce28999fd952b09f Mon Sep 17 00:00:00 2001 From: Sebastian Baunsgaard Date: Mon, 30 Oct 2023 14:49:01 +0100 Subject: [PATCH 21/28] [MINOR] Fix Empty Binary CLA Empty This commit fixes binary Matrix Vector/Matrix CLA operations to support empty sides in some edge case not supported yet, for instance <=. --- .../compress/lib/CLALibBinaryCellOp.java | 30 +++++++++++++++---- 1 file changed, 24 insertions(+), 6 deletions(-) diff --git a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibBinaryCellOp.java b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibBinaryCellOp.java index 13e5e3c9381..ede9ca46aad 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibBinaryCellOp.java +++ b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibBinaryCellOp.java @@ -74,9 +74,14 @@ public static MatrixBlock binaryOperationsRight(BinaryOperator op, CompressedMat ScalarOperator sop = new RightScalarOperator(op.fn, that.getValue(0, 0), op.getNumThreads()); return CLALibScalar.scalarOperations(sop, m1, result); } - if(that.isEmpty()) + else if(that.isEmpty()) return binaryOperationsEmpty(op, m1, that, result); + else + return binaryOperationsRightFiltered(op, m1, that, result); + } + private static MatrixBlock binaryOperationsRightFiltered(BinaryOperator op, CompressedMatrixBlock m1, + MatrixBlock that, MatrixBlock result) { LibMatrixBincell.isValidDimensionsBinaryExtended(m1, that); BinaryAccessType atype = LibMatrixBincell.getBinaryAccessTypeExtended(m1, that); @@ -113,17 +118,16 @@ private static MatrixBlock binaryOperationsEmpty(BinaryOperator op, CompressedMa final ValueFunction fn = op.fn; if(fn instanceof Multiply) - result = CompressedMatrixBlockFactory.createConstant(m1Row, m1Col, 0); + return CompressedMatrixBlockFactory.createConstant(m1Row, m1Col, 0); else if(fn instanceof Minus1Multiply) - result = CompressedMatrixBlockFactory.createConstant(m1Row, m1Col, 1); + return CompressedMatrixBlockFactory.createConstant(m1Row, m1Col, 1); else if(fn instanceof Minus || fn instanceof Plus || fn instanceof MinusMultiply || fn instanceof PlusMultiply) { CompressedMatrixBlock ret = new CompressedMatrixBlock(); ret.copy(m1); return ret; } else - throw new NotImplementedException("Function Type: " + fn); - return result; + return binaryOperationsRightFiltered(op, m1, that, result); } private static MatrixBlock selectProcessingBasedOnAccessType(BinaryOperator op, CompressedMatrixBlock m1, @@ -612,8 +616,11 @@ private final void processLeftDense(final int rl, final int ru) { } private final void processRight(final int rl, final int ru) { + + if(_m2.isEmpty()) + processRightEmpty(rl, ru); // all exec should have ret on left side - if(_m2.isInSparseFormat()) + else if(_m2.isInSparseFormat()) processRightSparse(rl, ru); else processRightDense(rl, ru); @@ -662,6 +669,17 @@ private final void processRightDense(final int rl, final int ru) { retV[c] = _op.fn.execute(retV[c], m2V[c]); } } + + private final void processRightEmpty(final int rl, final int ru) { + final DenseBlock rv = _ret.getDenseBlock(); + final int cols = _ret.getNumColumns(); + for(int r = rl; r < ru; r++) { + final double[] retV = rv.values(r); + int off = rv.pos(r); + for(int c = off; c < cols + off; c++) + retV[c] = _op.fn.execute(retV[c], 0); + } + } } private static class BinaryMVColLeftTask implements Callable { From fb605775865d2ec0fbcc3aff81975576f8baa5e1 Mon Sep 17 00:00:00 2001 From: Sebastian Baunsgaard Date: Mon, 30 Oct 2023 15:05:17 +0100 Subject: [PATCH 22/28] [MINOR] Parallel Compressed LMM --- .../compress/lib/CLALibLeftMultBy.java | 96 +++++++++++++++++-- .../runtime/compress/lib/CLALibMMChain.java | 42 ++++++++ .../compress/lib/CLALibRightMultBy.java | 4 +- 3 files changed, 133 insertions(+), 9 deletions(-) diff --git a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibLeftMultBy.java b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibLeftMultBy.java index 6029a87d466..30c1109d3a4 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibLeftMultBy.java +++ b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibLeftMultBy.java @@ -32,11 +32,14 @@ import org.apache.commons.logging.LogFactory; import org.apache.sysds.runtime.DMLRuntimeException; import org.apache.sysds.runtime.compress.CompressedMatrixBlock; +import org.apache.sysds.runtime.compress.DMLCompressionException; import org.apache.sysds.runtime.compress.colgroup.AColGroup; import org.apache.sysds.runtime.compress.colgroup.APreAgg; import org.apache.sysds.runtime.data.DenseBlock; import org.apache.sysds.runtime.data.SparseBlock; import org.apache.sysds.runtime.functionobjects.Plus; +import org.apache.sysds.runtime.matrix.data.LibMatrixBincell; +import org.apache.sysds.runtime.matrix.data.LibMatrixMult; import org.apache.sysds.runtime.matrix.data.LibMatrixReorg; import org.apache.sysds.runtime.matrix.data.MatrixBlock; import org.apache.sysds.runtime.matrix.operators.BinaryOperator; @@ -45,7 +48,7 @@ public final class CLALibLeftMultBy { private static final Log LOG = LogFactory.getLog(CLALibLeftMultBy.class.getName()); - private CLALibLeftMultBy(){ + private CLALibLeftMultBy() { // private constructor } @@ -139,7 +142,15 @@ else if(!(ret.getNumColumns() == numColumnsOutput && ret.getNumRows() == numRows } private static MatrixBlock leftMultByCompressedTransposedMatrix(CompressedMatrixBlock right, - CompressedMatrixBlock left, MatrixBlock ret, int k) { + CompressedMatrixBlock left, final MatrixBlock ret, int k) { + if(k > 1 && ret.getInMemorySize() < 1000000) + return leftMultByCompressedTransposedMatrixParallel(right, left, ret, k); + else + return leftMultByCompressedTransposedMatrixSingleThread(right, left, ret); + } + + private static MatrixBlock leftMultByCompressedTransposedMatrixParallel(CompressedMatrixBlock right, + CompressedMatrixBlock left, final MatrixBlock ret, int k) { final int sd = right.getNumRows(); // shared dim final int cr = right.getNumColumns(); @@ -149,18 +160,88 @@ private static MatrixBlock leftMultByCompressedTransposedMatrix(CompressedMatrix final List leftCG = left.getColGroups(); final boolean containsRight = CLALibUtils.shouldPreFilter(rightCG); - double[] cR = containsRight ? new double[cr] : null; + final double[] cR = containsRight ? new double[cr] : null; final List fRight = CLALibUtils.filterGroups(rightCG, cR); final boolean containsLeft = CLALibUtils.shouldPreFilter(leftCG); - double[] cL = containsLeft ? new double[rl] : null; + final double[] cL = containsLeft ? new double[rl] : null; final List fLeft = CLALibUtils.filterGroups(leftCG, cL); + // Force dense output + ret.setNonZeros((long) ret.getNumRows() * ret.getNumColumns()); + ret.allocateDenseBlock(); + + final ExecutorService ex = CommonThreadPool.get(k); + final List> t = new ArrayList<>(); + + for(int j = 0; j < fLeft.size(); j++) { + final int jj = j; + t.add(ex.submit(() -> { + MatrixBlock retT = new MatrixBlock(ret.getNumRows(), ret.getNumColumns(), false); + retT.allocateDenseBlock(); + for(int i = 0; i < fRight.size(); i++) { + fRight.get(i).leftMultByAColGroup(fLeft.get(jj), retT, sd); + } + retT.examSparsity(true); + return retT; + })); + } + + try { + final double[] retV = ret.getDenseBlockValues(); + if(containsLeft && containsRight) + // if both -- multiply the left and right vectors scaling by number of shared dim + outerProductWithScaling(cL, cR, sd, retV); + if(containsLeft) // if left -- multiply left with right sum + outerProduct(cL, CLALibUtils.getColSum(fRight, cr, sd), retV); + if(containsRight)// if right -- multiply right with left sum + outerProduct(CLALibUtils.getColSum(fLeft, rl, sd), cR, retV); + for(Future f : t) { + MatrixBlock mb = f.get(); + if(!mb.isEmpty()) { + if(mb.isInSparseFormat()) + LibMatrixBincell.bincellOpInPlaceRight(ret, mb, new BinaryOperator(Plus.getPlusFnObject())); + else if(mb.getDenseBlock().isContiguous()) + LibMatrixMult.vectAdd(mb.getDenseBlockValues(), retV, 0, 0, retV.length); + else + LibMatrixBincell.bincellOpInPlaceRight(ret, mb, new BinaryOperator(Plus.getPlusFnObject())); + } + } + ret.recomputeNonZeros(k); + } + catch(Exception e) { + throw new DMLCompressionException("Failed parallel Left Compressed Mult", e); + } + finally { + ex.shutdown(); + } + return ret; + } + + private static MatrixBlock leftMultByCompressedTransposedMatrixSingleThread(CompressedMatrixBlock right, + CompressedMatrixBlock left, final MatrixBlock ret) { + final int sd = right.getNumRows(); // shared dim + final int cr = right.getNumColumns(); + final int rl = left.getNumColumns(); + + final List rightCG = right.getColGroups(); + final List leftCG = left.getColGroups(); + + final boolean containsRight = CLALibUtils.shouldPreFilter(rightCG); + final double[] cR = containsRight ? new double[cr] : null; + final List fRight = CLALibUtils.filterGroups(rightCG, cR); + + final boolean containsLeft = CLALibUtils.shouldPreFilter(leftCG); + final double[] cL = containsLeft ? new double[rl] : null; + final List fLeft = CLALibUtils.filterGroups(leftCG, cL); + + // Force dense output + ret.setNonZeros((long) ret.getNumRows() * ret.getNumColumns()); + ret.allocateDenseBlock(); for(int j = 0; j < fLeft.size(); j++) for(int i = 0; i < fRight.size(); i++) fRight.get(i).leftMultByAColGroup(fLeft.get(j), ret, sd); - - double[] retV = ret.getDenseBlockValues(); + final double[] retV = ret.getDenseBlockValues(); if(containsLeft && containsRight) // if both -- multiply the left and right vectors scaling by number of shared dim outerProductWithScaling(cL, cR, sd, retV); @@ -169,7 +250,6 @@ private static MatrixBlock leftMultByCompressedTransposedMatrix(CompressedMatrix if(containsRight)// if right -- multiply right with left sum outerProduct(CLALibUtils.getColSum(fLeft, rl, sd), cR, retV); ret.recomputeNonZeros(); - return ret; } @@ -218,7 +298,7 @@ private static MatrixBlock LMM(List colGroups, MatrixBlock that, Matr LMMParallel(noPreAggGroups, preAggGroups, that, ret, null, overlapping, k); } - ret.recomputeNonZeros(); + ret.recomputeNonZeros(k); ret.examSparsity(); return ret; } diff --git a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibMMChain.java b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibMMChain.java index bc164a5e91b..060c7368717 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibMMChain.java +++ b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibMMChain.java @@ -35,6 +35,21 @@ import org.apache.sysds.runtime.matrix.data.MatrixBlock; import org.apache.sysds.runtime.matrix.operators.BinaryOperator; +/** + * Support compressed MM chain operation to fuse the following cases : + * + *

+ * XtXv == (t(X) %*% (X %*% v)) + *

+ * + *

+ * XtwXv == (t(X) %*% (w * (X %*% v))) + *

+ * + *

+ * XtXvy == (t(X) %*% ((X %*% v) - y)) + *

+ */ public final class CLALibMMChain { static final Log LOG = LogFactory.getLog(CLALibMMChain.class.getName()); @@ -42,6 +57,33 @@ private CLALibMMChain() { // private constructor } + /** + * Support compressed MM chain operation to fuse the following cases : + * + *

+ * XtXv == (t(X) %*% (X %*% v)) + *

+ * + *

+ * XtwXv == (t(X) %*% (w * (X %*% v))) + *

+ * + *

+ * XtXvy == (t(X) %*% ((X %*% v) - y)) + *

+ * + * Note the point of this optimization is that v and w always are vectors. This means in practice the all the compute + * is faster if the intermediates are exploited. + * + * + * @param x Is the X part of the chain optimized kernel + * @param v Is the mandatory v part of the chain + * @param w Is the optional w port of t the chain + * @param out The output to put the result into. Can also be returned and in some cases will not be used. + * @param ctype either XtwXv, XtXv or XtXvy + * @param k the parallelization degree + * @return The result either in the given output or a new allocation + */ public static MatrixBlock mmChain(CompressedMatrixBlock x, MatrixBlock v, MatrixBlock w, MatrixBlock out, ChainType ctype, int k) { diff --git a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibRightMultBy.java b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibRightMultBy.java index 39468b0cab8..2eef5f9f3f8 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibRightMultBy.java +++ b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibRightMultBy.java @@ -243,7 +243,9 @@ private static boolean RMMParallel(List filteredGroups, MatrixBlock t catch(InterruptedException | ExecutionException e) { throw new DMLRuntimeException(e); } - pool.shutdown(); + finally{ + pool.shutdown(); + } return containsNull; } From 7136a6aa922867aba3b047962e3931c820a66fac Mon Sep 17 00:00:00 2001 From: Sebastian Baunsgaard Date: Mon, 30 Oct 2023 15:07:12 +0100 Subject: [PATCH 23/28] [MINOR] Workload Analyzer Warn on unknown The AWARE workload analyzer previously errored out on operations that are unknown, now instead we write a warning, and assume all unknown operations are decompressing the output. --- .../compress/workload/WorkloadAnalyzer.java | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/src/main/java/org/apache/sysds/runtime/compress/workload/WorkloadAnalyzer.java b/src/main/java/org/apache/sysds/runtime/compress/workload/WorkloadAnalyzer.java index 68b60438fa9..a4c15b2b533 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/workload/WorkloadAnalyzer.java +++ b/src/main/java/org/apache/sysds/runtime/compress/workload/WorkloadAnalyzer.java @@ -60,7 +60,6 @@ import org.apache.sysds.parser.StatementBlock; import org.apache.sysds.parser.WhileStatement; import org.apache.sysds.parser.WhileStatementBlock; -import org.apache.sysds.runtime.compress.DMLCompressionException; import org.apache.sysds.runtime.compress.workload.AWTreeNode.WTNodeType; import org.apache.sysds.utils.Explain; @@ -68,7 +67,7 @@ public class WorkloadAnalyzer { private static final Log LOG = LogFactory.getLog(WorkloadAnalyzer.class.getName()); // indicator for more aggressive compression of intermediates public static boolean ALLOW_INTERMEDIATE_CANDIDATES = false; - // avoid wtree construction for assumptionly already compressed intermediates + // avoid w-tree construction for already compressed intermediates // (due to conditional control flow this might miss compression opportunities) public static boolean PRUNE_COMPRESSED_INTERMEDIATES = true; @@ -96,6 +95,7 @@ public static Map getAllCandidateWorkloads(DMLProgram prog) { // construct workload tree for candidate WorkloadAnalyzer wa = new WorkloadAnalyzer(prog); WTreeRoot tree = wa.createWorkloadTree(cand); + map.put(cand.getHopID(), tree); allWAs.add(wa); } @@ -337,6 +337,7 @@ private void createWorkloadTree(Hop hop, DMLProgram prog, AWTreeNode parent, Set } private void createOp(Hop hop, AWTreeNode parent) { + if(hop.getDataType().isMatrix()) { Op o = null; if(HopRewriteUtils.isData(hop, OpOpData.PERSISTENTREAD, OpOpData.TRANSIENTREAD)) @@ -425,7 +426,11 @@ else if(HopRewriteUtils.isBinary(hop, OpOp2.RBIND)) { o.setOverlapping(); } else if(ol) { - treeLookup.get(in.get(0).getHopID()).setDecompressing(); + if(in.get(0) != null) { + Op oo = treeLookup.get(in.get(0).getHopID()); + if(oo != null) + oo.setDecompressing(); + } return; } else { @@ -500,16 +505,15 @@ else if(isCompressed(o2)) { setDecompressionOnAllInputs(hop, parent); } } - else if(hop instanceof ParameterizedBuiltinOp) { + else if(hop instanceof ParameterizedBuiltinOp || hop instanceof NaryOp) { setDecompressionOnAllInputs(hop, parent); return; } - else if(hop instanceof NaryOp){ + else { + LOG.warn("Unknown Hop:" + hop.getClass().getSimpleName() + "\n" + Explain.explain(hop)); setDecompressionOnAllInputs(hop, parent); return; } - else - throw new DMLCompressionException("Unknown Hop:" +hop.getClass().getSimpleName() +"\n" + Explain.explain(hop)); o = o != null ? o : new OpNormal(hop, RewriteCompressedReblock.satisfiesSizeConstraintsForCompression(hop)); treeLookup.put(hop.getHopID(), o); From 948737390683c2a7b11e3f79d2a0303da4c77738 Mon Sep 17 00:00:00 2001 From: Sebastian Baunsgaard Date: Mon, 30 Oct 2023 15:33:43 +0100 Subject: [PATCH 24/28] [MINOR] fix empty nnz Compressed LLM --- .../apache/sysds/runtime/compress/lib/CLALibLeftMultBy.java | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibLeftMultBy.java b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibLeftMultBy.java index 30c1109d3a4..d0983d4ae06 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibLeftMultBy.java +++ b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibLeftMultBy.java @@ -168,8 +168,8 @@ private static MatrixBlock leftMultByCompressedTransposedMatrixParallel(Compress final List fLeft = CLALibUtils.filterGroups(leftCG, cL); // Force dense output - ret.setNonZeros((long) ret.getNumRows() * ret.getNumColumns()); ret.allocateDenseBlock(); + ret.setNonZeros((long) ret.getNumRows() * ret.getNumColumns()); final ExecutorService ex = CommonThreadPool.get(k); final List> t = new ArrayList<>(); @@ -196,6 +196,7 @@ private static MatrixBlock leftMultByCompressedTransposedMatrixParallel(Compress outerProduct(cL, CLALibUtils.getColSum(fRight, cr, sd), retV); if(containsRight)// if right -- multiply right with left sum outerProduct(CLALibUtils.getColSum(fLeft, rl, sd), cR, retV); + for(Future f : t) { MatrixBlock mb = f.get(); if(!mb.isEmpty()) { From 798a0df3fc179b3a4d7a903fd3755b23f52828c2 Mon Sep 17 00:00:00 2001 From: Sebastian Baunsgaard Date: Fri, 20 Oct 2023 17:38:34 +0200 Subject: [PATCH 25/28] [MINOR] Performance tests for compressed behavior Closes #1928 --- .../org/apache/sysds/performance/Main.java | 8 +- .../apache/sysds/performance/PerfUtil.java | 12 +-- .../apache/sysds/performance/TimingUtils.java | 2 + .../performance/compression/Serialize.java | 79 +++++++++++++++---- .../compression/TransformPerf.java | 14 ++-- .../sysds/performance/generators/Const.java | 2 +- .../performance/generators/ConstFrame.java | 64 +++++++-------- .../performance/generators/FrameFile.java | 76 +++++++++--------- .../generators/FrameTransformFile.java | 78 ++++++++---------- .../performance/generators/MatrixFile.java | 46 +++++------ .../performance/simple/DetectTypeArray.java | 38 ++++----- .../apache/sysds/performance/simple/NNZ.java | 48 +++++------ 12 files changed, 253 insertions(+), 214 deletions(-) diff --git a/src/test/java/org/apache/sysds/performance/Main.java b/src/test/java/org/apache/sysds/performance/Main.java index 4e8f566a302..185a43e2c34 100644 --- a/src/test/java/org/apache/sysds/performance/Main.java +++ b/src/test/java/org/apache/sysds/performance/Main.java @@ -132,8 +132,10 @@ private static void run11(String[] args, int id) throws Exception { double sparsity = Double.parseDouble(args[4]); int k = Integer.parseInt(args[5]); int n = Integer.parseInt(args[6]); - - Serialize s = new Serialize(n, new ConstMatrix(rows, cols, unique, sparsity), k); + //args[7] is id + Serialize s = (args.length == 9) ? // + new Serialize(n, new ConstMatrix(rows, cols, unique, sparsity), k) : // + new Serialize(n, new ConstMatrix(rows, cols, unique, sparsity), k, args[7], args[8]); if(id == -1) s.run(); @@ -179,7 +181,7 @@ private static void run15(String[] args) throws Exception { private static void run16(String[] args) { int len = Integer.parseInt(args[1]); - MatrixBlock mb = TestUtils.ceil(TestUtils.generateTestMatrixBlock(len, len, 0, 100, 0.01, len +1)); + MatrixBlock mb = TestUtils.ceil(TestUtils.generateTestMatrixBlock(len, len, 0, 100, 0.01, len + 1)); System.out.println(mb); } diff --git a/src/test/java/org/apache/sysds/performance/PerfUtil.java b/src/test/java/org/apache/sysds/performance/PerfUtil.java index f93b03bdb39..9115bf5878f 100644 --- a/src/test/java/org/apache/sysds/performance/PerfUtil.java +++ b/src/test/java/org/apache/sysds/performance/PerfUtil.java @@ -25,10 +25,10 @@ public interface PerfUtil { - public static String readSpec(String path) throws IOException { - InputStream in = new FileInputStream(path); - String spec = new String(in.readAllBytes()); - in.close(); - return spec; - } + public static String readSpec(String path) throws IOException { + InputStream in = new FileInputStream(path); + String spec = new String(in.readAllBytes()); + in.close(); + return spec; + } } diff --git a/src/test/java/org/apache/sysds/performance/TimingUtils.java b/src/test/java/org/apache/sysds/performance/TimingUtils.java index 11e2c1dca52..0faf01c9b02 100644 --- a/src/test/java/org/apache/sysds/performance/TimingUtils.java +++ b/src/test/java/org/apache/sysds/performance/TimingUtils.java @@ -21,6 +21,7 @@ import java.util.Arrays; +import org.apache.sysds.api.DMLScript; import org.apache.sysds.performance.generators.IGenerate; import org.apache.sysds.runtime.controlprogram.parfor.stat.Timing; @@ -93,6 +94,7 @@ public static double[] time(F f, F c, F b, int rep, IGenerate bq) throws Inte b.run(); while(bq.isEmpty()) Thread.sleep(bq.defaultWaitTime()); + DMLScript.SEED = i + 1000; time(f, times, i); c.run(); } diff --git a/src/test/java/org/apache/sysds/performance/compression/Serialize.java b/src/test/java/org/apache/sysds/performance/compression/Serialize.java index 12316874c11..802e7f3a7b1 100644 --- a/src/test/java/org/apache/sysds/performance/compression/Serialize.java +++ b/src/test/java/org/apache/sysds/performance/compression/Serialize.java @@ -38,9 +38,13 @@ import org.apache.sysds.runtime.compress.CompressedMatrixBlockFactory; import org.apache.sysds.runtime.compress.CompressionStatistics; import org.apache.sysds.runtime.compress.colgroup.scheme.CompressionScheme; +import org.apache.sysds.runtime.compress.io.ReaderCompressed; import org.apache.sysds.runtime.compress.io.WriterCompressed; import org.apache.sysds.runtime.compress.lib.CLALibScheme; +import org.apache.sysds.runtime.io.MatrixReader; import org.apache.sysds.runtime.io.MatrixWriter; +import org.apache.sysds.runtime.io.ReaderBinaryBlock; +import org.apache.sysds.runtime.io.ReaderBinaryBlockParallel; import org.apache.sysds.runtime.io.WriterBinaryBlock; import org.apache.sysds.runtime.io.WriterBinaryBlockParallel; import org.apache.sysds.runtime.matrix.data.MatrixBlock; @@ -49,26 +53,37 @@ public class Serialize extends APerfTest { private final String file; private final int k; + private final String codec; public Serialize(int N, IGenerate gen) { super(N, gen); - file = "tmp/perf-tmp.bin"; + file = "./tmp/perf-tmp.bin"; k = 1; + codec = "none"; } public Serialize(int N, IGenerate gen, int k) { super(N, gen); - file = "tmp/perf-tmp.bin"; + file = "./tmp/perf-tmp.bin"; this.k = k; + codec = "none"; } public Serialize(int N, IGenerate gen, int k, String file) { super(N, gen); this.file = file; this.k = k; + codec = "none"; } + public Serialize(int N, IGenerate gen, int k, String file, String codec) { + super(N, gen); + this.file = file == null ? "tmp/perf-tmp.bin" : file; + this.k = k; + this.codec = codec; + } + public void run() throws Exception, InterruptedException { CompressedMatrixBlock.debug = true; CompressedMatrixBlock.debug = false; @@ -81,27 +96,35 @@ public void run() throws Exception, InterruptedException { if(k == 1) { ConfigurationManager.getCompilerConfig().set(ConfigType.PARALLEL_CP_WRITE_BINARYFORMATS, false); } - + + ConfigurationManager.getDMLConfig().setTextValue(DMLConfig.IO_COMPRESSION_CODEC, codec); + System.out.println(ConfigurationManager.getDMLConfig().getTextValue(DMLConfig.IO_COMPRESSION_CODEC)); warmup(() -> sumTask(k), N); + // execute(() -> writeUncompressed(k), "Serialize"); // execute(() -> diskUncompressed(k), "CustomDisk"); - - execute(() -> standardIO(k), () -> setFileSize(), () -> cleanup(), "StandardDisk"); - // execute(() -> compressTask(k), "Compress Normal"); // execute(() -> writeCompressTask(k), "Compress Normal Serialize"); // execute(() -> diskCompressTask(k), "Compress Normal CustomDisk"); - - execute(() -> standardCompressedIO(k), () -> setFileSize(), () -> cleanup(), "Compress StandardIO"); - - final CompressionScheme sch2 = CLALibScheme.getScheme(getC()); // execute(() -> updateAndApplySchemeFused(sch2, k), "Update&Apply Scheme Fused"); // execute(() -> writeUpdateAndApplySchemeFused(sch2, k), "Update&Apply Scheme Fused Serialize"); // execute(() -> diskUpdateAndApplySchemeFused(sch2, k), "Update&Apply Scheme Fused Disk"); - + + execute(() -> standardIO(k), () -> setFileSize(), () -> cleanup(), "StandardDisk"); + execute(() -> standardCompressedIO(k), () -> setFileSize(), () -> cleanup(), "Compress StandardIO"); + final CompressionScheme sch2 = CLALibScheme.getScheme(getC()); execute(() -> standardCompressedIOUpdateAndApply(sch2, k), () -> setFileSize(), () -> cleanup(), "Update&Apply Standard IO"); + + // write the input file to disk. + standardIO(k); + execute(() -> standardIORead(k), "StandardRead"); + cleanup(); + // write compressed input file to disk + standardCompressedIOUpdateAndApply(sch2, k); + // standardCompressedIO( k); + execute(() -> standardCompressedRead(k), "StandardCompressedRead"); } public void run(int i) throws Exception, InterruptedException { @@ -177,6 +200,19 @@ private void standardIO(int k) { } } + private void standardIORead(int k) { + try { + MatrixBlock mb = gen.take(); + MatrixReader r = (k == 1) ? new ReaderBinaryBlock(false) : new ReaderBinaryBlockParallel(false); + MatrixBlock mbr = r.readMatrixFromHDFS(file, mb.getNumRows(), mb.getNumColumns(), ConfigurationManager.getBlocksize(), mb.getNonZeros()); + + ret.add(new InOut(mb.getInMemorySize(),mbr.getInMemorySize())); + } + catch(Exception e) { + throw new RuntimeException(e); + } + } + private void compressTask(int k) { MatrixBlock mb = gen.take(); long in = mb.getInMemorySize(); @@ -213,6 +249,20 @@ private void standardCompressedIO(int k) { } } + private void standardCompressedRead(int k) { + try { + MatrixBlock mb = gen.take(); + ReaderCompressed r = new ReaderCompressed(k); + MatrixBlock mbr = r.readMatrixFromHDFS(file, mb.getNumRows(), mb.getNumColumns(), ConfigurationManager.getBlocksize(), mb.getNonZeros()); + + ret.add(new InOut(mb.getInMemorySize(),mbr.getInMemorySize())); + } + catch(Exception e) { + throw new RuntimeException(e); + } + } + + // private void standardCompressedIOPipelined(int k) { // try { // // MatrixWriter w = new WriterBinaryBlockParallel(1); @@ -443,15 +493,14 @@ private void cleanup() { else fd.delete(); } + } private boolean deleteDirectory(File directoryToBeDeleted) { File[] allContents = directoryToBeDeleted.listFiles(); - if(allContents != null) { - for(File file : allContents) { + if(allContents != null) + for(File file : allContents) deleteDirectory(file); - } - } return directoryToBeDeleted.delete(); } diff --git a/src/test/java/org/apache/sysds/performance/compression/TransformPerf.java b/src/test/java/org/apache/sysds/performance/compression/TransformPerf.java index 454d7c7cfee..7d85a30b794 100644 --- a/src/test/java/org/apache/sysds/performance/compression/TransformPerf.java +++ b/src/test/java/org/apache/sysds/performance/compression/TransformPerf.java @@ -50,14 +50,18 @@ public void run() throws Exception { System.out.println(this); CompressedMatrixBlock.debug = true; - // execute(() -> detectSchema(k), "Detect Schema"); - // execute(() -> detectAndApply(k), "Detect&Apply Frame Schema"); + System.out.println(String.format("Unknown mem size: %30d", gen.take().getInMemorySize())); - updateGen(); + execute(() -> detectSchema(k), "Detect Schema"); + execute(() -> detectAndApply(k), "Detect&Apply Frame Schema"); + execute(() -> transformEncode(k), "TransformEncode Def"); + execute(() -> transformEncodeCompressed(k), "TransformEncode Comp"); - // execute(() -> detectAndApply(k), "Detect&Apply Frame Schema Known"); + updateGen(); - // execute(() -> transformEncode(k), "TransformEncode Def"); + System.out.println(String.format("Known mem size: %30d", gen.take().getInMemorySize())); + System.out.println(gen.take().slice(0, 10)); + execute(() -> transformEncode(k), "TransformEncode Def"); execute(() -> transformEncodeCompressed(k), "TransformEncode Comp"); } diff --git a/src/test/java/org/apache/sysds/performance/generators/Const.java b/src/test/java/org/apache/sysds/performance/generators/Const.java index 2d3adc1aced..7cb4abd40d5 100644 --- a/src/test/java/org/apache/sysds/performance/generators/Const.java +++ b/src/test/java/org/apache/sysds/performance/generators/Const.java @@ -20,5 +20,5 @@ package org.apache.sysds.performance.generators; public interface Const extends IGenerate { - public void change(T t); + public void change(T t); } diff --git a/src/test/java/org/apache/sysds/performance/generators/ConstFrame.java b/src/test/java/org/apache/sysds/performance/generators/ConstFrame.java index 13f7392380a..075fd3efc90 100644 --- a/src/test/java/org/apache/sysds/performance/generators/ConstFrame.java +++ b/src/test/java/org/apache/sysds/performance/generators/ConstFrame.java @@ -25,43 +25,43 @@ public class ConstFrame implements Const { - protected FrameBlock fb; + protected FrameBlock fb; - public ConstFrame(FrameBlock fb) { - this.fb = fb; - } + public ConstFrame(FrameBlock fb) { + this.fb = fb; + } - @Override - public FrameBlock take() { - return fb; - } + @Override + public FrameBlock take() { + return fb; + } - @Override - public void generate(int N) throws InterruptedException { - // do nothing - } + @Override + public void generate(int N) throws InterruptedException { + // do nothing + } - @Override - public final boolean isEmpty() { - return false; - } + @Override + public final boolean isEmpty() { + return false; + } - @Override - public final int defaultWaitTime() { - return 0; - } + @Override + public final int defaultWaitTime() { + return 0; + } - @Override - public String toString() { - StringBuilder sb = new StringBuilder(); - sb.append(this.getClass().getSimpleName()); - sb.append(" Schema:"); - sb.append(Arrays.toString(fb.getSchema())); - return sb.toString(); - } + @Override + public String toString() { + StringBuilder sb = new StringBuilder(); + sb.append(this.getClass().getSimpleName()); + sb.append(" Schema:"); + sb.append(Arrays.toString(fb.getSchema())); + return sb.toString(); + } - @Override - public void change(FrameBlock t) { - fb = t; - } + @Override + public void change(FrameBlock t) { + fb = t; + } } diff --git a/src/test/java/org/apache/sysds/performance/generators/FrameFile.java b/src/test/java/org/apache/sysds/performance/generators/FrameFile.java index d89a2589d76..e106e7c063f 100644 --- a/src/test/java/org/apache/sysds/performance/generators/FrameFile.java +++ b/src/test/java/org/apache/sysds/performance/generators/FrameFile.java @@ -31,50 +31,46 @@ public class FrameFile extends ConstFrame { - final private String path; + final private String path; - private FrameFile(String path, FrameBlock fb) { - super(fb); - this.path = path; - System.out.println("First 10 rows:"); - System.out.println(fb.slice(0, 10)); - } + private FrameFile(String path, FrameBlock fb) { + super(fb); + this.path = path; + } - public static FrameFile create(String path) throws Exception { + public static FrameFile create(String path) throws Exception { - MetaDataAll mba = new MetaDataAll(path + ".mtd", false, true); - if(mba.mtdExists()) { - LOG.error(mba); + MetaDataAll mba = new MetaDataAll(path + ".mtd", false, true); + if(mba.mtdExists()) { + FileFormat f = FileFormat.valueOf(mba.getFormatTypeString().toUpperCase()); + ValueType[] schema = FrameObject.parseSchema(mba.getSchema()); + FileFormatProperties p = null; + if(f.equals(FileFormat.CSV)) { + p = new FileFormatPropertiesCSV(); + ((FileFormatPropertiesCSV) p).setHeader(mba.getHasHeader()); + ((FileFormatPropertiesCSV) p).setDelim(mba.getDelim()); + } + FrameReader r = FrameReaderFactory.createFrameReader(f, p); + FrameBlock fb = r.readFrameFromHDFS(path, schema, mba.getDim1(), mba.getDim2()); + return new FrameFile(path, fb); + } + else { + LOG.error("No Mtd file found.. please add one. Fallback to CSV reading with header"); + // we assume csv + FrameReader r = FrameReaderFactory.createFrameReader(FileFormat.CSV); + FrameBlock fb = r.readFrameFromHDFS(path, -1, -1); + return new FrameFile(path, fb); + } - // DataCharacteristics ds = mba.getDataCharacteristics(); - FileFormat f = FileFormat.valueOf(mba.getFormatTypeString().toUpperCase()); - ValueType[] schema = FrameObject.parseSchema(mba.getSchema()); - FileFormatProperties p = null; - if(f.equals(FileFormat.CSV)){ - p = new FileFormatPropertiesCSV(); - ((FileFormatPropertiesCSV)p).setHeader(mba.getHasHeader()); - } - FrameReader r = FrameReaderFactory.createFrameReader(f, p); - FrameBlock fb = r.readFrameFromHDFS(path, schema, mba.getDim1(), mba.getDim2()); - return new FrameFile(path, fb); - } - else { - LOG.error("No Mtd file found.. please add one. Fallback to CSV reading with header"); - // we assume csv - FrameReader r = FrameReaderFactory.createFrameReader(FileFormat.CSV); - FrameBlock fb = r.readFrameFromHDFS(path, -1, -1); - return new FrameFile(path, fb); - } + } - } - - @Override - public String toString() { - StringBuilder sb = new StringBuilder(); - sb.append(super.toString()); - sb.append(" From file: "); - sb.append(path); - return sb.toString(); - } + @Override + public String toString() { + StringBuilder sb = new StringBuilder(); + sb.append(super.toString()); + sb.append(" From file: "); + sb.append(path); + return sb.toString(); + } } diff --git a/src/test/java/org/apache/sysds/performance/generators/FrameTransformFile.java b/src/test/java/org/apache/sysds/performance/generators/FrameTransformFile.java index 359cbd23815..8f8130f81cc 100644 --- a/src/test/java/org/apache/sysds/performance/generators/FrameTransformFile.java +++ b/src/test/java/org/apache/sysds/performance/generators/FrameTransformFile.java @@ -19,65 +19,51 @@ package org.apache.sysds.performance.generators; -import java.io.IOException; - -import org.apache.sysds.common.Types.FileFormat; -import org.apache.sysds.common.Types.ValueType; import org.apache.sysds.performance.PerfUtil; import org.apache.sysds.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer; import org.apache.sysds.runtime.frame.data.FrameBlock; -import org.apache.sysds.runtime.io.FileFormatPropertiesCSV; -import org.apache.sysds.runtime.io.FrameReader; -import org.apache.sysds.runtime.io.FrameReaderFactory; import org.apache.sysds.runtime.matrix.data.MatrixBlock; import org.apache.sysds.runtime.transform.encode.EncoderFactory; import org.apache.sysds.runtime.transform.encode.MultiColumnEncoder; public class FrameTransformFile extends ConstMatrix { - final private String path; - final private String specPath; - - private FrameTransformFile(String path, String specPath, MatrixBlock mb) throws IOException { - super(mb); - this.path = path; - this.specPath = specPath; - } - - // example: - // src/test/resources/datasets/titanic/tfspec.json - // src/test/resources/datasets/titanic/titanic.csv - public static FrameTransformFile create(String path, String specPath) throws IOException { - // read spec - final String spec = PerfUtil.readSpec(specPath); + final private String path; + final private String specPath; - // MetaDataAll mba = new MetaDataAll(path + ".mtd", false, true); - // DataCharacteristics ds = mba.getDataCharacteristics(); - // FileFormat f = FileFormat.valueOf(mba.getFormatTypeString().toUpperCase()); + private FrameTransformFile(String path, String specPath, MatrixBlock mb) throws Exception { + super(mb); + this.path = path; + this.specPath = specPath; + } - FileFormatPropertiesCSV csvP = new FileFormatPropertiesCSV(); - csvP.setHeader(true); - FrameReader r = FrameReaderFactory.createFrameReader(FileFormat.CSV, csvP); - FrameBlock fb = r.readFrameFromHDFS(path, new ValueType[] {ValueType.STRING}, -1, -1); + // example: + // src/test/resources/datasets/titanic/tfspec.json + // src/test/resources/datasets/titanic/titanic.csv + public static FrameTransformFile create(String path, String specPath) throws Exception { + // read spec + final String spec = PerfUtil.readSpec(specPath); + final FrameFile fg = FrameFile.create(path); - int k = InfrastructureAnalyzer.getLocalParallelism(); - FrameBlock sc = fb.detectSchema(k); - fb = fb.applySchema(sc, k); - MultiColumnEncoder encoder = EncoderFactory.createEncoder(spec, fb.getColumnNames(), fb.getNumColumns(), null); - MatrixBlock mb = encoder.encode(fb, k); + FrameBlock fb = fg.take(); + int k = InfrastructureAnalyzer.getLocalParallelism(); + FrameBlock sc = fb.detectSchema(k); + fb = fb.applySchema(sc, k); + MultiColumnEncoder encoder = EncoderFactory.createEncoder(spec, fb.getColumnNames(), fb.getNumColumns(), null); + MatrixBlock mb = encoder.encode(fb, k); - return new FrameTransformFile(path, specPath, mb); - } + return new FrameTransformFile(path, specPath, mb); + } - @Override - public String toString() { - StringBuilder sb = new StringBuilder(); - sb.append(this.getClass().getSimpleName()); - sb.append(" From file: "); - sb.append(path); - sb.append(" -- Transformed with: "); - sb.append(specPath); - return sb.toString(); - } + @Override + public String toString() { + StringBuilder sb = new StringBuilder(); + sb.append(this.getClass().getSimpleName()); + sb.append(" From file: "); + sb.append(path); + sb.append(" -- Transformed with: "); + sb.append(specPath); + return sb.toString(); + } } diff --git a/src/test/java/org/apache/sysds/performance/generators/MatrixFile.java b/src/test/java/org/apache/sysds/performance/generators/MatrixFile.java index 0f85528ad23..c23b5f472a9 100644 --- a/src/test/java/org/apache/sysds/performance/generators/MatrixFile.java +++ b/src/test/java/org/apache/sysds/performance/generators/MatrixFile.java @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -28,31 +28,31 @@ public class MatrixFile extends ConstMatrix { - final private String path; + final private String path; - private MatrixFile(String path, MatrixBlock mb) { - super(mb); - this.path = path; - } + private MatrixFile(String path, MatrixBlock mb) { + super(mb); + this.path = path; + } - public static MatrixFile create(String path) throws Exception { + public static MatrixFile create(String path) throws Exception { - MetaDataAll mba = new MetaDataAll(path + ".mtd", false, true); - DataCharacteristics ds = mba.getDataCharacteristics(); - FileFormat f = FileFormat.valueOf(mba.getFormatTypeString().toUpperCase()); + MetaDataAll mba = new MetaDataAll(path + ".mtd", false, true); + DataCharacteristics ds = mba.getDataCharacteristics(); + FileFormat f = FileFormat.valueOf(mba.getFormatTypeString().toUpperCase()); - MatrixReader r = MatrixReaderFactory.createMatrixReader(f); - MatrixBlock mb = r.readMatrixFromHDFS(path, ds.getRows(), ds.getCols(), ds.getBlocksize(), ds.getNonZeros()); - return new MatrixFile(path, mb); - } + MatrixReader r = MatrixReaderFactory.createMatrixReader(f); + MatrixBlock mb = r.readMatrixFromHDFS(path, ds.getRows(), ds.getCols(), ds.getBlocksize(), ds.getNonZeros()); + return new MatrixFile(path, mb); + } - @Override - public String toString() { - StringBuilder sb = new StringBuilder(); - sb.append(this.getClass().getSimpleName()); - sb.append(" From file: "); - sb.append(path); - return sb.toString(); - } + @Override + public String toString() { + StringBuilder sb = new StringBuilder(); + sb.append(this.getClass().getSimpleName()); + sb.append(" From file: "); + sb.append(path); + return sb.toString(); + } } diff --git a/src/test/java/org/apache/sysds/performance/simple/DetectTypeArray.java b/src/test/java/org/apache/sysds/performance/simple/DetectTypeArray.java index f9fdf1b9547..6acf92da242 100644 --- a/src/test/java/org/apache/sysds/performance/simple/DetectTypeArray.java +++ b/src/test/java/org/apache/sysds/performance/simple/DetectTypeArray.java @@ -27,30 +27,30 @@ public class DetectTypeArray { - public static void main(String[] args) { - Array a = ArrayFactory.create(generateRandomFloatString(1000, 134)); + public static void main(String[] args) { + Array a = ArrayFactory.create(generateRandomFloatString(1000, 134)); - Timing t = new Timing(); - t.start(); - int N = 10000; - for(int i = 0; i < N; i++) - a.analyzeValueType(); + Timing t = new Timing(); + t.start(); + int N = 10000; + for(int i = 0; i < N; i++) + a.analyzeValueType(); - System.out.println(t.stop() / N); + System.out.println(t.stop() / N); - } + } - public static String[] generateRandomFloatString(int size, int seed) { - Random r = new Random(seed); - String[] ret = new String[size]; - for(int i = 0; i < size; i++) { - int e = r.nextInt(999); - int a = r.nextInt(999); + public static String[] generateRandomFloatString(int size, int seed) { + Random r = new Random(seed); + String[] ret = new String[size]; + for(int i = 0; i < size; i++) { + int e = r.nextInt(999); + int a = r.nextInt(999); - ret[i] = String.format("%d.%03d", e, a); - } + ret[i] = String.format("%d.%03d", e, a); + } - return ret; - } + return ret; + } } diff --git a/src/test/java/org/apache/sysds/performance/simple/NNZ.java b/src/test/java/org/apache/sysds/performance/simple/NNZ.java index 8ed77aea979..57407080c37 100644 --- a/src/test/java/org/apache/sysds/performance/simple/NNZ.java +++ b/src/test/java/org/apache/sysds/performance/simple/NNZ.java @@ -24,32 +24,32 @@ import org.apache.sysds.test.TestUtils; public class NNZ { - public static void main(String[] args) { - MatrixBlock mb = TestUtils.generateTestMatrixBlock(10000, 1000, 0, 103, 0.7, 421); - Timing t = new Timing(); - t.start(); - for(int i = 0; i < 1000; i++) { - mb.recomputeNonZeros(); - } - System.out.println("single: " + t.stop()/ 1000); - t.start(); - for(int i = 0; i < 1000; i++) { + public static void main(String[] args) { + MatrixBlock mb = TestUtils.generateTestMatrixBlock(10000, 1000, 0, 103, 0.7, 421); + Timing t = new Timing(); + t.start(); + for(int i = 0; i < 1000; i++) { + mb.recomputeNonZeros(); + } + System.out.println("single: " + t.stop() / 1000); + t.start(); + for(int i = 0; i < 1000; i++) { - mb.recomputeNonZeros(16); - } + mb.recomputeNonZeros(16); + } - System.out.println("par: " + t.stop()/ 1000); - t.start(); - for(int i = 0; i < 1000; i++) { - mb.recomputeNonZeros(); - } - System.out.println("single: " + t.stop()/ 1000); - t.start(); - for(int i = 0; i < 1000; i++) { + System.out.println("par: " + t.stop() / 1000); + t.start(); + for(int i = 0; i < 1000; i++) { + mb.recomputeNonZeros(); + } + System.out.println("single: " + t.stop() / 1000); + t.start(); + for(int i = 0; i < 1000; i++) { - mb.recomputeNonZeros(16); - } + mb.recomputeNonZeros(16); + } - System.out.println("par: " + t.stop()/ 1000); - } + System.out.println("par: " + t.stop() / 1000); + } } From 2b8de1629b935d0b75caf38e4295c706980f0ce7 Mon Sep 17 00:00:00 2001 From: Sebastian Baunsgaard Date: Thu, 26 Oct 2023 18:24:54 +0200 Subject: [PATCH 26/28] [MINOR] JIT optimize LibMatrixBinCell This commit move some of the code inside LibMatrixBincell around to encourage jit compilation of some methods. In specific folloing methods have been introduced. - safeBinaryMvSparseRowVector - fillZeroValuesEmpty - fillZeroValuesDense - fillZeroValuesSparse - safeBinaryMMDenseDenseDensePM_Vec (Plus Multiply kernel vectorized) - safeBinaryMMDenseDenseDensePM (Plus Multiply kernel small input) - safeBinaryMMDenseDenseDenseContiguous (This one makes a big difference) - safeBinaryMMDenseDenseDenseGeneric In specific the safeBinaryMMDenseDenseDenseContiguous, safeBinaryMMDenseDenseDensePMm and safeBinaryMMDenseDenseDensePM_Vec improve the performance by much. In LM_cg the performance: Stats output: +* 3.123 3000 (Before) +* 1.991 3000 (After) + 1.125 2021 (Before) + 0.703 2015 (After) This is training on Criteo 100k rows. --- .../runtime/matrix/data/LibMatrixBincell.java | 430 +++++++++++------- .../runtime/matrix/data/LibMatrixMult.java | 2 +- 2 files changed, 269 insertions(+), 163 deletions(-) diff --git a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixBincell.java b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixBincell.java index e53f09a7f43..e5ec7a00209 100644 --- a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixBincell.java +++ b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixBincell.java @@ -851,85 +851,93 @@ private static void safeBinaryMVSparseDenseRow(MatrixBlock m1, MatrixBlock m2, M private static void safeBinaryMVSparse(MatrixBlock m1, MatrixBlock m2, MatrixBlock ret, BinaryOperator op) { boolean isMultiply = (op.fn instanceof Multiply); boolean skipEmpty = (isMultiply || isSparseSafeDivide(op, m2)); - - int rlen = m1.rlen; - int clen = m1.clen; - SparseBlock a = m1.sparseBlock; BinaryAccessType atype = getBinaryAccessType(m1, m2); - - //early abort on skip and empty - if( skipEmpty && (m1.isEmptyBlock(false) || m2.isEmptyBlock(false) ) ) + + // early abort on skip and empty + if(skipEmpty && (m1.isEmptyBlock(false) || m2.isEmptyBlock(false))) return; // skip entire empty block - - //allocate once in order to prevent repeated reallocation - if( ret.sparse ) + + // allocate once in order to prevent repeated reallocation + if(ret.sparse) ret.allocateSparseRowsBlock(); - - if( atype == BinaryAccessType.MATRIX_COL_VECTOR ) - { - for( int i=0; i aix[apos]){ - apos++; - } - // for each point in the sparse range - for(; apos < alen && aix[apos] < len; apos++){ - if(!zeroIsZero){ - while(cpos < len && cpos < aix[apos]){ - ret.appendValue(rpos, cpos++, zero); - } - } - cpos = aix[apos]; - final double v = op.fn.execute(0, vals[apos]); - ret.appendValue(rpos, aix[apos], v); - // cpos++; - } - // process tail. + } + else { + // def + for(int k = cpos; k < len; k++) { + ret.appendValue(rpos, k, op.fn.execute(0, vals[k])); + } + } + } + + private static void fillZeroValuesSparse(BinaryOperator op, MatrixBlock m2, MatrixBlock ret, boolean skipEmpty, + int rpos, int cpos, int len) { + + final double zero = op.fn.execute(0.0, 0.0); + final boolean zeroIsZero = zero == 0.0; + final SparseBlock sb = m2.getSparseBlock(); + if(sb.isEmpty(0)) { + if(!zeroIsZero) { + while(cpos < len) + ret.appendValue(rpos, cpos++, zero); + } + } + else { + int apos = sb.pos(0); + final int alen = sb.size(0) + apos; + final int[] aix = sb.indexes(0); + final double[] vals = sb.values(0); + // skip aix pos until inside range of cpos and len + while(apos < alen && aix[apos] < len && cpos > aix[apos]) { + apos++; + } + // for each point in the sparse range + for(; apos < alen && aix[apos] < len; apos++) { if(!zeroIsZero) { - while(cpos < len) { + while(cpos < len && cpos < aix[apos]) { ret.appendValue(rpos, cpos++, zero); } } - } - } - else { - final DenseBlock db = m2.getDenseBlock(); - final double[] vals = db.values(0); - for(int k = cpos; k < len; k++){ - ret.appendValue(rpos, k, op.fn.execute(0, vals[k])); + cpos = aix[apos]; + final double v = op.fn.execute(0, vals[apos]); + ret.appendValue(rpos, aix[apos], v); + // cpos++; + } + // process tail. + if(!zeroIsZero) { + while(cpos < len) { + ret.appendValue(rpos, cpos++, zero); + } } } } @@ -1313,40 +1347,86 @@ else if(op.fn instanceof Multiply) } private static long safeBinaryMMDenseDenseDense(MatrixBlock m1, MatrixBlock m2, MatrixBlock ret, - BinaryOperator op, int rl, int ru) - { - boolean isPM = m1.clen >= 512 & (op.fn instanceof PlusMultiply | op.fn instanceof MinusMultiply); - double cntPM = !isPM ? Double.NaN : (op.fn instanceof PlusMultiply ? - ((PlusMultiply)op.fn).getConstant() : -1d * ((MinusMultiply)op.fn).getConstant()); + BinaryOperator op, int rl, int ru){ + final int clen = m1.clen; + final boolean isPM = (op.fn instanceof PlusMultiply || op.fn instanceof MinusMultiply); //guard for postponed allocation in single-threaded exec - if( !ret.isAllocated() ) + if(!ret.isAllocated()) ret.allocateDenseBlock(); - DenseBlock da = m1.getDenseBlock(); - DenseBlock db = m2.getDenseBlock(); - DenseBlock dc = ret.getDenseBlock(); - ValueFunction fn = op.fn; - int clen = m1.clen; + final DenseBlock da = m1.getDenseBlock(); + final DenseBlock db = m2.getDenseBlock(); + final DenseBlock dc = ret.getDenseBlock(); - //compute dense-dense binary, maintain nnz on-the-fly + if(isPM && clen >= 64) + return safeBinaryMMDenseDenseDensePM_Vec(da, db, dc, op, rl, ru, clen); + else if(da.isContiguous() && db.isContiguous() && dc.isContiguous()) { + if(op.fn instanceof PlusMultiply) + return safeBinaryMMDenseDenseDensePM(da, db, dc, op, rl, ru, clen); + else + return safeBinaryMMDenseDenseDenseContiguous(da, db, dc, op, rl, ru, clen); + } + else + return safeBinaryMMDenseDenseDenseGeneric(da, db, dc, op, rl, ru, clen); + } + + private static final long safeBinaryMMDenseDenseDensePM_Vec(DenseBlock da, DenseBlock db, DenseBlock dc, BinaryOperator op, + int rl, int ru, int clen) { + final double cntPM = (op.fn instanceof PlusMultiply ? ((PlusMultiply) op.fn).getConstant() : -1d * + ((MinusMultiply) op.fn).getConstant()); long lnnz = 0; - for(int i=rl; i Date: Tue, 31 Oct 2023 12:31:27 +0100 Subject: [PATCH 27/28] [SYSTEMDS-3636] Improved ultra-sparse tsmm right w/ sparse output This patch consolidates previous patches on ultra-sparse tsmm and significantly improves performance for ultra-sparse outputs. We now dispatch the concrete sparse kernel depending on the estimated number of non-zeros. The new ultra-sparse kernel uses sparse updates on the output (via binary search) instead of dense buffering. On the germany_osm (open street map) graph, this patch improved performance from >1h to 12s (SciPy takes ~60s) and thus yields very competitive performance. --- .../runtime/matrix/data/LibMatrixMult.java | 54 +++++++++++-------- 1 file changed, 31 insertions(+), 23 deletions(-) diff --git a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixMult.java b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixMult.java index 3fa91bb31e4..3df09cbc61d 100644 --- a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixMult.java +++ b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixMult.java @@ -437,9 +437,11 @@ public static void matrixMultTransposeSelf(MatrixBlock m1, MatrixBlock ret, bool //pre-processing ret.sparse = isSparseOutputTSMM(m1, leftTranspose); ret.allocateBlock(); - + MatrixBlock m1t = isSparseOutputTSMM(m1, leftTranspose, true) ? + LibMatrixReorg.transpose(m1) : null; + //core tsmm operation - matrixMultTransposeSelf(m1, ret, leftTranspose, 0, ret.rlen); + matrixMultTransposeSelf(m1, m1t, ret, leftTranspose, 0, ret.rlen); //post-processing if(copyToLowerTriangle){ @@ -477,7 +479,9 @@ public static void matrixMultTransposeSelf(MatrixBlock m1, MatrixBlock ret, bool //pre-processing (no need to check isThreadSafe) ret.sparse = isSparseOutputTSMM(m1, leftTranspose); ret.allocateBlock(); - + MatrixBlock m1t = isSparseOutputTSMM(m1, leftTranspose, true) ? + LibMatrixReorg.transpose(m1) : null; + //core multi-threaded matrix mult computation ExecutorService pool = CommonThreadPool.get(k); try { @@ -485,7 +489,7 @@ public static void matrixMultTransposeSelf(MatrixBlock m1, MatrixBlock ret, bool //load balance via #tasks=4k due to triangular shape int blklen = (int)(Math.ceil((double)ret.rlen / (4 * k))); for(int i = 0; i < ret.rlen; i += blklen) - tasks.add(new MatrixMultTransposeTask(m1, ret, leftTranspose, i, Math.min(i+blklen, ret.rlen))); + tasks.add(new MatrixMultTransposeTask(m1, m1t, ret, leftTranspose, i, Math.min(i+blklen, ret.rlen))); for( Future rtask : pool.invokeAll(tasks) ) rtask.get(); } @@ -2235,9 +2239,13 @@ private static void matrixMultTransposeSelfDense( MatrixBlock m1, MatrixBlock re } } - private static void matrixMultTransposeSelf(MatrixBlock m1, MatrixBlock ret, boolean leftTranspose, int rl, int ru) { - if(m1.sparse && ret.sparse) - matrixMultTransposeSelfUltraSparse(m1, ret, leftTranspose, rl, ru); + private static void matrixMultTransposeSelf(MatrixBlock m1, MatrixBlock m1t, MatrixBlock ret, boolean leftTranspose, int rl, int ru) { + if(m1.sparse && ret.sparse) { + if( m1t == null ) + matrixMultTransposeSelfUltraSparse(m1, ret, leftTranspose, rl, ru); + else + matrixMultTransposeSelfUltraSparse2(m1, m1t, ret, leftTranspose, rl, ru); + } else if( m1.sparse ) matrixMultTransposeSelfSparse(m1, ret, leftTranspose, rl, ru); else @@ -2406,9 +2414,7 @@ private static void matrixMultTransposeSelfUltraSparse( MatrixBlock m1, MatrixBl } } - //alternative matrixMultTransposeSelfUltraSparse2 w/ IKJ iteration order and dense buffering - //(for moderately large graphs 4x improvement compared to above, but for large graphs slower -> non-conclusive) - @SuppressWarnings("unused") + //alternative matrixMultTransposeSelfUltraSparse2 w/ IKJ iteration order and sparse updates private static void matrixMultTransposeSelfUltraSparse2( MatrixBlock m1, MatrixBlock m1t, MatrixBlock ret, boolean leftTranspose, int rl, int ru ) { if( leftTranspose ) throw new DMLRuntimeException("Left tsmm with sparse output not supported"); @@ -2417,17 +2423,13 @@ private static void matrixMultTransposeSelfUltraSparse2( MatrixBlock m1, MatrixB SparseBlock a = m1.sparseBlock; SparseBlock b = m1t.sparseBlock; SparseBlock c = ret.sparseBlock; - int m = m1.rlen; - double[] tmp = new double[m]; - for(int i=rl; i { private final MatrixBlock _m1; + private final MatrixBlock _m1t; private final MatrixBlock _ret; private final boolean _left; private final int _rl; private final int _ru; - protected MatrixMultTransposeTask( MatrixBlock m1, MatrixBlock ret, boolean left, int rl, int ru ) + protected MatrixMultTransposeTask( MatrixBlock m1, MatrixBlock m1t, MatrixBlock ret, boolean left, int rl, int ru ) { _m1 = m1; + _m1t = m1t; _ret = ret; _left = left; _rl = rl; @@ -4593,7 +4601,7 @@ protected MatrixMultTransposeTask( MatrixBlock m1, MatrixBlock ret, boolean left @Override public Object call() { - matrixMultTransposeSelf(_m1, _ret, _left, _rl, _ru); + matrixMultTransposeSelf(_m1, _m1t, _ret, _left, _rl, _ru); return null; } } From d9e4f213fa41e6698a23b07d4faebcc5f32f1bc0 Mon Sep 17 00:00:00 2001 From: Matthias Boehm Date: Tue, 31 Oct 2023 15:22:37 +0100 Subject: [PATCH 28/28] [SYSTEMDS-3636] Ultra-sparse tsmm right w/ multi-threaded transpose Further improvement of ultra-sparse tsmm right: w/ multi-threaded transpose. On a scenario of 10 times G %*% t(G) with G being germany_osm, the runtime changes as follows OLD: matrix mult: 13.346214013s matrix mult: 5.498598342s matrix mult: 5.11548485s matrix mult: 5.573473983s matrix mult: 5.673529942s matrix mult: 6.08607291s matrix mult: 6.244303553s matrix mult: 6.422722927s matrix mult: 4.995632087s matrix mult: 9.085500786s SystemDS Statistics: Total elapsed time: 71.007 sec. Total compilation time: 0.792 sec. Total execution time: 70.215 sec. Cache hits (Mem/Li/WB/FS/HDFS): 11/0/0/0/1. Cache writes (Li/WB/FS/HDFS): 0/11/0/0. Cache times (ACQr/m, RLS, EXP): 4.696/0.000/9.415/0.000 sec. HOP DAGs recompiled (PRED, SB): 0/0. HOP DAGs recompile time: 0.000 sec. Total JIT compile time: 5.987 sec. Total JVM GC count: 1. Total JVM GC time: 0.059 sec. Heavy hitter instructions: 1 tsmm 68.039 10 2 != 1.577 1 3 uak+ 0.585 1 4 + 0.027 22 5 print 0.008 12 6 mvvar 0.001 31 7 createvar 0.001 12 8 rmvar 0.000 45 9 time 0.000 20 10 - 0.000 10 NEW matrix mult: 12.17142539s matrix mult: 5.063393773s matrix mult: 4.764698928s matrix mult: 4.771695393s matrix mult: 5.434539822s matrix mult: 4.640708695s matrix mult: 4.967180443s matrix mult: 5.156199379s matrix mult: 5.472330144s matrix mult: 5.310449401s SystemDS Statistics: Total elapsed time: 60.405 sec. Total compilation time: 0.880 sec. Total execution time: 59.525 sec. Cache hits (Mem/Li/WB/FS/HDFS): 11/0/0/0/1. Cache writes (Li/WB/FS/HDFS): 0/11/0/0. Cache times (ACQr/m, RLS, EXP): 4.223/0.000/8.626/0.000 sec. HOP DAGs recompiled (PRED, SB): 0/0. HOP DAGs recompile time: 0.000 sec. Total JIT compile time: 6.956 sec. Total JVM GC count: 1. Total JVM GC time: 0.062 sec. Heavy hitter instructions: 1 tsmm 57.751 10 2 != 1.296 1 3 uak+ 0.465 1 4 + 0.029 22 5 print 0.008 12 6 mvvar 0.001 31 7 createvar 0.001 12 8 rmvar 0.000 45 9 time 0.000 20 10 - 0.000 10 --- .../org/apache/sysds/runtime/matrix/data/LibMatrixMult.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixMult.java b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixMult.java index 3df09cbc61d..98b3eaa1bba 100644 --- a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixMult.java +++ b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixMult.java @@ -480,7 +480,7 @@ public static void matrixMultTransposeSelf(MatrixBlock m1, MatrixBlock ret, bool ret.sparse = isSparseOutputTSMM(m1, leftTranspose); ret.allocateBlock(); MatrixBlock m1t = isSparseOutputTSMM(m1, leftTranspose, true) ? - LibMatrixReorg.transpose(m1) : null; + LibMatrixReorg.transpose(m1, k) : null; //core multi-threaded matrix mult computation ExecutorService pool = CommonThreadPool.get(k);