diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/MapToChar.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/MapToChar.java index 50e327ba5a2..ba5f00d8cc0 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/MapToChar.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/MapToChar.java @@ -168,7 +168,7 @@ protected static MapToChar readFields(DataInput in) throws IOException { final int length = in.readInt(); final char[] data = new char[length]; for(int i = 0; i < length; i++) - data[i] = in.readChar(); + data[i] = (char)in.readUnsignedShort(); return new MapToChar(unique, data); } diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/columns/Array.java b/src/main/java/org/apache/sysds/runtime/frame/data/columns/Array.java index fd77370f47d..00afcec68f0 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/columns/Array.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/columns/Array.java @@ -20,20 +20,27 @@ package org.apache.sysds.runtime.frame.data.columns; import java.lang.ref.SoftReference; +import java.util.ArrayList; +import java.util.Collections; import java.util.HashMap; import java.util.Iterator; +import java.util.List; import java.util.Map; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Future; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.apache.hadoop.io.Writable; import org.apache.sysds.common.Types.ValueType; +import org.apache.sysds.hops.OptimizerUtils; import org.apache.sysds.runtime.DMLRuntimeException; import org.apache.sysds.runtime.compress.colgroup.mapping.AMapToData; import org.apache.sysds.runtime.compress.estim.sample.SampleEstimatorFactory; import org.apache.sysds.runtime.frame.data.columns.ArrayFactory.FrameArrayType; import org.apache.sysds.runtime.frame.data.compress.ArrayCompressionStatistics; import org.apache.sysds.runtime.matrix.data.Pair; +import org.apache.sysds.runtime.util.CommonThreadPool; /** * Generic, resizable native arrays for the internal representation of the columns in the FrameBlock. We use this custom @@ -119,12 +126,58 @@ public synchronized final Map getRecodeMap(int estimate) { * @return The recode map */ protected Map createRecodeMap(int estimate) { - final Map map = new HashMap<>((int)Math.min((long)estimate *2, size())); + // final Map map = new HashMap<>((int)Math.min((long)estimate *2, size())); + final int s = size(); + int k = OptimizerUtils.getTransformNumThreads(); + if(k <= 1 || s < 10000) + return createRecodeMap(estimate, 0, s); + else + return parallelCreateRecodeMap(estimate, s, k); + } + + private Map parallelCreateRecodeMap(int estimate, final int s, int k) { + final ExecutorService pool = CommonThreadPool.get(k); + try{ + + final int blk = Math.max(10000, (s + k) / k); + List>> tasks = new ArrayList<>(); + for(int i = 0; i < s; i+= blk){ + final int start = i; + final int end = Math.min(i + blk, s); + tasks.add(pool.submit(() -> createRecodeMap(estimate, start, end))); + } + final Map map = tasks.get(0).get(); + for(int i = 1; i < tasks.size(); i++){ + final Map map2 = tasks.get(i).get(); + mergeRecodeMaps(map, map2); + } + return map; + } + catch(Exception e){ + throw new RuntimeException(e); + } + finally{ + pool.shutdown(); + } + } + + private void mergeRecodeMaps( Map target, Map from){ + List fromEntriesOrdered = new ArrayList<>(Collections.nCopies(from.size(), null)); + for(Map.Entry e : from.entrySet()) + fromEntriesOrdered.set(e.getValue().intValue(), e.getKey()); + long id = target.size(); + for(T e : fromEntriesOrdered){ + Long v = target.putIfAbsent(e, id ); + if(v == null) + id++; + } + } + + private Map createRecodeMap(final int estimate, final int s, final int e) { + final Map map = new HashMap<>((int)Math.min((long)estimate *2, e-s)); long id = 1; - final int s = size(); - for(int i = 0; i < s; i++) + for(int i = s; i < e; i++) id = addValRecodeMap(map, id, i); - return map; } diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/columns/OptionalArray.java b/src/main/java/org/apache/sysds/runtime/frame/data/columns/OptionalArray.java index ba0197a6aa0..ebdf6327654 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/columns/OptionalArray.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/columns/OptionalArray.java @@ -524,10 +524,8 @@ public void setM(Map map, int si, AMapToData m, int i) { @Override protected long addValRecodeMap(Map map, long id, int i) { - - if(_n.get(i)) { + if(_n.get(i)) id = _a.addValRecodeMap(map, id, i); - } return id; }