From 07f753fba13fd1076872007edb1055919a082096 Mon Sep 17 00:00:00 2001 From: Sebastian Baunsgaard Date: Tue, 24 Sep 2024 16:17:52 +0200 Subject: [PATCH] [MINOR] Frame tests improvement 2 Add tests for Frame append and detect schema. --- .../frame/data/lib/FrameLibAppend.java | 10 +++++-- .../frame/data/lib/FrameLibDetectSchema.java | 21 +++++++------- .../test/component/frame/FrameCustomTest.java | 18 ++++++++++++ .../sysds/test/component/frame/FrameTest.java | 28 +++++++++++++++++++ 4 files changed, 64 insertions(+), 13 deletions(-) diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/lib/FrameLibAppend.java b/src/main/java/org/apache/sysds/runtime/frame/data/lib/FrameLibAppend.java index 78177b1d2f0..60c61e8f4af 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/lib/FrameLibAppend.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/lib/FrameLibAppend.java @@ -33,8 +33,12 @@ import org.apache.sysds.runtime.frame.data.columns.ColumnMetadata; public class FrameLibAppend { - protected static final Log LOG = LogFactory.getLog(FrameLibAppend.class.getName()); + + private FrameLibAppend(){ + // private constructor. + } + /** * Appends the given argument FrameBlock 'that' to this FrameBlock by creating a deep copy to prevent side effects. * For cbind, the frames are appended column-wise (same number of rows), while for rbind the frames are appended @@ -50,7 +54,7 @@ public static FrameBlock append(FrameBlock a, FrameBlock b, boolean cbind) { return ret; } - public static FrameBlock appendCbind(FrameBlock a, FrameBlock b) { + private static FrameBlock appendCbind(FrameBlock a, FrameBlock b) { final int nRow = a.getNumRows(); final int nRowB = b.getNumRows(); @@ -73,7 +77,7 @@ else if(b.getNumColumns() == 0) return new FrameBlock(_schema, _colnames, _colmeta, _coldata); } - public static FrameBlock appendRbind(FrameBlock a, FrameBlock b) { + private static FrameBlock appendRbind(FrameBlock a, FrameBlock b) { final int nCol = a.getNumColumns(); final int nColB = b.getNumColumns(); diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/lib/FrameLibDetectSchema.java b/src/main/java/org/apache/sysds/runtime/frame/data/lib/FrameLibDetectSchema.java index 889fc853f91..2e8e2ba1060 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/lib/FrameLibDetectSchema.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/lib/FrameLibDetectSchema.java @@ -22,7 +22,6 @@ import java.util.ArrayList; import java.util.List; import java.util.concurrent.Callable; -import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutorService; import java.util.concurrent.Future; @@ -67,11 +66,16 @@ public static FrameBlock detectSchema(FrameBlock in, double sampleFraction, int } private FrameBlock apply() { - final int cols = in.getNumColumns(); - final FrameBlock fb = new FrameBlock(UtilFunctions.nCopies(cols, ValueType.STRING)); - String[] schemaInfo = (k == 1) ? singleThreadApply() : parallelApply(); - fb.appendRow(schemaInfo); - return fb; + try{ + final int cols = in.getNumColumns(); + final FrameBlock fb = new FrameBlock(UtilFunctions.nCopies(cols, ValueType.STRING)); + String[] schemaInfo = (k == 1) ? singleThreadApply() : parallelApply(); + fb.appendRow(schemaInfo); + return fb; + } + catch(Exception e){ + throw new DMLRuntimeException("Failed to detect schema", e); + } } private String[] singleThreadApply() { @@ -84,7 +88,7 @@ private String[] singleThreadApply() { return schemaInfo; } - private String[] parallelApply() { + private String[] parallelApply() throws Exception { final ExecutorService pool = CommonThreadPool.get(k); try { final int cols = in.getNumColumns(); @@ -99,9 +103,6 @@ private String[] parallelApply() { return schemaInfo; } - catch(ExecutionException | InterruptedException e) { - throw new DMLRuntimeException("Exception interrupted or exception thrown in detectSchema", e); - } finally{ pool.shutdown(); } diff --git a/src/test/java/org/apache/sysds/test/component/frame/FrameCustomTest.java b/src/test/java/org/apache/sysds/test/component/frame/FrameCustomTest.java index 5e482a53698..476ce57d3aa 100644 --- a/src/test/java/org/apache/sysds/test/component/frame/FrameCustomTest.java +++ b/src/test/java/org/apache/sysds/test/component/frame/FrameCustomTest.java @@ -19,10 +19,16 @@ package org.apache.sysds.test.component.frame; +import static org.junit.Assert.assertThrows; import static org.junit.Assert.assertTrue; +import static org.mockito.ArgumentMatchers.anyInt; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.when; import org.apache.sysds.common.Types.ValueType; +import org.apache.sysds.runtime.DMLRuntimeException; import org.apache.sysds.runtime.frame.data.FrameBlock; +import org.apache.sysds.runtime.frame.data.lib.FrameLibDetectSchema; import org.apache.sysds.runtime.matrix.data.MatrixBlock; import org.apache.sysds.runtime.util.DataConverter; import org.apache.sysds.test.TestUtils; @@ -61,4 +67,16 @@ public void castToFrame2() { assertTrue(f.getSchema()[0] == ValueType.FP64); } + + @Test + public void detectSchemaError(){ + FrameBlock f = TestUtils.generateRandomFrameBlock(10, 10, 23); + FrameBlock spy = spy(f); + when(spy.getColumn(anyInt())).thenThrow(new RuntimeException()); + + Exception e = assertThrows(DMLRuntimeException.class, () -> FrameLibDetectSchema.detectSchema(spy, 3)); + + assertTrue(e.getMessage().contains("Failed to detect schema")); + } + } diff --git a/src/test/java/org/apache/sysds/test/component/frame/FrameTest.java b/src/test/java/org/apache/sysds/test/component/frame/FrameTest.java index bf848983448..88f99339a1c 100644 --- a/src/test/java/org/apache/sysds/test/component/frame/FrameTest.java +++ b/src/test/java/org/apache/sysds/test/component/frame/FrameTest.java @@ -173,6 +173,34 @@ public void cBindEmptyAfter() { f.append(b, true); } + @Test(expected = DMLRuntimeException.class) + public void cBindEmptyCols() { + // must have same number of rows. + FrameBlock b = new FrameBlock(); + b.append(f, false); + } + + @Test(expected = DMLRuntimeException.class) + public void cBindEmptyAfterCols() { + // must have same number of rows. + FrameBlock b = new FrameBlock(); + f.append(b, false); + } + + @Test + public void cBindEmptyR() { + // must have same number of rows. + FrameBlock b = new FrameBlock(new ValueType[0], f.getNumRows() ); + b.append(f, true); + } + + @Test + public void cBindEmptyAfterR() { + // must have same number of rows. + FrameBlock b = new FrameBlock(new ValueType[0], f.getNumRows() ); + f.append(b, true); + } + @Test public void cBindStringColAfter() { // must have same number of rows.