From cd205b51fd0f17fab50e813ab07486bb59d4242b Mon Sep 17 00:00:00 2001 From: Sebastian Baunsgaard Date: Tue, 22 Oct 2024 01:28:56 +0200 Subject: [PATCH] call down --- .../runtime/frame/data/columns/Array.java | 2 +- .../frame/data/columns/HashIntegerArray.java | 23 +++++++++++ .../frame/data/columns/HashLongArray.java | 22 ++++++++++ .../frame/data/columns/OptionalArray.java | 40 ++++++++++++++++++- 4 files changed, 84 insertions(+), 3 deletions(-) diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/columns/Array.java b/src/main/java/org/apache/sysds/runtime/frame/data/columns/Array.java index 5a08c5bcf58..8f653ddad13 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/columns/Array.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/columns/Array.java @@ -128,7 +128,7 @@ protected Map createRecodeMap(int estimate) { return map; } - private long addValRecodeMap(Map map, long id, int i) { + protected long addValRecodeMap(Map map, long id, int i) { T val = get(i); if(val != null) { Long v = map.putIfAbsent(val, id); diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/columns/HashIntegerArray.java b/src/main/java/org/apache/sysds/runtime/frame/data/columns/HashIntegerArray.java index 131036d2085..1b5a63e76b9 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/columns/HashIntegerArray.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/columns/HashIntegerArray.java @@ -23,10 +23,12 @@ import java.io.DataOutput; import java.io.IOException; import java.util.Arrays; +import java.util.Map; import org.apache.commons.lang3.NotImplementedException; import org.apache.sysds.common.Types.ValueType; import org.apache.sysds.runtime.DMLRuntimeException; +import org.apache.sysds.runtime.compress.colgroup.mapping.AMapToData; import org.apache.sysds.runtime.frame.data.columns.ArrayFactory.FrameArrayType; import org.apache.sysds.runtime.matrix.data.Pair; import org.apache.sysds.runtime.util.UtilFunctions; @@ -435,6 +437,27 @@ public boolean possiblyContainsNaN() { return false; } + @Override + protected long addValRecodeMap(Map map, long id, int i) { + Long val = getLong(i); + Long v = map.putIfAbsent(val, id); + if(v == null) + id++; + return id; + } + + @Override + public void setM(Map map, AMapToData m, int i){ + m.set(i, map.get(getInt(i)).intValue() - 1); + } + + @Override + public void setM(Map map, int si, AMapToData m, int i) { + final Integer v = getInt(i); + m.set(i, map.get(v).intValue() - 1); + } + + @Override public String toString() { StringBuilder sb = new StringBuilder(_size * 5 + 2); diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/columns/HashLongArray.java b/src/main/java/org/apache/sysds/runtime/frame/data/columns/HashLongArray.java index 3c802d3267c..85f237309bc 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/columns/HashLongArray.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/columns/HashLongArray.java @@ -23,10 +23,12 @@ import java.io.DataOutput; import java.io.IOException; import java.util.Arrays; +import java.util.Map; import org.apache.commons.lang3.NotImplementedException; import org.apache.sysds.common.Types.ValueType; import org.apache.sysds.runtime.DMLRuntimeException; +import org.apache.sysds.runtime.compress.colgroup.mapping.AMapToData; import org.apache.sysds.runtime.frame.data.columns.ArrayFactory.FrameArrayType; import org.apache.sysds.runtime.matrix.data.Pair; import org.apache.sysds.runtime.util.UtilFunctions; @@ -432,6 +434,26 @@ public boolean possiblyContainsNaN() { return false; } + @Override + protected long addValRecodeMap(Map map, long id, int i) { + Integer val = getInt(i); + Long v = map.putIfAbsent(val, id); + if(v == null) + id++; + + return id; + } + + @Override + public void setM(Map map, AMapToData m, int i){ + m.set(i, map.get(getLong(i)).intValue() - 1); + } + + @Override + public void setM(Map map, int si, AMapToData m, int i) { + m.set(i, map.get(getLong(i)).intValue() - 1); + } + @Override public String toString() { StringBuilder sb = new StringBuilder(_size * 5 + 2); diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/columns/OptionalArray.java b/src/main/java/org/apache/sysds/runtime/frame/data/columns/OptionalArray.java index 02620f27fbf..eb14b20f1f9 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/columns/OptionalArray.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/columns/OptionalArray.java @@ -27,6 +27,7 @@ import org.apache.sysds.common.Types.ValueType; import org.apache.sysds.runtime.DMLRuntimeException; +import org.apache.sysds.runtime.compress.colgroup.mapping.AMapToData; import org.apache.sysds.runtime.frame.data.columns.ArrayFactory.FrameArrayType; import org.apache.sysds.runtime.matrix.data.Pair; import org.apache.sysds.runtime.util.UtilFunctions; @@ -473,6 +474,7 @@ public boolean possiblyContainsNaN() { } @Override + @SuppressWarnings("unchecked") protected Map createRecodeMap(int estimate) { if(getValueType() == ValueType.BOOLEAN) { // shortcut for boolean arrays, since we only @@ -484,11 +486,46 @@ protected Map createRecodeMap(int estimate) { return map; } + else if(getValueType() == ValueType.HASH32){ + Map map = new HashMap<>(estimate); + HashIntegerArray b = (HashIntegerArray)_a; + long id = 1; + for(int i = 0; i < size(); i++){ + if(_n.get(i)) + id = b.addValRecodeMap(map, id, i); + } + return (Map)map; + } + else if(getValueType() == ValueType.HASH64){ + Map map = new HashMap<>(estimate); + HashLongArray b = (HashLongArray)_a; + long id = 1; + for(int i = 0; i < size(); i++){ + if(_n.get(i)) + id = b.addValRecodeMap(map, id, i); + } + return (Map)map; + } else return super.createRecodeMap(estimate); } - private long addValRecodeMap(Map map, long id, int i) { + @Override + public void setM(Map map, AMapToData m, int i){ + _a.setM(map,m,i); + } + + @Override + public void setM(Map map, int si, AMapToData m, int i) { + if(_n.get(i)) + _a.setM(map,si,m,i); + else + m.set(i, si); + } + + + @Override + protected long addValRecodeMap(Map map, long id, int i) { T val = get(i); if(val != null) { Long v = map.putIfAbsent(val, id); @@ -498,7 +535,6 @@ private long addValRecodeMap(Map map, long id, int i) { return id; } - @Override public String toString() { StringBuilder sb = new StringBuilder(_size + 2);