From dfec8b9ab4376950ae5eaaccf986dbd1bb3eb78e Mon Sep 17 00:00:00 2001 From: Jianfeng Mao <4297243+jmao-denver@users.noreply.github.com> Date: Fri, 1 Nov 2024 14:59:32 -0600 Subject: [PATCH] feat: TableDataService API integration via Python (#6175) Fixes #6171 Co-authored-by: Nathaniel Bauernfeind Co-authored-by: Ryan Caudy --- .../table/impl/chunkboxer/ChunkBoxer.java | 128 ++- .../engine/table/impl/SourceTable.java | 5 +- .../table/impl/locations/TableLocation.java | 4 +- .../impl/locations/TableLocationProvider.java | 4 +- .../impl/TableLocationSubscriptionBuffer.java | 2 +- ...TableLocationUpdateSubscriptionBuffer.java | 2 +- .../regioned/RegionedColumnSourceManager.java | 21 +- .../impl/TestPartitionAwareSourceTable.java | 18 + .../TestRegionedColumnSourceManager.java | 17 + extensions/barrage/build.gradle | 1 + .../barrage/util/ArrowToTableConverter.java | 82 +- .../extensions/barrage/util/BarrageUtil.java | 26 + .../barrage/util/PythonTableDataService.java | 999 ++++++++++++++++++ .../experimental/table_data_service.py | 455 ++++++++ py/server/tests/test_table_data_service.py | 331 ++++++ .../server/arrow/ArrowFlightUtil.java | 2 +- 16 files changed, 2000 insertions(+), 97 deletions(-) create mode 100644 extensions/barrage/src/main/java/io/deephaven/extensions/barrage/util/PythonTableDataService.java create mode 100644 py/server/deephaven/experimental/table_data_service.py create mode 100644 py/server/tests/test_table_data_service.py diff --git a/engine/base/src/main/java/io/deephaven/engine/table/impl/chunkboxer/ChunkBoxer.java b/engine/base/src/main/java/io/deephaven/engine/table/impl/chunkboxer/ChunkBoxer.java index 0f295838f73..73c2799de7c 100644 --- a/engine/base/src/main/java/io/deephaven/engine/table/impl/chunkboxer/ChunkBoxer.java +++ b/engine/base/src/main/java/io/deephaven/engine/table/impl/chunkboxer/ChunkBoxer.java @@ -7,6 +7,7 @@ import io.deephaven.chunk.attributes.Values; import io.deephaven.engine.table.Context; import io.deephaven.util.type.TypeUtils; +import org.jetbrains.annotations.NotNull; /** * Convert an arbitrary chunk to a chunk of boxed objects. @@ -14,17 +15,54 @@ public class ChunkBoxer { /** - * Return a chunk that contains boxed Objects representing the primitive values in primitives. + * Return a chunk that contains boxed {@link Object Objects} representing the primitive values in {@code values}. */ public interface BoxerKernel extends Context { /** - * Convert all primitives to an object. + * Box all values into {@link Object Objects} if they are not already {@code Objects}. * - * @param primitives the primitives to convert + * @param values the values to box * - * @return a chunk containing primitives as an object + * @return a chunk containing values as {@code Objects} (not owned by the caller) */ - ObjectChunk box(Chunk primitives); + ObjectChunk box(Chunk values); + } + + /** + * Box the value at {@code offset} in {@code values}. + *

+ * Please use a {@link #getBoxer(ChunkType, int) ChunkBoxer} when boxing multiple values in order to amortize the + * cost of implementation lookup and avoid virtual dispatch. + * + * @param values The chunk containing the value to box + * @param offset The offset of the value to box + * @return The boxed value + * @param The type of the boxed value + */ + @SuppressWarnings("unchecked") + public static BOXED_TYPE boxedGet(@NotNull final Chunk values, int offset) { + final ChunkType type = values.getChunkType(); + switch (type) { + case Boolean: + return (BOXED_TYPE) Boolean.valueOf(values.asBooleanChunk().get(offset)); + case Char: + return (BOXED_TYPE) TypeUtils.box(values.asCharChunk().get(offset)); + case Byte: + return (BOXED_TYPE) TypeUtils.box(values.asByteChunk().get(offset)); + case Short: + return (BOXED_TYPE) TypeUtils.box(values.asShortChunk().get(offset)); + case Int: + return (BOXED_TYPE) TypeUtils.box(values.asIntChunk().get(offset)); + case Long: + return (BOXED_TYPE) TypeUtils.box(values.asLongChunk().get(offset)); + case Float: + return (BOXED_TYPE) TypeUtils.box(values.asFloatChunk().get(offset)); + case Double: + return (BOXED_TYPE) TypeUtils.box(values.asDoubleChunk().get(offset)); + case Object: + return (BOXED_TYPE) values.asObjectChunk().get(offset); + } + throw new IllegalArgumentException("Unknown type: " + type); } public static BoxerKernel getBoxer(ChunkType type, int capacity) { @@ -55,8 +93,8 @@ public static BoxerKernel getBoxer(ChunkType type, int capacity) { private static class ObjectBoxer implements BoxerKernel { @Override - public ObjectChunk box(Chunk primitives) { - return primitives.asObjectChunk(); + public ObjectChunk box(Chunk values) { + return values.asObjectChunk(); } } @@ -79,13 +117,13 @@ private static class BooleanBoxer extends BoxerCommon { } @Override - public ObjectChunk box(Chunk primitives) { - final BooleanChunk booleanChunk = primitives.asBooleanChunk(); - for (int ii = 0; ii < primitives.size(); ++ii) { + public ObjectChunk box(Chunk values) { + final BooleanChunk booleanChunk = values.asBooleanChunk(); + for (int ii = 0; ii < values.size(); ++ii) { // noinspection UnnecessaryBoxing objectChunk.set(ii, Boolean.valueOf(booleanChunk.get(ii))); } - objectChunk.setSize(primitives.size()); + objectChunk.setSize(values.size()); return objectChunk; } } @@ -96,12 +134,12 @@ private static class CharBoxer extends BoxerCommon { } @Override - public ObjectChunk box(Chunk primitives) { - final CharChunk charChunk = primitives.asCharChunk(); - for (int ii = 0; ii < primitives.size(); ++ii) { - objectChunk.set(ii, io.deephaven.util.type.TypeUtils.box(charChunk.get(ii))); + public ObjectChunk box(Chunk values) { + final CharChunk charChunk = values.asCharChunk(); + for (int ii = 0; ii < values.size(); ++ii) { + objectChunk.set(ii, TypeUtils.box(charChunk.get(ii))); } - objectChunk.setSize(primitives.size()); + objectChunk.setSize(values.size()); return objectChunk; } } @@ -112,12 +150,12 @@ private static class ByteBoxer extends BoxerCommon { } @Override - public ObjectChunk box(Chunk primitives) { - final ByteChunk byteChunk = primitives.asByteChunk(); - for (int ii = 0; ii < primitives.size(); ++ii) { - objectChunk.set(ii, io.deephaven.util.type.TypeUtils.box(byteChunk.get(ii))); + public ObjectChunk box(Chunk values) { + final ByteChunk byteChunk = values.asByteChunk(); + for (int ii = 0; ii < values.size(); ++ii) { + objectChunk.set(ii, TypeUtils.box(byteChunk.get(ii))); } - objectChunk.setSize(primitives.size()); + objectChunk.setSize(values.size()); return objectChunk; } } @@ -128,12 +166,12 @@ private static class ShortBoxer extends BoxerCommon { } @Override - public ObjectChunk box(Chunk primitives) { - final ShortChunk shortChunk = primitives.asShortChunk(); - for (int ii = 0; ii < primitives.size(); ++ii) { - objectChunk.set(ii, io.deephaven.util.type.TypeUtils.box(shortChunk.get(ii))); + public ObjectChunk box(Chunk values) { + final ShortChunk shortChunk = values.asShortChunk(); + for (int ii = 0; ii < values.size(); ++ii) { + objectChunk.set(ii, TypeUtils.box(shortChunk.get(ii))); } - objectChunk.setSize(primitives.size()); + objectChunk.setSize(values.size()); return objectChunk; } } @@ -144,12 +182,12 @@ private static class IntBoxer extends BoxerCommon { } @Override - public ObjectChunk box(Chunk primitives) { - final IntChunk intChunk = primitives.asIntChunk(); - for (int ii = 0; ii < primitives.size(); ++ii) { - objectChunk.set(ii, io.deephaven.util.type.TypeUtils.box(intChunk.get(ii))); + public ObjectChunk box(Chunk values) { + final IntChunk intChunk = values.asIntChunk(); + for (int ii = 0; ii < values.size(); ++ii) { + objectChunk.set(ii, TypeUtils.box(intChunk.get(ii))); } - objectChunk.setSize(primitives.size()); + objectChunk.setSize(values.size()); return objectChunk; } } @@ -160,12 +198,12 @@ private static class LongBoxer extends BoxerCommon { } @Override - public ObjectChunk box(Chunk primitives) { - final LongChunk longChunk = primitives.asLongChunk(); - for (int ii = 0; ii < primitives.size(); ++ii) { - objectChunk.set(ii, io.deephaven.util.type.TypeUtils.box(longChunk.get(ii))); + public ObjectChunk box(Chunk values) { + final LongChunk longChunk = values.asLongChunk(); + for (int ii = 0; ii < values.size(); ++ii) { + objectChunk.set(ii, TypeUtils.box(longChunk.get(ii))); } - objectChunk.setSize(primitives.size()); + objectChunk.setSize(values.size()); return objectChunk; } } @@ -176,12 +214,12 @@ private static class FloatBoxer extends BoxerCommon { } @Override - public ObjectChunk box(Chunk primitives) { - final FloatChunk floatChunk = primitives.asFloatChunk(); - for (int ii = 0; ii < primitives.size(); ++ii) { - objectChunk.set(ii, io.deephaven.util.type.TypeUtils.box(floatChunk.get(ii))); + public ObjectChunk box(Chunk values) { + final FloatChunk floatChunk = values.asFloatChunk(); + for (int ii = 0; ii < values.size(); ++ii) { + objectChunk.set(ii, TypeUtils.box(floatChunk.get(ii))); } - objectChunk.setSize(primitives.size()); + objectChunk.setSize(values.size()); return objectChunk; } } @@ -192,12 +230,12 @@ private static class DoubleBoxer extends BoxerCommon { } @Override - public ObjectChunk box(Chunk primitives) { - final DoubleChunk doubleChunk = primitives.asDoubleChunk(); - for (int ii = 0; ii < primitives.size(); ++ii) { + public ObjectChunk box(Chunk values) { + final DoubleChunk doubleChunk = values.asDoubleChunk(); + for (int ii = 0; ii < values.size(); ++ii) { objectChunk.set(ii, TypeUtils.box(doubleChunk.get(ii))); } - objectChunk.setSize(primitives.size()); + objectChunk.setSize(values.size()); return objectChunk; } } diff --git a/engine/table/src/main/java/io/deephaven/engine/table/impl/SourceTable.java b/engine/table/src/main/java/io/deephaven/engine/table/impl/SourceTable.java index 28d8f75822d..bab3e560b05 100644 --- a/engine/table/src/main/java/io/deephaven/engine/table/impl/SourceTable.java +++ b/engine/table/src/main/java/io/deephaven/engine/table/impl/SourceTable.java @@ -71,7 +71,7 @@ public abstract class SourceTable> exte /** * The update source object for refreshing locations and location sizes. */ - private Runnable locationChangePoller; + private LocationChangePoller locationChangePoller; /** * Construct a new disk-backed table. @@ -336,6 +336,9 @@ protected void destroy() { if (updateSourceRegistrar != null) { if (locationChangePoller != null) { updateSourceRegistrar.removeSource(locationChangePoller); + // NB: we do not want to null out any locationChangePoller.locationBuffer here, as they may still be in + // use by a notification delivery running currently with this destroy. + locationChangePoller.locationBuffer.reset(); } } } diff --git a/engine/table/src/main/java/io/deephaven/engine/table/impl/locations/TableLocation.java b/engine/table/src/main/java/io/deephaven/engine/table/impl/locations/TableLocation.java index 29aba5a011c..63ab9e03d1b 100644 --- a/engine/table/src/main/java/io/deephaven/engine/table/impl/locations/TableLocation.java +++ b/engine/table/src/main/java/io/deephaven/engine/table/impl/locations/TableLocation.java @@ -65,8 +65,8 @@ interface Listener extends BasicTableDataListener { * or 1 handleException callbacks during invocation and continuing after completion, on a thread determined by the * implementation. Don't hold a lock that prevents notification delivery while subscribing! *

- * This method only guarantees eventually consistent state. To force a state update, use run() after subscription - * completes. + * This method only guarantees eventually consistent state. To force a state update, use refresh() after + * subscription completes. * * @param listener A listener */ diff --git a/engine/table/src/main/java/io/deephaven/engine/table/impl/locations/TableLocationProvider.java b/engine/table/src/main/java/io/deephaven/engine/table/impl/locations/TableLocationProvider.java index cda4c542264..b6d7c94439e 100644 --- a/engine/table/src/main/java/io/deephaven/engine/table/impl/locations/TableLocationProvider.java +++ b/engine/table/src/main/java/io/deephaven/engine/table/impl/locations/TableLocationProvider.java @@ -102,8 +102,8 @@ default void handleTableLocationKeysUpdate( * must not hold any lock that prevents notification delivery while subscribing. Callers must guard * against duplicate notifications. *

- * This method only guarantees eventually consistent state. To force a state update, use run() after subscription - * completes. + * This method only guarantees eventually consistent state. To force a state update, use refresh() after + * subscription completes. * * @param listener A listener. */ diff --git a/engine/table/src/main/java/io/deephaven/engine/table/impl/locations/impl/TableLocationSubscriptionBuffer.java b/engine/table/src/main/java/io/deephaven/engine/table/impl/locations/impl/TableLocationSubscriptionBuffer.java index e7a5fab9a59..3a59197a265 100644 --- a/engine/table/src/main/java/io/deephaven/engine/table/impl/locations/impl/TableLocationSubscriptionBuffer.java +++ b/engine/table/src/main/java/io/deephaven/engine/table/impl/locations/impl/TableLocationSubscriptionBuffer.java @@ -81,7 +81,7 @@ public synchronized LocationUpdate processPending() { if (tableLocationProvider.supportsSubscriptions()) { tableLocationProvider.subscribe(this); } else { - // NB: Providers that don't support subscriptions don't tick - this single call to run is + // NB: Providers that don't support subscriptions don't tick - this single call to refresh is // sufficient. tableLocationProvider.refresh(); final Collection> tableLocationKeys = new ArrayList<>(); diff --git a/engine/table/src/main/java/io/deephaven/engine/table/impl/locations/impl/TableLocationUpdateSubscriptionBuffer.java b/engine/table/src/main/java/io/deephaven/engine/table/impl/locations/impl/TableLocationUpdateSubscriptionBuffer.java index 0422d3d703d..a0777bd362a 100644 --- a/engine/table/src/main/java/io/deephaven/engine/table/impl/locations/impl/TableLocationUpdateSubscriptionBuffer.java +++ b/engine/table/src/main/java/io/deephaven/engine/table/impl/locations/impl/TableLocationUpdateSubscriptionBuffer.java @@ -40,7 +40,7 @@ public synchronized boolean processPending() { if (tableLocation.supportsSubscriptions()) { tableLocation.subscribe(this); } else { - // NB: Locations that don't support subscriptions don't tick - this single call to run is + // NB: Locations that don't support subscriptions don't tick - this single call to refresh is // sufficient. tableLocation.refresh(); handleUpdate(); diff --git a/engine/table/src/main/java/io/deephaven/engine/table/impl/sources/regioned/RegionedColumnSourceManager.java b/engine/table/src/main/java/io/deephaven/engine/table/impl/sources/regioned/RegionedColumnSourceManager.java index 14cace8f619..0a131275a7b 100644 --- a/engine/table/src/main/java/io/deephaven/engine/table/impl/sources/regioned/RegionedColumnSourceManager.java +++ b/engine/table/src/main/java/io/deephaven/engine/table/impl/sources/regioned/RegionedColumnSourceManager.java @@ -180,7 +180,25 @@ public class RegionedColumnSourceManager implements ColumnSourceManager, Delegat : TableDefinition.inferFrom(columnSourceMap); if (isRefreshing) { - livenessNode = new LivenessArtifact() {}; + livenessNode = new LivenessArtifact() { + @Override + protected void destroy() { + super.destroy(); + // NB: we do not want to null out any subscriptionBuffers here, as they may still be in use by a + // notification delivery running currently with this destroy. We also do not want to clear the table + // location maps as these locations may still be useful for static tables. + for (final EmptyTableLocationEntry entry : emptyTableLocations.values()) { + if (entry.subscriptionBuffer != null) { + entry.subscriptionBuffer.reset(); + } + } + for (final IncludedTableLocationEntry entry : includedTableLocations.values()) { + if (entry.subscriptionBuffer != null) { + entry.subscriptionBuffer.reset(); + } + } + } + }; } else { // This RCSM wil be managing table locations to prevent them from being de-scoped but will not otherwise // participate in the liveness management process. @@ -519,7 +537,6 @@ public final synchronized boolean isEmpty() { return sharedColumnSources; } - @Override public LivenessNode asLivenessNode() { return livenessNode; } diff --git a/engine/table/src/test/java/io/deephaven/engine/table/impl/TestPartitionAwareSourceTable.java b/engine/table/src/test/java/io/deephaven/engine/table/impl/TestPartitionAwareSourceTable.java index d3fc36d80b4..9446b49be8a 100644 --- a/engine/table/src/test/java/io/deephaven/engine/table/impl/TestPartitionAwareSourceTable.java +++ b/engine/table/src/test/java/io/deephaven/engine/table/impl/TestPartitionAwareSourceTable.java @@ -208,6 +208,7 @@ public void setUp() throws Exception { @Override public void tearDown() throws Exception { try { + allowLivenessRelease(); super.tearDown(); } finally { if (coalesced != null) { @@ -217,6 +218,22 @@ public void tearDown() throws Exception { } } + private void allowLivenessRelease() { + checking(new Expectations() { + { + allowing(locationProvider).supportsSubscriptions(); + allowing(locationProvider).unsubscribe(with(any(TableLocationProvider.Listener.class))); + will(returnValue(true)); + for (int li = 0; li < tableLocations.length; ++li) { + final TableLocation tableLocation = tableLocations[li]; + allowing(tableLocation).supportsSubscriptions(); + will(returnValue(true)); + allowing(tableLocation).unsubscribe(with(any(TableLocation.Listener.class))); + } + } + }); + } + private Map> getIncludedColumnsMap(final int... indices) { return IntStream.of(indices) .mapToObj(ci -> new Pair<>(TABLE_DEFINITION.getColumns().get(ci).getName(), columnSources[ci])) @@ -443,6 +460,7 @@ public Object invoke(Invocation invocation) { errorNotification.reset(); final ControlledUpdateGraph updateGraph = ExecutionContext.getContext().getUpdateGraph().cast(); updateGraph.runWithinUnitTestCycle(() -> { + allowLivenessRelease(); SUT.refresh(); updateGraph.markSourcesRefreshedForUnitTests(); }, false); diff --git a/engine/table/src/test/java/io/deephaven/engine/table/impl/sources/regioned/TestRegionedColumnSourceManager.java b/engine/table/src/test/java/io/deephaven/engine/table/impl/sources/regioned/TestRegionedColumnSourceManager.java index ca531a2ec2d..0aeb1b05a07 100644 --- a/engine/table/src/test/java/io/deephaven/engine/table/impl/sources/regioned/TestRegionedColumnSourceManager.java +++ b/engine/table/src/test/java/io/deephaven/engine/table/impl/sources/regioned/TestRegionedColumnSourceManager.java @@ -662,6 +662,23 @@ public void testRefreshing() { checkIndexes(); assertEquals(Arrays.asList(tableLocation0A, tableLocation1A, tableLocation0B, tableLocation1B), SUT.includedLocations()); + + // expect table locations to be cleaned up via LivenessScope release as the test exits + IntStream.range(0, tableLocations.length).forEachOrdered(li -> { + final TableLocation tl = tableLocations[li]; + checking(new Expectations() { + { + oneOf(tl).supportsSubscriptions(); + if (li % 2 == 0) { + // Even locations don't support subscriptions + will(returnValue(false)); + } else { + will(returnValue(true)); + oneOf(tl).unsubscribe(with(subscriptionBuffers[li])); + } + } + }); + }); } private static void maybePrintStackTrace(@NotNull final Exception e) { diff --git a/extensions/barrage/build.gradle b/extensions/barrage/build.gradle index 82db148d212..f57cdcadbeb 100644 --- a/extensions/barrage/build.gradle +++ b/extensions/barrage/build.gradle @@ -22,6 +22,7 @@ dependencies { implementation libs.arrow.vector implementation libs.arrow.format + implementation project(path: ':extensions-source-support') compileOnly project(':util-immutables') annotationProcessor libs.immutables.value diff --git a/extensions/barrage/src/main/java/io/deephaven/extensions/barrage/util/ArrowToTableConverter.java b/extensions/barrage/src/main/java/io/deephaven/extensions/barrage/util/ArrowToTableConverter.java index c57c2111a17..2c8388ad9d1 100644 --- a/extensions/barrage/src/main/java/io/deephaven/extensions/barrage/util/ArrowToTableConverter.java +++ b/extensions/barrage/src/main/java/io/deephaven/extensions/barrage/util/ArrowToTableConverter.java @@ -24,6 +24,7 @@ import org.apache.arrow.flatbuf.MessageHeader; import org.apache.arrow.flatbuf.RecordBatch; import org.apache.arrow.flatbuf.Schema; +import org.jetbrains.annotations.NotNull; import java.io.IOException; import java.nio.ByteBuffer; @@ -51,7 +52,7 @@ public class ArrowToTableConverter { private volatile boolean completed = false; - private static BarrageProtoUtil.MessageInfo parseArrowIpcMessage(final ByteBuffer bb) throws IOException { + public static BarrageProtoUtil.MessageInfo parseArrowIpcMessage(final ByteBuffer bb) { final BarrageProtoUtil.MessageInfo mi = new BarrageProtoUtil.MessageInfo(); bb.order(ByteOrder.LITTLE_ENDIAN); @@ -64,13 +65,45 @@ private static BarrageProtoUtil.MessageInfo parseArrowIpcMessage(final ByteBuffe final ByteBuffer bodyBB = bb.slice(); final ByteBufferInputStream bbis = new ByteBufferInputStream(bodyBB); final CodedInputStream decoder = CodedInputStream.newInstance(bbis); - // noinspection UnstableApiUsage mi.inputStream = new LittleEndianDataInputStream( new BarrageProtoUtil.ObjectInputStreamAdapter(decoder, bodyBB.remaining())); } return mi; } + public static Schema parseArrowSchema(final BarrageProtoUtil.MessageInfo mi) { + if (mi.header.headerType() != MessageHeader.Schema) { + throw new IllegalArgumentException("The input is not a valid Arrow Schema IPC message"); + } + + // The Schema instance (especially originated from Python) can't be assumed to be valid after the return + // of this method. Until https://github.com/jpy-consortium/jpy/issues/126 is resolved, we need to make a copy of + // the header to use after the return of this method. + ByteBuffer original = mi.header.getByteBuffer(); + ByteBuffer copy = ByteBuffer.allocate(original.remaining()).put(original).rewind(); + Schema schema = new Schema(); + Message.getRootAsMessage(copy).header(schema); + + return schema; + } + + public static PrimitiveIterator.OfLong extractBufferInfo(@NotNull final RecordBatch batch) { + final long[] bufferInfo = new long[batch.buffersLength()]; + for (int bi = 0; bi < batch.buffersLength(); ++bi) { + int offset = LongSizedDataStructure.intSize("BufferInfo", batch.buffers(bi).offset()); + int length = LongSizedDataStructure.intSize("BufferInfo", batch.buffers(bi).length()); + + if (bi < batch.buffersLength() - 1) { + final int nextOffset = + LongSizedDataStructure.intSize("BufferInfo", batch.buffers(bi + 1).offset()); + // our parsers handle overhanging buffers + length += Math.max(0, nextOffset - offset - length); + } + bufferInfo[bi] = length; + } + return Arrays.stream(bufferInfo).iterator(); + } + @ScriptApi public synchronized void setSchema(final ByteBuffer ipcMessage) { // The input ByteBuffer instance (especially originated from Python) can't be assumed to be valid after the @@ -79,11 +112,8 @@ public synchronized void setSchema(final ByteBuffer ipcMessage) { if (completed) { throw new IllegalStateException("Conversion is complete; cannot process additional messages"); } - final BarrageProtoUtil.MessageInfo mi = getMessageInfo(ipcMessage); - if (mi.header.headerType() != MessageHeader.Schema) { - throw new IllegalArgumentException("The input is not a valid Arrow Schema IPC message"); - } - parseSchema(mi.header); + final BarrageProtoUtil.MessageInfo mi = parseArrowIpcMessage(ipcMessage); + configureWithSchema(parseArrowSchema(mi)); } @ScriptApi @@ -108,7 +138,7 @@ public synchronized void addRecordBatch(final ByteBuffer ipcMessage) { throw new IllegalStateException("Arrow schema must be provided before record batches can be added"); } - final BarrageProtoUtil.MessageInfo mi = getMessageInfo(ipcMessage); + final BarrageProtoUtil.MessageInfo mi = parseArrowIpcMessage(ipcMessage); if (mi.header.headerType() != MessageHeader.RecordBatch) { throw new IllegalArgumentException("The input is not a valid Arrow RecordBatch IPC message"); } @@ -138,14 +168,7 @@ public synchronized void onCompleted() throws InterruptedException { completed = true; } - protected void parseSchema(final Message message) { - // The Schema instance (especially originated from Python) can't be assumed to be valid after the return - // of this method. Until https://github.com/jpy-consortium/jpy/issues/126 is resolved, we need to make a copy of - // the header to use after the return of this method. - ByteBuffer original = message.getByteBuffer(); - ByteBuffer copy = ByteBuffer.allocate(original.remaining()).put(original).rewind(); - Schema schema = new Schema(); - Message.getRootAsMessage(copy).header(schema); + protected void configureWithSchema(final Schema schema) { if (resultTable != null) { throw Exceptions.statusRuntimeException(Code.INVALID_ARGUMENT, "Schema evolution not supported"); } @@ -179,20 +202,7 @@ protected BarrageMessage createBarrageMessage(BarrageProtoUtil.MessageInfo mi, i new FlatBufferIteratorAdapter<>(batch.nodesLength(), i -> new ChunkInputStreamGenerator.FieldNodeInfo(batch.nodes(i))); - final long[] bufferInfo = new long[batch.buffersLength()]; - for (int bi = 0; bi < batch.buffersLength(); ++bi) { - int offset = LongSizedDataStructure.intSize("BufferInfo", batch.buffers(bi).offset()); - int length = LongSizedDataStructure.intSize("BufferInfo", batch.buffers(bi).length()); - - if (bi < batch.buffersLength() - 1) { - final int nextOffset = - LongSizedDataStructure.intSize("BufferInfo", batch.buffers(bi + 1).offset()); - // our parsers handle overhanging buffers - length += Math.max(0, nextOffset - offset - length); - } - bufferInfo[bi] = length; - } - final PrimitiveIterator.OfLong bufferInfoIter = Arrays.stream(bufferInfo).iterator(); + final PrimitiveIterator.OfLong bufferInfoIter = extractBufferInfo(batch); msg.rowsRemoved = RowSetFactory.empty(); msg.shifted = RowSetShiftData.EMPTY; @@ -221,16 +231,4 @@ protected BarrageMessage createBarrageMessage(BarrageProtoUtil.MessageInfo mi, i msg.length = numRowsAdded; return msg; } - - private BarrageProtoUtil.MessageInfo getMessageInfo(ByteBuffer ipcMessage) { - final BarrageProtoUtil.MessageInfo mi; - try { - mi = parseArrowIpcMessage(ipcMessage); - } catch (IOException unexpected) { - throw new UncheckedDeephavenException(unexpected); - } - return mi; - } - - } diff --git a/extensions/barrage/src/main/java/io/deephaven/extensions/barrage/util/BarrageUtil.java b/extensions/barrage/src/main/java/io/deephaven/extensions/barrage/util/BarrageUtil.java index 8c5abd669ee..3d83db05833 100755 --- a/extensions/barrage/src/main/java/io/deephaven/extensions/barrage/util/BarrageUtil.java +++ b/extensions/barrage/src/main/java/io/deephaven/extensions/barrage/util/BarrageUtil.java @@ -30,6 +30,8 @@ import io.deephaven.extensions.barrage.BarragePerformanceLog; import io.deephaven.extensions.barrage.BarrageSnapshotOptions; import io.deephaven.extensions.barrage.BarrageStreamGenerator; +import io.deephaven.extensions.barrage.chunk.ChunkReader; +import io.deephaven.extensions.barrage.chunk.DefaultChunkReadingFactory; import io.deephaven.extensions.barrage.chunk.vector.VectorExpansionKernel; import io.deephaven.internal.log.LoggerFactory; import io.deephaven.io.logger.Logger; @@ -41,6 +43,7 @@ import io.deephaven.engine.util.input.InputTableUpdater; import io.deephaven.chunk.ChunkType; import io.deephaven.proto.backplane.grpc.ExportedTableCreationResponse; +import io.deephaven.qst.column.Column; import io.deephaven.util.type.TypeUtils; import io.deephaven.vector.Vector; import io.grpc.stub.StreamObserver; @@ -73,6 +76,8 @@ import java.util.stream.Collectors; import java.util.stream.Stream; +import static io.deephaven.extensions.barrage.chunk.ChunkReader.typeInfo; + public class BarrageUtil { public static final BarrageSnapshotOptions DEFAULT_SNAPSHOT_DESER_OPTIONS = BarrageSnapshotOptions.builder().build(); @@ -511,6 +516,27 @@ public Class[] computeWireComponentTypes() { return tableDef.getColumnStream() .map(ColumnDefinition::getComponentType).toArray(Class[]::new); } + + public ChunkReader[] computeChunkReaders( + @NotNull final ChunkReader.Factory chunkReaderFactory, + @NotNull final org.apache.arrow.flatbuf.Schema schema, + @NotNull final StreamReaderOptions barrageOptions) { + final ChunkReader[] readers = new ChunkReader[tableDef.numColumns()]; + + final List> columns = tableDef.getColumns(); + for (int ii = 0; ii < tableDef.numColumns(); ++ii) { + final ColumnDefinition columnDefinition = columns.get(ii); + final int factor = (conversionFactors == null) ? 1 : conversionFactors[ii]; + final ChunkReader.TypeInfo typeInfo = typeInfo( + ReinterpretUtils.maybeConvertToWritablePrimitiveChunkType(columnDefinition.getDataType()), + columnDefinition.getDataType(), + columnDefinition.getComponentType(), + schema.fields(ii)); + readers[ii] = DefaultChunkReadingFactory.INSTANCE.getReader(barrageOptions, factor, typeInfo); + } + + return readers; + } } private static void setConversionFactor( diff --git a/extensions/barrage/src/main/java/io/deephaven/extensions/barrage/util/PythonTableDataService.java b/extensions/barrage/src/main/java/io/deephaven/extensions/barrage/util/PythonTableDataService.java new file mode 100644 index 00000000000..14ecfdc1c09 --- /dev/null +++ b/extensions/barrage/src/main/java/io/deephaven/extensions/barrage/util/PythonTableDataService.java @@ -0,0 +1,999 @@ +// +// Copyright (c) 2016-2024 Deephaven Data Labs and Patent Pending +// +package io.deephaven.extensions.barrage.util; + +import io.deephaven.UncheckedDeephavenException; +import io.deephaven.api.SortColumn; +import io.deephaven.base.log.LogOutput; +import io.deephaven.chunk.Chunk; +import io.deephaven.chunk.WritableChunk; +import io.deephaven.chunk.attributes.Values; +import io.deephaven.configuration.Configuration; +import io.deephaven.engine.context.ExecutionContext; +import io.deephaven.engine.rowset.RowSetFactory; +import io.deephaven.engine.table.BasicDataIndex; +import io.deephaven.engine.table.ColumnDefinition; +import io.deephaven.engine.table.Table; +import io.deephaven.engine.table.TableDefinition; +import io.deephaven.engine.table.impl.PartitionAwareSourceTable; +import io.deephaven.engine.table.impl.TableUpdateMode; +import io.deephaven.engine.table.impl.chunkboxer.ChunkBoxer; +import io.deephaven.engine.table.impl.locations.*; +import io.deephaven.engine.table.impl.locations.impl.*; +import io.deephaven.engine.table.impl.sources.regioned.*; +import io.deephaven.extensions.barrage.chunk.ChunkInputStreamGenerator; +import io.deephaven.extensions.barrage.chunk.ChunkReader; +import io.deephaven.extensions.barrage.chunk.DefaultChunkReadingFactory; +import io.deephaven.generic.region.*; +import io.deephaven.io.log.impl.LogOutputStringImpl; +import io.deephaven.util.SafeCloseable; +import io.deephaven.util.annotations.ScriptApi; +import org.apache.arrow.flatbuf.MessageHeader; +import org.apache.arrow.flatbuf.RecordBatch; +import org.apache.arrow.flatbuf.Schema; +import org.jetbrains.annotations.NotNull; +import org.jetbrains.annotations.Nullable; +import org.jpy.PyObject; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.*; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.BiConsumer; +import java.util.function.Consumer; +import java.util.function.Function; +import java.util.function.LongConsumer; +import java.util.stream.IntStream; + +import static io.deephaven.extensions.barrage.util.ArrowToTableConverter.parseArrowIpcMessage; + +@ScriptApi +public class PythonTableDataService extends AbstractTableDataService { + + private static final int DEFAULT_PAGE_SIZE = Configuration.getInstance() + .getIntegerForClassWithDefault(PythonTableDataService.class, "defaultPageSize", 1 << 16); + private static final long REGION_MASK = RegionedColumnSource.ROW_KEY_TO_SUB_REGION_ROW_INDEX_MASK; + + private final BackendAccessor backend; + private final ChunkReader.Factory chunkReaderFactory; + private final StreamReaderOptions streamReaderOptions; + private final int pageSize; + + @ScriptApi + public static PythonTableDataService create( + @NotNull final PyObject pyTableDataService, + @Nullable final ChunkReader.Factory chunkReaderFactory, + @Nullable final StreamReaderOptions streamReaderOptions, + final int pageSize) { + return new PythonTableDataService( + pyTableDataService, + chunkReaderFactory == null ? DefaultChunkReadingFactory.INSTANCE : chunkReaderFactory, + streamReaderOptions == null ? BarrageUtil.DEFAULT_SNAPSHOT_DESER_OPTIONS : streamReaderOptions, + pageSize <= 0 ? DEFAULT_PAGE_SIZE : pageSize); + } + + /** + * Construct a Deephaven {@link io.deephaven.engine.table.impl.locations.TableDataService TableDataService} wrapping + * the provided Python TableDataServiceBackend. + * + * @param pyTableDataService The Python TableDataService + * @param pageSize The page size to use for all regions + */ + private PythonTableDataService( + @NotNull final PyObject pyTableDataService, + @NotNull final ChunkReader.Factory chunkReaderFactory, + @NotNull final StreamReaderOptions streamReaderOptions, + final int pageSize) { + super("PythonTableDataService"); + this.backend = new BackendAccessor(pyTableDataService); + this.chunkReaderFactory = chunkReaderFactory; + this.streamReaderOptions = streamReaderOptions; + this.pageSize = pageSize <= 0 ? DEFAULT_PAGE_SIZE : pageSize; + } + + /** + * Get a Deephaven {@link Table} for the supplied {@link TableKey}. + * + * @param tableKey The table key + * @param live Whether the table should update as new data becomes available + * @return The {@link Table} + */ + @ScriptApi + public Table makeTable(@NotNull final TableKeyImpl tableKey, final boolean live) { + final TableLocationProviderImpl tableLocationProvider = + (TableLocationProviderImpl) getTableLocationProvider(tableKey); + return new PartitionAwareSourceTable( + tableLocationProvider.tableDefinition, + tableKey.toString(), + RegionedTableComponentFactoryImpl.INSTANCE, + tableLocationProvider, + live ? ExecutionContext.getContext().getUpdateGraph() : null); + } + + /** + * This Backend impl marries the Python TableDataService with the Deephaven TableDataService. By performing the + * object translation here, we can keep the Python TableDataService implementation simple and focused on the Python + * side of the implementation. + */ + private class BackendAccessor { + private final PyObject pyTableDataService; + + private BackendAccessor( + @NotNull final PyObject pyTableDataService) { + this.pyTableDataService = pyTableDataService; + } + + /** + * Get two schemas, the first for partitioning columns whose values will be derived from TableLocationKey and + * applied to all rows in the associated TableLocation, and the second specifying the table data to be read + * chunk-wise (in columnar fashion) from the TableLocations. + * + * @param tableKey the table key + * @return the schemas + */ + public BarrageUtil.ConvertedArrowSchema[] getTableSchema( + @NotNull final TableKeyImpl tableKey) { + final AsyncState asyncState = new AsyncState<>(); + + final Consumer onRawSchemas = byteBuffers -> { + final BarrageUtil.ConvertedArrowSchema[] schemas = new BarrageUtil.ConvertedArrowSchema[2]; + + if (byteBuffers.length != schemas.length) { + asyncState.setError(new IllegalArgumentException(String.format( + "Provided too many IPC messages. Expected %d, received %d.", + schemas.length, byteBuffers.length))); + return; + } + + for (int ii = 0; ii < schemas.length; ++ii) { + try { + schemas[ii] = BarrageUtil.convertArrowSchema( + ArrowToTableConverter.parseArrowSchema( + ArrowToTableConverter.parseArrowIpcMessage( + byteBuffers[ii]))); + } catch (final Exception e) { + final String schemaType = ii % 2 == 0 ? "data table" : "partitioning column"; + asyncState.setError(new IllegalArgumentException(String.format( + "Failed to parse %s schema message", schemaType), e)); + return; + } + } + + asyncState.setResult(schemas); + }; + + final Consumer onFailure = errorString -> asyncState.setError( + new UncheckedDeephavenException(errorString)); + + pyTableDataService.call("_table_schema", tableKey.key, onRawSchemas, onFailure); + + return asyncState.awaitResult(err -> new TableDataException(String.format( + "%s: table_schema failed", tableKey), err)); + } + + /** + * Get the existing table locations for the provided {@code tableKey}. + * + * @param definition the table definition to validate partitioning columns against + * @param tableKey the table key + * @param listener the listener to call with each existing table location key + */ + public void getTableLocations( + @NotNull final TableDefinition definition, + @NotNull final TableKeyImpl tableKey, + @NotNull final Consumer listener) { + final AsyncState asyncState = new AsyncState<>(); + + final BiConsumer convertingListener = + (tableLocationKey, byteBuffers) -> { + try { + processTableLocationKey(definition, tableKey, listener, tableLocationKey, byteBuffers); + } catch (final RuntimeException e) { + asyncState.setError(e); + } + }; + + final Runnable onSuccess = () -> asyncState.setResult(true); + final Consumer onFailure = errorString -> asyncState.setError( + new UncheckedDeephavenException(errorString)); + + pyTableDataService.call("_table_locations", tableKey.key, convertingListener, onSuccess, onFailure); + asyncState.awaitResult(err -> new TableDataException(String.format( + "%s: table_locations failed", tableKey), err)); + } + + /** + * Subscribe to table location updates for the provided {@code tableKey}. + *

+ * The {@code tableLocationListener} should be invoked with all existing table locations. Any asynchronous calls + * to {@code tableLocationListener}, {@code successCallback}, or {@code failureCallback} will block until this + * method has completed. + * + * @param definition the table definition to validate partitioning columns against + * @param tableKey the table key + * @param tableLocationListener the tableLocationListener to call with each table location key + * @param successCallback the success callback; called when the subscription is successfully established and the + * tableLocationListener has been called with all existing table locations + * @param failureCallback the failure callback; called to deliver an exception triggered while activating or + * maintaining the underlying data source + * @return a {@link SafeCloseable} that can be used to cancel the subscription + */ + public SafeCloseable subscribeToTableLocations( + @NotNull final TableDefinition definition, + @NotNull final TableKeyImpl tableKey, + @NotNull final Consumer tableLocationListener, + @NotNull final Runnable successCallback, + @NotNull final Consumer failureCallback) { + final AtomicBoolean locationProcessingFailed = new AtomicBoolean(); + final AtomicReference cancellationCallbackRef = new AtomicReference<>(); + + final BiConsumer convertingListener = + (tableLocationKey, byteBuffers) -> { + if (locationProcessingFailed.get()) { + return; + } + try { + processTableLocationKey( + definition, tableKey, tableLocationListener, tableLocationKey, byteBuffers); + } catch (final RuntimeException e) { + failureCallback.accept(e); + // we must also cancel the python subscription + locationProcessingFailed.set(true); + final SafeCloseable localCancellationCallback = cancellationCallbackRef.get(); + if (localCancellationCallback != null) { + localCancellationCallback.close(); + } + } + }; + + final Consumer onFailure = errorString -> failureCallback.accept( + new UncheckedDeephavenException(errorString)); + + final PyObject cancellationCallback = pyTableDataService.call( + "_subscribe_to_table_locations", tableKey.key, + convertingListener, successCallback, onFailure); + final SafeCloseable cancellationCallbackOnce = new SafeCloseable() { + + private final AtomicBoolean invoked = new AtomicBoolean(); + + @Override + public void close() { + if (invoked.compareAndSet(false, true)) { + cancellationCallback.call("__call__"); + cancellationCallbackRef.set(null); + } + } + }; + cancellationCallbackRef.set(cancellationCallbackOnce); + if (locationProcessingFailed.get()) { + cancellationCallbackOnce.close(); + } + return cancellationCallbackOnce; + } + + private void processTableLocationKey( + @NotNull final TableDefinition definition, + @NotNull final TableKeyImpl tableKey, + @NotNull final Consumer listener, + @NotNull final TableLocationKeyImpl tableLocationKey, + @NotNull final ByteBuffer[] byteBuffers) { + if (byteBuffers.length == 0) { + if (!definition.getPartitioningColumns().isEmpty()) { + throw new IllegalArgumentException(String.format("%s:%s: table_location_key callback expected " + + "partitioned column values but none were provided", tableKey, tableLocationKey)); + } + listener.accept(tableLocationKey); + return; + } + + if (byteBuffers.length != 2) { + throw new IllegalArgumentException(String.format("%s:%s: table_location_key callback expected 2 IPC " + + "messages describing the wire format of the partitioning columns followed by partitioning " + + "values, but received %d messages", tableKey, tableLocationKey, byteBuffers.length)); + } + + // partitioning values must be in the same order as the partitioning keys, so we'll prepare an ordered map + // with null values for each key so that we may fill them in out of order + final Map> partitioningValues = new LinkedHashMap<>( + definition.getPartitioningColumns().size()); + definition.getPartitioningColumns().forEach(column -> partitioningValues.put(column.getName(), null)); + + // note that we recompute chunk readers for each location to support some schema evolution + final Schema partitioningValuesSchema = ArrowToTableConverter.parseArrowSchema( + ArrowToTableConverter.parseArrowIpcMessage(byteBuffers[0])); + final BarrageUtil.ConvertedArrowSchema schemaPlus = + BarrageUtil.convertArrowSchema(partitioningValuesSchema); + + try { + definition.checkCompatibility(schemaPlus.tableDef); + } catch (TableDefinition.IncompatibleTableDefinitionException err) { + throw new IllegalArgumentException(String.format("%s:%s: table_location_key callback received " + + "partitioning schema that is incompatible with table definition", tableKey, tableLocationKey), + err); + } + + final ChunkReader[] readers = schemaPlus.computeChunkReaders( + chunkReaderFactory, + partitioningValuesSchema, + streamReaderOptions); + + final BarrageProtoUtil.MessageInfo recordBatchMessageInfo = parseArrowIpcMessage(byteBuffers[1]); + if (recordBatchMessageInfo.header.headerType() != MessageHeader.RecordBatch) { + throw new IllegalArgumentException(String.format("%s:%s: table_location_key callback received 2nd IPC " + + "message that is not a valid Arrow RecordBatch", tableKey, tableLocationKey)); + } + final RecordBatch batch = (RecordBatch) recordBatchMessageInfo.header.header(new RecordBatch()); + + final Iterator fieldNodeIter = + new FlatBufferIteratorAdapter<>(batch.nodesLength(), + i -> new ChunkInputStreamGenerator.FieldNodeInfo(batch.nodes(i))); + + final PrimitiveIterator.OfLong bufferInfoIter = ArrowToTableConverter.extractBufferInfo(batch); + + // extract partitioning values and box them to be used as Comparable in the map + for (int ci = 0; ci < partitioningValuesSchema.fieldsLength(); ++ci) { + final String columnName = partitioningValuesSchema.fields(ci).name(); + try (final WritableChunk columnValues = readers[ci].readChunk( + fieldNodeIter, bufferInfoIter, recordBatchMessageInfo.inputStream, null, 0, 0)) { + + if (columnValues.size() != 1) { + throw new IllegalArgumentException(String.format("%s:%s: table_location_key callback received " + + "%d rows for partitioning column %s; expected 1", tableKey, tableLocationKey, + columnValues.size(), columnName)); + } + + partitioningValues.put(columnName, ChunkBoxer.boxedGet(columnValues, 0)); + } catch (final IOException ioe) { + throw new UncheckedDeephavenException(String.format( + "%s:%s: table_location_key callback failed to read partitioning column %s", tableKey, + tableLocationKey, columnName), ioe); + } + } + + listener.accept(new TableLocationKeyImpl(tableLocationKey.locationKey, partitioningValues)); + } + + /** + * Get the size of the given {@code tableLocationKey}. + * + * @param tableKey the table key + * @param tableLocationKey the table location key + * @param listener the listener to call with the table location size + */ + public void getTableLocationSize( + @NotNull final TableKeyImpl tableKey, + @NotNull final TableLocationKeyImpl tableLocationKey, + @NotNull final LongConsumer listener) { + final AsyncState asyncState = new AsyncState<>(); + + final LongConsumer onSize = asyncState::setResult; + final Consumer onFailure = errorString -> asyncState.setError( + new UncheckedDeephavenException(errorString)); + + pyTableDataService.call("_table_location_size", tableKey.key, tableLocationKey.locationKey, + onSize, onFailure); + + listener.accept(asyncState.awaitResult(err -> new TableDataException(String.format( + "%s:%s: table_location_size failed", tableKey, tableLocationKey), err))); + } + + /** + * Subscribe to the existing size and future size changes of a table location. + * + * @param tableKey the table key + * @param tableLocationKey the table location key + * @param sizeListener the sizeListener to call with the partition size + * @param successCallback the success callback; called when the subscription is successfully established and the + * sizeListener has been called with the initial size + * @param failureCallback the failure callback; called to deliver an exception triggered while activating or + * maintaining the underlying data source + * @return a {@link SafeCloseable} that can be used to cancel the subscription + */ + public SafeCloseable subscribeToTableLocationSize( + @NotNull final TableKeyImpl tableKey, + @NotNull final TableLocationKeyImpl tableLocationKey, + @NotNull final LongConsumer sizeListener, + @NotNull final Runnable successCallback, + @NotNull final Consumer failureCallback) { + + final PyObject cancellationCallback = pyTableDataService.call( + "_subscribe_to_table_location_size", tableKey.key, tableLocationKey.locationKey, + sizeListener, successCallback, failureCallback); + + return () -> cancellationCallback.call("__call__"); + } + + /** + * Get a range of data for a column. + * + * @param tableKey the table key + * @param tableLocationKey the table location key + * @param columnDefinition the column definition + * @param firstRowPosition the first row position + * @param minimumSize the minimum size + * @param maximumSize the maximum size + * @return the column values + */ + public List> getColumnValues( + @NotNull final TableKeyImpl tableKey, + @NotNull final TableLocationKeyImpl tableLocationKey, + @NotNull final ColumnDefinition columnDefinition, + final long firstRowPosition, + final int minimumSize, + final int maximumSize) { + + final AsyncState>> asyncState = new AsyncState<>(); + + final String columnName = columnDefinition.getName(); + final Consumer onMessages = messages -> { + if (messages.length < 2) { + asyncState.setError(new IllegalArgumentException(String.format( + "expected at least 2 IPC messages describing the wire format of the column followed by " + + "column values, but received %d messages", + messages.length))); + return; + } + final Schema schema = ArrowToTableConverter.parseArrowSchema( + ArrowToTableConverter.parseArrowIpcMessage(messages[0])); + final BarrageUtil.ConvertedArrowSchema schemaPlus = BarrageUtil.convertArrowSchema(schema); + + if (schema.fieldsLength() > 1) { + asyncState.setError(new IllegalArgumentException(String.format( + "Received more than one field. Received %d fields for columns %s.", + schema.fieldsLength(), + IntStream.range(0, schema.fieldsLength()) + .mapToObj(ci -> schema.fields(ci).name()) + .reduce((a, b) -> a + ", " + b).orElse("")))); + return; + } + if (!columnDefinition.isCompatible(schemaPlus.tableDef.getColumns().get(0))) { + asyncState.setError(new IllegalArgumentException(String.format( + "Received incompatible column definition. Expected %s, but received %s.", + columnDefinition, schemaPlus.tableDef.getColumns().get(0)))); + return; + } + + final ArrayList> resultChunks = new ArrayList<>(messages.length - 1); + final ChunkReader reader = schemaPlus.computeChunkReaders( + chunkReaderFactory, schema, streamReaderOptions)[0]; + int mi = 1; + try { + for (; mi < messages.length; ++mi) { + final BarrageProtoUtil.MessageInfo recordBatchMessageInfo = parseArrowIpcMessage(messages[mi]); + if (recordBatchMessageInfo.header.headerType() != MessageHeader.RecordBatch) { + throw new IllegalArgumentException(String.format( + "IPC message %d is not a valid Arrow RecordBatch IPC message", mi)); + } + final RecordBatch batch = (RecordBatch) recordBatchMessageInfo.header.header(new RecordBatch()); + + final Iterator fieldNodeIter = + new FlatBufferIteratorAdapter<>(batch.nodesLength(), + i -> new ChunkInputStreamGenerator.FieldNodeInfo(batch.nodes(i))); + + final PrimitiveIterator.OfLong bufferInfoIter = ArrowToTableConverter.extractBufferInfo(batch); + + resultChunks.add(reader.readChunk( + fieldNodeIter, bufferInfoIter, recordBatchMessageInfo.inputStream, null, 0, 0)); + } + + asyncState.setResult(resultChunks); + } catch (final IOException ioe) { + SafeCloseable.closeAll(resultChunks.iterator()); + asyncState.setError(new UncheckedDeephavenException(String.format( + "failed to read IPC message %d", mi), ioe)); + } catch (final RuntimeException e) { + SafeCloseable.closeAll(resultChunks.iterator()); + asyncState.setError(e); + } + }; + + final Consumer onFailure = errorString -> asyncState.setError( + new UncheckedDeephavenException(errorString)); + + pyTableDataService.call("_column_values", + tableKey.key, tableLocationKey.locationKey, columnName, firstRowPosition, + minimumSize, maximumSize, onMessages, onFailure); + + return asyncState.awaitResult(err -> new TableDataException(String.format( + "%s:%s: column_values(%s, %d, %d, %d) failed", + tableKey, tableLocationKey, columnName, firstRowPosition, minimumSize, maximumSize), err)); + } + } + + @Override + protected @NotNull TableLocationProvider makeTableLocationProvider(@NotNull final TableKey tableKey) { + if (!(tableKey instanceof TableKeyImpl)) { + throw new IllegalArgumentException(String.format("%s: unsupported TableKey %s", this, tableKey)); + } + return new TableLocationProviderImpl((TableKeyImpl) tableKey); + } + + /** + * {@link TableKey} implementation for TableService. + */ + public static class TableKeyImpl implements ImmutableTableKey { + + private final PyObject key; + private int cachedHashCode; + + public TableKeyImpl(@NotNull final PyObject key) { + this.key = key; + } + + @Override + public boolean equals(final Object other) { + if (this == other) { + return true; + } + if (!(other instanceof TableKeyImpl)) { + return false; + } + final TableKeyImpl otherTableKey = (TableKeyImpl) other; + return this.key.equals(otherTableKey.key); + } + + @Override + public int hashCode() { + if (cachedHashCode == 0) { + final int computedHashCode = Long.hashCode(key.call("__hash__").getLongValue()); + // Don't use 0; that's used by StandaloneTableKey, and also our sentinel for the need to compute + if (computedHashCode == 0) { + final int fallbackHashCode = TableKeyImpl.class.hashCode(); + cachedHashCode = fallbackHashCode == 0 ? 1 : fallbackHashCode; + } else { + cachedHashCode = computedHashCode; + } + } + return cachedHashCode; + } + + @Override + public LogOutput append(@NotNull final LogOutput logOutput) { + return logOutput.append(getImplementationName()) + .append("[key=").append(key.toString()).append(']'); + } + + @Override + public String toString() { + return new LogOutputStringImpl().append(this).toString(); + } + + @Override + public String getImplementationName() { + return "PythonTableDataService.TableKeyImpl"; + } + } + + /** + * {@link TableLocationProvider} implementation for TableService. + */ + private class TableLocationProviderImpl extends AbstractTableLocationProvider { + + private final TableDefinition tableDefinition; + + private Subscription subscription = null; + + private TableLocationProviderImpl(@NotNull final TableKeyImpl tableKey) { + super(tableKey, true, TableUpdateMode.APPEND_ONLY, TableUpdateMode.APPEND_ONLY); + final BarrageUtil.ConvertedArrowSchema[] schemas = backend.getTableSchema(tableKey); + + final TableDefinition partitioningDef = schemas[0].tableDef; + final TableDefinition tableDataDef = schemas[1].tableDef; + final Map> columns = new LinkedHashMap<>( + partitioningDef.numColumns() + tableDataDef.numColumns()); + + // all partitioning columns default to the front + for (final ColumnDefinition column : partitioningDef.getColumns()) { + columns.put(column.getName(), column.withPartitioning()); + } + + for (final ColumnDefinition column : tableDataDef.getColumns()) { + final ColumnDefinition existingDef = columns.get(column.getName()); + + if (existingDef == null) { + columns.put(column.getName(), column); + } else if (!existingDef.isCompatible(column)) { + // validate that both definitions are the same + throw new IllegalArgumentException(String.format("%s: column %s has conflicting definitions in " + + "partitioning and table data schemas: %s vs %s", tableKey, column.getName(), + existingDef, column)); + } + } + + tableDefinition = TableDefinition.of(columns.values()); + } + + @Override + protected @NotNull TableLocation makeTableLocation(@NotNull final TableLocationKey locationKey) { + if (!(locationKey instanceof TableLocationKeyImpl)) { + throw new IllegalArgumentException(String.format("%s: Unsupported TableLocationKey %s", this, + locationKey)); + } + return new TableLocationImpl((TableKeyImpl) getKey(), (TableLocationKeyImpl) locationKey); + } + + @Override + public void refresh() { + TableKeyImpl key = (TableKeyImpl) getKey(); + backend.getTableLocations(tableDefinition, key, this::handleTableLocationKeyAdded); + } + + @Override + protected void activateUnderlyingDataSource() { + TableKeyImpl key = (TableKeyImpl) getKey(); + final Subscription localSubscription = subscription = new Subscription(); + localSubscription.cancellationCallback = backend.subscribeToTableLocations( + tableDefinition, key, this::handleTableLocationKeyAdded, + () -> activationSuccessful(localSubscription), + error -> activationFailed(localSubscription, new TableDataException( + String.format("%s: subscribe_to_table_locations failed", key), error))); + } + + @Override + protected void deactivateUnderlyingDataSource() { + final Subscription localSubscription = subscription; + subscription = null; + if (localSubscription != null) { + localSubscription.cancellationCallback.close(); + } + } + + @Override + protected boolean matchSubscriptionToken(final T token) { + return token == subscription; + } + + @Override + public String getImplementationName() { + return "PythonTableDataService.TableLocationProvider"; + } + } + + /** + * {@link TableLocationKey} implementation for TableService. + */ + public static class TableLocationKeyImpl extends PartitionedTableLocationKey { + + private final PyObject locationKey; + private int cachedHashCode; + + /** + * Construct a TableLocationKeyImpl. Used by the Python adapter. + * + * @param locationKey the location key + */ + @ScriptApi + public TableLocationKeyImpl(@NotNull final PyObject locationKey) { + this(locationKey, Map.of()); + } + + private TableLocationKeyImpl( + @NotNull final PyObject locationKey, + @NotNull final Map> partitionValues) { + super(partitionValues); + this.locationKey = locationKey; + } + + @Override + public boolean equals(final Object other) { + if (this == other) { + return true; + } + if (!(other instanceof TableLocationKeyImpl)) { + return false; + } + final TableLocationKeyImpl otherTyped = (TableLocationKeyImpl) other; + return partitions.equals((otherTyped).partitions) && locationKey.equals(otherTyped.locationKey); + } + + @Override + public int hashCode() { + if (cachedHashCode == 0) { + final int computedHashCode = + 31 * partitions.hashCode() + Long.hashCode(locationKey.call("__hash__").getLongValue()); + // Don't use 0; that's used by StandaloneTableLocationKey, and also our sentinel for the need to compute + if (computedHashCode == 0) { + final int fallbackHashCode = TableLocationKeyImpl.class.hashCode(); + cachedHashCode = fallbackHashCode == 0 ? 1 : fallbackHashCode; + } else { + cachedHashCode = computedHashCode; + } + } + return cachedHashCode; + } + + @Override + public int compareTo(@NotNull final TableLocationKey other) { + if (getClass() != other.getClass()) { + throw new ClassCastException(String.format("Cannot compare %s to %s", getClass(), other.getClass())); + } + final TableLocationKeyImpl otherTableLocationKey = (TableLocationKeyImpl) other; + return PartitionsComparator.INSTANCE.compare(partitions, otherTableLocationKey.partitions); + } + + @Override + public LogOutput append(@NotNull final LogOutput logOutput) { + return logOutput.append(getImplementationName()) + .append(":[key=").append(locationKey.toString()) + .append(", partitions=").append(PartitionsFormatter.INSTANCE, partitions) + .append(']'); + } + + @Override + public String toString() { + return new LogOutputStringImpl().append(this).toString(); + } + + @Override + public String getImplementationName() { + return "PythonTableDataService.TableLocationKeyImpl"; + } + } + + /** + * {@link TableLocation} implementation for TableService. + */ + public class TableLocationImpl extends AbstractTableLocation { + + volatile Subscription subscription = null; + + private long size; + + private TableLocationImpl( + @NotNull final TableKeyImpl tableKey, + @NotNull final TableLocationKeyImpl locationKey) { + super(tableKey, locationKey, true); + } + + private synchronized void onSizeChanged(final long newSize) { + if (size >= newSize) { + return; + } + size = newSize; + handleUpdate(RowSetFactory.flat(size), System.currentTimeMillis()); + } + + @Override + protected @NotNull ColumnLocation makeColumnLocation(@NotNull final String name) { + return new ColumnLocationImpl(this, name); + } + + @Override + public void refresh() { + final TableKeyImpl key = (TableKeyImpl) getTableKey(); + final TableLocationKeyImpl location = (TableLocationKeyImpl) getKey(); + backend.getTableLocationSize(key, location, this::onSizeChanged); + } + + @Override + public @NotNull List getSortedColumns() { + return List.of(); + } + + @Override + public @NotNull List getDataIndexColumns() { + return List.of(); + } + + @Override + public boolean hasDataIndex(@NotNull final String... columns) { + return false; + } + + @Override + public @Nullable BasicDataIndex loadDataIndex(@NotNull final String... columns) { + return null; + } + + @Override + protected void activateUnderlyingDataSource() { + final TableKeyImpl key = (TableKeyImpl) getTableKey(); + final TableLocationKeyImpl location = (TableLocationKeyImpl) getKey(); + + final Subscription localSubscription = subscription = new Subscription(); + final LongConsumer subscriptionFilter = newSize -> { + if (localSubscription != subscription) { + // we've been cancelled and/or replaced + return; + } + + onSizeChanged(newSize); + }; + localSubscription.cancellationCallback = backend.subscribeToTableLocationSize( + key, location, subscriptionFilter, () -> activationSuccessful(localSubscription), + errorString -> activationFailed(localSubscription, new TableDataException(errorString))); + } + + @Override + protected void deactivateUnderlyingDataSource() { + final Subscription localSubscription = subscription; + subscription = null; + if (localSubscription != null) { + localSubscription.cancellationCallback.close(); + } + } + + @Override + protected boolean matchSubscriptionToken(final T token) { + return token == subscription; + } + + @Override + public String getImplementationName() { + return "PythonTableDataService.TableLocationImpl"; + } + } + + /** + * {@link ColumnLocation} implementation for TableService. + */ + public class ColumnLocationImpl extends AbstractColumnLocation { + + protected ColumnLocationImpl( + @NotNull final PythonTableDataService.TableLocationImpl tableLocation, + @NotNull final String name) { + super(tableLocation, name); + } + + @Override + public boolean exists() { + // Schema is consistent across all column locations with the same segment ID. This implementation should be + // changed when/if we add support for rich schema evolution. + return true; + } + + @Override + public ColumnRegionChar makeColumnRegionChar( + @NotNull final ColumnDefinition columnDefinition) { + return new AppendOnlyFixedSizePageRegionChar<>(REGION_MASK, pageSize, + new TableServiceGetRangeAdapter(columnDefinition)); + } + + @Override + public ColumnRegionByte makeColumnRegionByte( + @NotNull final ColumnDefinition columnDefinition) { + return new AppendOnlyFixedSizePageRegionByte<>(REGION_MASK, pageSize, + new TableServiceGetRangeAdapter(columnDefinition)); + } + + @Override + public ColumnRegionShort makeColumnRegionShort( + @NotNull final ColumnDefinition columnDefinition) { + return new AppendOnlyFixedSizePageRegionShort<>(REGION_MASK, pageSize, + new TableServiceGetRangeAdapter(columnDefinition)); + } + + @Override + public ColumnRegionInt makeColumnRegionInt( + @NotNull final ColumnDefinition columnDefinition) { + return new AppendOnlyFixedSizePageRegionInt<>(REGION_MASK, pageSize, + new TableServiceGetRangeAdapter(columnDefinition)); + + } + + @Override + public ColumnRegionLong makeColumnRegionLong( + @NotNull final ColumnDefinition columnDefinition) { + return new AppendOnlyFixedSizePageRegionLong<>(REGION_MASK, pageSize, + new TableServiceGetRangeAdapter(columnDefinition)); + + } + + @Override + public ColumnRegionFloat makeColumnRegionFloat( + @NotNull final ColumnDefinition columnDefinition) { + return new AppendOnlyFixedSizePageRegionFloat<>(REGION_MASK, pageSize, + new TableServiceGetRangeAdapter(columnDefinition)); + } + + @Override + public ColumnRegionDouble makeColumnRegionDouble( + @NotNull final ColumnDefinition columnDefinition) { + return new AppendOnlyFixedSizePageRegionDouble<>(REGION_MASK, pageSize, + new TableServiceGetRangeAdapter(columnDefinition)); + } + + @Override + public ColumnRegionObject makeColumnRegionObject( + @NotNull final ColumnDefinition columnDefinition) { + return new AppendOnlyFixedSizePageRegionObject<>(REGION_MASK, pageSize, + new TableServiceGetRangeAdapter(columnDefinition)); + } + + private class TableServiceGetRangeAdapter implements AppendOnlyRegionAccessor { + private final @NotNull ColumnDefinition columnDefinition; + + public TableServiceGetRangeAdapter(@NotNull ColumnDefinition columnDefinition) { + this.columnDefinition = columnDefinition; + } + + @Override + public void readChunkPage( + final long firstRowPosition, + final int minimumSize, + @NotNull final WritableChunk destination) { + final TableLocationImpl location = (TableLocationImpl) getTableLocation(); + final TableKeyImpl key = (TableKeyImpl) location.getTableKey(); + + final List> values = backend.getColumnValues( + key, (TableLocationKeyImpl) location.getKey(), columnDefinition, + firstRowPosition, minimumSize, destination.capacity()); + + final int numRows = values.stream().mapToInt(WritableChunk::size).sum(); + + if (numRows < minimumSize) { + throw new TableDataException(String.format("%s:%s: column_values(%s, %d, %d, %d) did not return " + + "enough data. Read %d rows but expected row range was %d to %d.", + key, location, columnDefinition.getName(), firstRowPosition, minimumSize, + destination.capacity(), numRows, minimumSize, destination.capacity())); + } + if (numRows > destination.capacity()) { + throw new TableDataException(String.format("%s:%s: column_values(%s, %d, %d, %d) returned too much " + + "data. Read %d rows but maximum allowed is %d.", key, location, + columnDefinition.getName(), firstRowPosition, minimumSize, destination.capacity(), numRows, + destination.capacity())); + } + + int offset = 0; + for (final Chunk rbChunk : values) { + final int length = Math.min(destination.capacity() - offset, rbChunk.size()); + destination.copyFromChunk(rbChunk, 0, offset, length); + offset += length; + } + destination.setSize(offset); + } + + @Override + public long size() { + return getTableLocation().getSize(); + } + } + } + + private static class Subscription { + SafeCloseable cancellationCallback; + } + + /** + * Helper used to simplify backend asynchronous RPC patterns for otherwise synchronous operations. + */ + private static class AsyncState { + private T result; + private Exception error; + + public synchronized void setResult(final T result) { + if (this.result != null) { + throw new IllegalStateException("Callback can only be called once"); + } + if (result == null) { + throw new IllegalArgumentException("Callback invoked with null result"); + } + this.result = result; + notifyAll(); + } + + public synchronized void setError(@NotNull final Exception error) { + if (this.error == null) { + this.error = error; + } + notifyAll(); + } + + public synchronized T awaitResult(@NotNull final Function errorMapper) { + while (result == null && error == null) { + try { + wait(); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw errorMapper.apply(e); + } + } + if (error != null) { + throw errorMapper.apply(error); + } + return result; + } + } +} diff --git a/py/server/deephaven/experimental/table_data_service.py b/py/server/deephaven/experimental/table_data_service.py new file mode 100644 index 00000000000..5613d0903b2 --- /dev/null +++ b/py/server/deephaven/experimental/table_data_service.py @@ -0,0 +1,455 @@ +# +# Copyright (c) 2016-2024 Deephaven Data Labs and Patent Pending +# +"""This module defines a table service backend interface TableDataServiceBackend that users can implement to provide +external data in the format of pyarrow Table to Deephaven tables. The backend service implementation should be passed +to the TableDataService constructor to create a new TableDataService instance. The TableDataService instance can then +be used to create Deephaven tables backed by the backend service.""" +import traceback +from abc import ABC, abstractmethod +from typing import Optional, Callable + +import jpy + +import pyarrow as pa + +from deephaven.dherror import DHError +from deephaven._wrapper import JObjectWrapper +from deephaven.table import Table + +_JPythonTableDataService = jpy.get_type("io.deephaven.extensions.barrage.util.PythonTableDataService") +_JTableKeyImpl = jpy.get_type("io.deephaven.extensions.barrage.util.PythonTableDataService$TableKeyImpl") +_JTableLocationKeyImpl = jpy.get_type("io.deephaven.extensions.barrage.util.PythonTableDataService$TableLocationKeyImpl") + + +class TableKey(ABC): + """A key that identifies a table. The key should be unique for each table. The key can be any Python object and + should include sufficient information to uniquely identify the table for the backend service. The __hash__ method + must be implemented to ensure that the key is hashable. + """ + + @abstractmethod + def __hash__(self): + pass + + +class TableLocationKey(ABC): + """A key that identifies a specific location of a table. The key should be unique for each table location of the + table. The key can be any Python object and should include sufficient information to uniquely identify the location + for the backend service to fetch the data values and data size. The __hash__ method must be implemented to ensure + that the key is hashable. + """ + + @abstractmethod + def __hash__(self): + pass + + +class TableDataServiceBackend(ABC): + """An interface for a backend service that provides access to table data.""" + + @abstractmethod + def table_schema(self, table_key: TableKey, + schema_cb: Callable[[pa.Schema, Optional[pa.Schema]], None], + failure_cb: Callable[[Exception], None]) -> None: + """ Provides the table data schema and the partitioning column schema for the table with the given table key via + the schema_cb callback. The table data schema is not required to include the partitioning columns defined in + the partitioning column schema. + + The failure callback should be invoked when a failure to provide the schemas occurs. + + The table_schema caller will block until one of the schema or failure callbacks is called. + + Note that asynchronous calls to any callback may block until this method has returned. + + Args: + table_key (TableKey): the table key + schema_cb (Callable[[pa.Schema, Optional[pa.Schema]], None]): the callback function with two arguments: the + table data schema and the optional partitioning column schema + failure_cb (Callable[[Exception], None]): the failure callback function + """ + pass + + @abstractmethod + def table_locations(self, table_key: TableKey, + location_cb: Callable[[TableLocationKey, Optional[pa.Table]], None], + success_cb: Callable[[], None], + failure_cb: Callable[[Exception], None]) -> None: + """ Provides the existing table locations for the table with the given table via the location_cb callback. + + The location callback should be called with the table location key and an optional pyarrow.Table that contains + the partitioning values for the location. The schema of the table must match the optional partitioning column + schema returned by :meth:`table_schema` for the table_key. The table must have a single row for the particular + table location key provided in the 1st argument, with values for each partitioning column in the row. + + The success callback should be called when all existing table locations have been delivered to the table + location callback. + + The failure callback should be invoked when failure to provide existing table locations occurs. + + The table_locations caller will block until one of the success or failure callbacks is called. + + This is called for tables created when :meth:`TableDataService.make_table` is called with refreshing=False + + Note that asynchronous calls to any callback may block until this method has returned. + + Args: + table_key (TableKey): the table key + location_cb (Callable[[TableLocationKey, Optional[pa.Table]], None]): the callback function + success_cb (Callable[[], None]): the success callback function + failure_cb (Callable[[Exception], None]): the failure callback function + """ + pass + + @abstractmethod + def subscribe_to_table_locations(self, table_key: TableKey, + location_cb: Callable[[TableLocationKey, Optional[pa.Table]], None], + success_cb: Callable[[], None], + failure_cb: Callable[[Exception], None]) -> Callable[[], None]: + """ Provides the table locations, existing and new, for the table with the given table key via the location_cb + callback. + + The location callback should be called with the table location key and an optional pyarrow.Table that contains + the partitioning values for the location. The schema of the table must match the optional partitioning column + schema returned by :meth:`table_schema` for the table_key. The table must have a single row for the particular + table location key provided in the 1st argument, with values for each partitioning column in the row. + + The success callback should be called when the subscription is established successfully and after all existing + table locations have been delivered to the table location callback. + + The failure callback should be invoked at initial failure to establish a subscription, or on a permanent failure + to keep the subscription active (e.g. failure with no reconnection possible, or failure to reconnect/resubscribe + before a timeout). + + This is called for tables created when :meth:`TableDataService.make_table` is called with refreshing=True. + + Note that asynchronous calls to any callback will block until this method has returned. + + Args: + table_key (TableKey): the table key + location_cb (Callable[[TableLocationKey, Optional[pa.Table]], None]): the table location callback function + success_cb (Callable[[], None]): the success callback function + failure_cb (Callable[[Exception], None]): the failure callback function + + + Returns: + Callable[[], None]: a function that can be called to unsubscribe from this subscription + """ + pass + + @abstractmethod + def table_location_size(self, table_key: TableKey, table_location_key: TableLocationKey, + size_cb: Callable[[int], None], + failure_cb: Callable[[Exception], None]) -> None: + """ Provides the size of the table location with the given table key and table location key via the size_cb + callback. The size is the number of rows in the table location. + + The failure callback should be invoked when a failure to provide the table location size occurs. + + The table_location_size caller will block until one of the size or failure callbacks is called. + + This is called for tables created when :meth:`TableDataService.make_table` is called with refreshing=False. + + Note that asynchronous calls to any callback may block until this method has returned. + + Args: + table_key (TableKey): the table key + table_location_key (TableLocationKey): the table location key + size_cb (Callable[[int], None]): the callback function + """ + pass + + @abstractmethod + def subscribe_to_table_location_size(self, table_key: TableKey, table_location_key: TableLocationKey, + size_cb: Callable[[int], None], + success_cb: Callable[[], None], + failure_cb: Callable[[Exception], None]) -> Callable[[], None]: + """ Provides the current and future sizes of the table location with the given table key and table location + key via the size_cb callback. The size is the number of rows in the table location. + + The success callback should be called when the subscription is established successfully and after the current + table location size has been delivered to the size callback. + + The failure callback should be invoked at initial failure to establish a subscription, or on a permanent failure + to keep the subscription active (e.g. failure with no reconnection possible, or failure to reconnect/resubscribe + before a timeout). + + This is called for tables created when :meth:``TableDataService.make_table` is called with refreshing=True + + Note that asynchronous calls to any callback will block until this method has returned. + + Args: + table_key (TableKey): the table key + table_location_key (TableLocationKey): the table location key + size_cb (Callable[[int], None]): the table location size callback function + success_cb (Callable[[], None]): the success callback function + failure_cb (Callable[[Exception], None]): the failure callback function + + Returns: + Callable[[], None]: a function that can be called to unsubscribe from this subscription + """ + pass + + @abstractmethod + def column_values(self, table_key: TableKey, table_location_key: TableLocationKey, col: str, offset: int, + min_rows: int, max_rows: int, + values_cb: Callable[[pa.Table], None], + failure_cb: Callable[[Exception], None]) -> None: + """ Provides the data values for the column with the given name for the table location with the given table key + and table location key via the values_cb callback. The column values are provided as a pyarrow.Table that + contains the data values for the column within the specified range requirement. The values_cb callback should be + called with a single column pyarrow.Table that contains the data values for the given column within the + specified range requirement. + + The failure callback should be invoked when a failure to provide the column values occurs. + + The column_values caller will block until one of the values or failure callbacks is called. + + Note that asynchronous calls to any callback may block until this method has returned. + + Args: + table_key (TableKey): the table key + table_location_key (TableLocationKey): the table location key + col (str): the column name + offset (int): the starting row index + min_rows (int): the minimum number of rows to return, min_rows is always <= page size + max_rows (int): the maximum number of rows to return + values_cb (Callable[[pa.Table], None]): the callback function with one argument: the pyarrow.Table that + contains the data values for the column within the specified range + failure_cb (Callable[[Exception], None]): the failure callback function + """ + pass + + +class TableDataService(JObjectWrapper): + """ A TableDataService serves as a wrapper around a tightly-coupled Deephaven TableDataService implementation + (Java class PythonTableDataService) that delegates to a Python TableDataServiceBackend for TableKey creation, + TableLocationKey discovery, and data subscription/retrieval operations. It supports the creation of Deephaven tables + from the Python backend service that provides table data and table data locations to the Deephaven tables. + """ + j_object_type = _JPythonTableDataService + _backend: TableDataServiceBackend + + def __init__(self, backend: TableDataServiceBackend, *, chunk_reader_factory: Optional[jpy.JType] = None, + stream_reader_options: Optional[jpy.JType] = None, page_size: Optional[int] = None): + """ Creates a new TableDataService with the given user-implemented backend service. + + Args: + backend (TableDataServiceBackend): the user-implemented backend service implementation + chunk_reader_factory (Optional[jpy.JType]): the Barrage chunk reader factory, default is None + stream_reader_options (Optional[jpy.JType]): the Barrage stream reader options, default is None + page_size (int): the page size for the table service, default is None, meaning to use the configurable + jvm property: PythonTableDataService.defaultPageSize which defaults to 64K. + """ + self._backend = backend + + if page_size is None: + page_size = 0 + elif page_size < 0: + raise ValueError("The page size must be non-negative") + + self._j_tbl_service = _JPythonTableDataService.create( + self, chunk_reader_factory, stream_reader_options, page_size) + + @property + def j_object(self): + return self._j_tbl_service + + def make_table(self, table_key: TableKey, *, refreshing: bool) -> Table: + """ Creates a Table backed by the backend service with the given table key. + + Args: + table_key (TableKey): the table key + refreshing (bool): whether the table is live or static + + Returns: + Table: a new table + + Raises: + DHError + """ + j_table_key = _JTableKeyImpl(table_key) + try: + return Table(self._j_tbl_service.makeTable(j_table_key, refreshing)) + except Exception as e: + raise DHError(e, message=f"failed to make a table for the key {table_key}") from e + + def _table_schema(self, table_key: TableKey, schema_cb: jpy.JType, failure_cb: jpy.JType) -> None: + """ Provides the table data schema and the partitioning values schema for the table with the given table key as + two serialized byte buffers to the PythonTableDataService (Java) via callbacks. Only called by the + PythonTableDataService. + + Args: + table_key (TableKey): the table key + schema_cb (jpy.JType): the Java callback function with one argument: an array of byte buffers that contain + the serialized table data arrow and partitioning values schemas + failure_cb (jpy.JType): the failure Java callback function with one argument: an exception stringyy + """ + def schema_cb_proxy(dt_schema: pa.Schema, pc_schema: Optional[pa.Schema] = None): + j_dt_schema_bb = jpy.byte_buffer(dt_schema.serialize()) + pc_schema = pc_schema if pc_schema is not None else pa.schema([]) + j_pc_schema_bb = jpy.byte_buffer(pc_schema.serialize()) + schema_cb.accept(jpy.array("java.nio.ByteBuffer", [j_pc_schema_bb, j_dt_schema_bb])) + + def failure_cb_proxy(error: Exception): + message = error.getMessage() if hasattr(error, "getMessage") else str(error) + tb_str = traceback.format_exc() + failure_cb.accept("\n".join([message, tb_str])) + + self._backend.table_schema(table_key, schema_cb_proxy, failure_cb_proxy) + + def _table_locations(self, table_key: TableKey, location_cb: jpy.JType, success_cb: jpy.JType, + failure_cb: jpy.JType) -> None: + """ Provides the existing table locations for the table with the given table key to the PythonTableDataService + (Java) via callbacks. Only called by the PythonTableDataService. + + Args: + table_key (TableKey): the table key + location_cb (jpy.JType): the Java callback function with two arguments: a table location key and an array of + byte buffers that contain the serialized arrow schema and a record batch of the partitioning values + success_cb (jpy.JType): the success Java callback function with no arguments + failure_cb (jpy.JType): the failure Java callback function with one argument: an exception string + """ + def location_cb_proxy(pt_location_key: TableLocationKey, pt_table: pa.Table): + j_tbl_location_key = _JTableLocationKeyImpl(pt_location_key) + if pt_table is None or pt_table.to_batches() is None: + location_cb.apply(j_tbl_location_key, jpy.array("java.nio.ByteBuffer", [])) + else: + if pt_table.num_rows != 1: + raise ValueError("The number of rows in the pyarrow table for partitioning values must be 1") + bb_list = [jpy.byte_buffer(rb.serialize()) for rb in pt_table.to_batches()] + bb_list.insert(0, jpy.byte_buffer(pt_table.schema.serialize())) + location_cb.accept(j_tbl_location_key, jpy.array("java.nio.ByteBuffer", bb_list)) + + def success_cb_proxy(): + success_cb.run() + + def failure_cb_proxy(error: Exception): + message = error.getMessage() if hasattr(error, "getMessage") else str(error) + tb_str = traceback.format_exc() + failure_cb.accept("\n".join([message, tb_str])) + + self._backend.table_locations(table_key, location_cb_proxy, success_cb_proxy, failure_cb_proxy) + + def _subscribe_to_table_locations(self, table_key: TableKey, location_cb: jpy.JType, success_cb: jpy.JType, + failure_cb: jpy.JType) -> Callable[[], None]: + """ Provides the table locations, existing and new, for the table with the given table key to the + PythonTableDataService (Java) via callbacks. Only called by the PythonTableDataService. + + Args: + table_key (TableKey): the table key + location_cb (jpy.JType): the Java callback function with two arguments: a table location key of the new + location and an array of byte buffers that contain the partitioning arrow schema and the serialized + record batches of the partitioning values + success_cb (jpy.JType): the success Java callback function with no arguments + failure_cb (jpy.JType): the failure Java callback function with one argument: an exception string + + Returns: + Callable[[], None]: a function that can be called to unsubscribe from this subscription + """ + def location_cb_proxy(pt_location_key: TableLocationKey, pt_table: pa.Table): + j_tbl_location_key = _JTableLocationKeyImpl(pt_location_key) + if pt_table is None: + location_cb.apply(j_tbl_location_key, jpy.array("java.nio.ByteBuffer", [])) + else: + if pt_table.num_rows != 1: + raise ValueError("The number of rows in the pyarrow table for partitioning column values must be 1") + bb_list = [jpy.byte_buffer(rb.serialize()) for rb in pt_table.to_batches()] + bb_list.insert(0, jpy.byte_buffer(pt_table.schema.serialize())) + location_cb.accept(j_tbl_location_key, jpy.array("java.nio.ByteBuffer", bb_list)) + + def success_cb_proxy(): + success_cb.run() + + def failure_cb_proxy(error: Exception): + message = error.getMessage() if hasattr(error, "getMessage") else str(error) + tb_str = traceback.format_exc() + failure_cb.accept("\n".join([message, tb_str])) + + return self._backend.subscribe_to_table_locations(table_key, location_cb_proxy, success_cb_proxy, + failure_cb_proxy) + + def _table_location_size(self, table_key: TableKey, table_location_key: TableLocationKey, size_cb: jpy.JType, + failure_cb: jpy.JType) -> None: + """ Provides the size of the table location with the given table key and table location key to the + PythonTableDataService (Java) via callbacks. Only called by the PythonTableDataService. + + Args: + table_key (TableKey): the table key + table_location_key (TableLocationKey): the table location key + size_cb (jpy.JType): the Java callback function with one argument: the size of the table location in number + of rows + failure_cb (jpy.JType): the failure Java callback function with one argument: an exception string + """ + def size_cb_proxy(size: int): + size_cb.accept(size) + + def failure_cb_proxy(error: Exception): + message = error.getMessage() if hasattr(error, "getMessage") else str(error) + tb_str = traceback.format_exc() + failure_cb.accept("\n".join([message, tb_str])) + + self._backend.table_location_size(table_key, table_location_key, size_cb_proxy, failure_cb_proxy) + + def _subscribe_to_table_location_size(self, table_key: TableKey, table_location_key: TableLocationKey, + size_cb: jpy.JType, success_cb: jpy.JType, failure_cb: jpy.JType) -> Callable[[], None]: + """ Provides the current and future sizes of the table location with the given table key and table location key + to the PythonTableDataService (Java) via callbacks. Only called by the PythonTableDataService. + + Args: + table_key (TableKey): the table key + table_location_key (TableLocationKey): the table location key + size_cb (jpy.JType): the Java callback function with one argument: the size of the location in number of + rows + success_cb (jpy.JType): the success Java callback function with no arguments + failure_cb (jpy.JType): the failure Java callback function with one argument: an exception string + + Returns: + Callable[[], None]: a function that can be called to unsubscribe from this subscription + """ + def size_cb_proxy(size: int): + size_cb.accept(size) + + def success_cb_proxy(): + success_cb.run() + + def failure_cb_proxy(error: Exception): + message = error.getMessage() if hasattr(error, "getMessage") else str(error) + tb_str = traceback.format_exc() + failure_cb.accept("\n".join([message, tb_str])) + + return self._backend.subscribe_to_table_location_size(table_key, table_location_key, size_cb_proxy, + success_cb_proxy, failure_cb_proxy) + + def _column_values(self, table_key: TableKey, table_location_key: TableLocationKey, col: str, offset: int, + min_rows: int, max_rows: int, values_cb: jpy.JType, failure_cb: jpy.JType) -> None: + """ Provides the data values for the column with the given name for the table column with the given table key + and table location key to the PythonTableDataService (Java) via callbacks. Only called by the + PythonTableDataService. + + Args: + table_key (TableKey): the table key + table_location_key (TableLocationKey): the table location key + col (str): the column name + offset (int): the starting row index + min_rows (int): the minimum number of rows to return, min_rows is always <= page size + max_rows (int): the maximum number of rows to return + values_cb (jpy.JType): the Java callback function with one argument: an array of byte buffers that contain + the arrow schema and the serialized record batches for the given column + failure_cb (jpy.JType): the failure Java callback function with one argument: an exception string + """ + def values_cb_proxy(pt_table: pa.Table): + if len(pt_table) < min_rows or len(pt_table) > max_rows: + raise ValueError("The number of rows in the pyarrow table for column values must be in the range of " + f"{min_rows} to {max_rows}") + bb_list = [jpy.byte_buffer(rb.serialize()) for rb in pt_table.to_batches()] + bb_list.insert(0, jpy.byte_buffer(pt_table.schema.serialize())) + values_cb.accept(jpy.array("java.nio.ByteBuffer", bb_list)) + + def failure_cb_proxy(error: Exception): + message = error.getMessage() if hasattr(error, "getMessage") else str(error) + tb_str = traceback.format_exc() + failure_cb.accept("\n".join([message, tb_str])) + + self._backend.column_values(table_key, table_location_key, col, offset, min_rows, max_rows, values_cb_proxy, + failure_cb_proxy) diff --git a/py/server/tests/test_table_data_service.py b/py/server/tests/test_table_data_service.py new file mode 100644 index 00000000000..6c524edf263 --- /dev/null +++ b/py/server/tests/test_table_data_service.py @@ -0,0 +1,331 @@ +# +# Copyright (c) 2016-2024 Deephaven Data Labs and Patent Pending +# + +import threading +import time +import unittest +from typing import Callable, Optional, Generator, Dict + +import numpy as np +import pyarrow as pa +import pyarrow.compute as pc + +from deephaven import new_table +from deephaven.column import byte_col, char_col, short_col, int_col, long_col, float_col, double_col, string_col, \ + datetime_col, bool_col, ColumnType +from deephaven.execution_context import get_exec_ctx, ExecutionContext +from deephaven.experimental.table_data_service import TableDataServiceBackend, TableKey, \ + TableLocationKey, TableDataService +import deephaven.arrow as dharrow +from deephaven.liveness_scope import liveness_scope + +from tests.testbase import BaseTestCase + + +class TableKeyImpl(TableKey): + def __init__(self, key: str): + self.key = key + + def __hash__(self): + return hash(self.key) + + +class TableLocationKeyImpl(TableLocationKey): + def __init__(self, key: str): + self.key = key + + def __hash__(self): + return hash(self.key) + + +class TestBackend(TableDataServiceBackend): + def __init__(self, gen_pa_table: Generator[pa.Table, None, None], pt_schema: pa.Schema, + pc_schema: Optional[pa.Schema] = None): + self.pt_schema: pa.Schema = pt_schema + self.pc_schema: pa.Schema = pc_schema + self.gen_pa_table: Generator = gen_pa_table + self.subscriptions_enabled_for_test: bool = True + self.sub_new_partition_cancelled: bool = False + self.sub_new_partition_fail_test: bool = False + self.sub_partition_size_fail_test: bool = False + self.partitions: Dict[TableLocationKey, pa.Table] = {} + self.partitions_size_subscriptions: Dict[TableLocationKey, bool] = {} + self.existing_partitions_called: int = 0 + self.partition_size_called: int = 0 + + def table_schema(self, table_key: TableKeyImpl, + schema_cb: Callable[[pa.Schema, Optional[pa.Schema]], None], + failure_cb: Callable[[str], None]) -> None: + if table_key.key == "test": + schema_cb(self.pt_schema, self.pc_schema) + else: + failure_cb("table key not found") + + def table_locations(self, table_key: TableKeyImpl, + location_cb: Callable[[TableLocationKeyImpl, Optional[pa.Table]], None], + success_cb: Callable[[], None], + failure_cb: Callable[[str], None]) -> None: + pa_table = next(self.gen_pa_table) + if table_key.key == "test": + ticker = str(pa_table.column("Ticker")[0]) + + partition_key = TableLocationKeyImpl(f"{ticker}/NYSE") + self.partitions[partition_key] = pa_table + + expr = ((pc.field("Ticker") == f"{ticker}") & (pc.field("Exchange") == "NYSE")) + location_cb(partition_key, pa_table.filter(expr).select(["Ticker", "Exchange"]).slice(0, 1)) + self.existing_partitions_called += 1 + + # indicate that we've finished notifying existing table locations + success_cb() + else: + failure_cb("table key not found") + + def table_location_size(self, table_key: TableKeyImpl, table_location_key: TableLocationKeyImpl, + size_cb: Callable[[int], None], + failure_cb: Callable[[str], None]) -> None: + size_cb(self.partitions[table_location_key].num_rows) + self.partition_size_called += 1 + + def column_values(self, table_key: TableKeyImpl, table_location_key: TableLocationKeyImpl, + col: str, offset: int, min_rows: int, max_rows: int, + values_cb: Callable[[pa.Table], None], + failure_cb: Callable[[str], None]) -> None: + if table_key.key == "test": + values_cb(self.partitions[table_location_key].select([col]).slice(offset, max_rows)) + else: + failure_cb("table key not found") + + def _th_new_partitions(self, table_key: TableKeyImpl, exec_ctx: ExecutionContext, + location_cb: Callable[[TableLocationKeyImpl, Optional[pa.Table]], None], + failure_cb: Callable[[Exception], None]) -> None: + if table_key.key != "test": + return + + while not self.sub_new_partition_cancelled and self.subscriptions_enabled_for_test: + try: + with exec_ctx: + pa_table = next(self.gen_pa_table) + except StopIteration: + break + + ticker = str(pa_table.column("Ticker")[0]) + partition_key = TableLocationKeyImpl(f"{ticker}/NYSE") + self.partitions[partition_key] = pa_table + + expr = ((pc.field("Ticker") == f"{ticker}") & (pc.field("Exchange") == "NYSE")) + location_cb(partition_key, pa_table.filter(expr).select(["Ticker", "Exchange"]).slice(0, 1)) + if self.sub_new_partition_fail_test: + failure_cb(Exception("table location subscription failure")) + return + time.sleep(0.1) + + def subscribe_to_table_locations(self, table_key: TableKeyImpl, + location_cb: Callable[[TableLocationKeyImpl, Optional[pa.Table]], None], + success_cb: Callable[[], None], failure_cb: Callable[[str], None]) -> Callable[[], None]: + if table_key.key != "test": + return lambda: None + + # simulate an existing partition + pa_table = next(self.gen_pa_table) + if table_key.key == "test": + ticker = str(pa_table.column("Ticker")[0]) + + partition_key = TableLocationKeyImpl(f"{ticker}/NYSE") + self.partitions[partition_key] = pa_table + + expr = ((pc.field("Ticker") == f"{ticker}") & (pc.field("Exchange") == "NYSE")) + location_cb(partition_key, pa_table.filter(expr).select(["Ticker", "Exchange"]).slice(0, 1)) + + exec_ctx = get_exec_ctx() + th = threading.Thread(target=self._th_new_partitions, args=(table_key, exec_ctx, location_cb, failure_cb)) + th.start() + + def _cancellation_callback(): + self.sub_new_partition_cancelled = True + + success_cb() + return _cancellation_callback + + def _th_partition_size_changes(self, table_key: TableKeyImpl, table_location_key: TableLocationKeyImpl, + size_cb: Callable[[int], None], + failure_cb: Callable[[Exception], None] + ) -> None: + if table_key.key != "test": + return + + if table_location_key not in self.partitions_size_subscriptions: + return + + while self.subscriptions_enabled_for_test and self.partitions_size_subscriptions[table_location_key]: + pa_table = self.partitions[table_location_key] + rbs = pa_table.to_batches() + rbs.append(pa_table.to_batches()[0]) + new_pa_table = pa.Table.from_batches(rbs) + self.partitions[table_location_key] = new_pa_table + size_cb(new_pa_table.num_rows) + if self.sub_partition_size_fail_test: + failure_cb(Exception("table location size subscription failure")) + return + time.sleep(0.1) + + def subscribe_to_table_location_size(self, table_key: TableKeyImpl, + table_location_key: TableLocationKeyImpl, + size_cb: Callable[[int], None], + success_cb: Callable[[], None], failure_cb: Callable[[str], None] + ) -> Callable[[], None]: + if table_key.key != "test": + return lambda: None + + if table_location_key not in self.partitions: + return lambda: None + + # need to initial size + size_cb(self.partitions[table_location_key].num_rows) + + self.partitions_size_subscriptions[table_location_key] = True + th = threading.Thread(target=self._th_partition_size_changes, args=(table_key, table_location_key, size_cb, + failure_cb)) + th.start() + + def _cancellation_callback(): + self.partitions_size_subscriptions[table_location_key] = False + + success_cb() + return _cancellation_callback + + +class TableDataServiceTestCase(BaseTestCase): + tickers = ["AAPL", "FB", "GOOG", "MSFT", "NVDA", "TMSC", "TSLA", "VZ", "WMT", "XOM"] + + def gen_pa_table(self) -> Generator[pa.Table, None, None]: + for tikcer in self.tickers: + cols = [ + string_col(name="Ticker", data=[tikcer, tikcer]), + string_col(name="Exchange", data=["NYSE", "NYSE"]), + bool_col(name="Boolean", data=[True, False]), + byte_col(name="Byte", data=(1, -1)), + char_col(name="Char", data='-1'), + short_col(name="Short", data=[1, -1]), + int_col(name="Int", data=[1, -1]), + long_col(name="Long", data=[1, -1]), + long_col(name="NPLong", data=np.array([1, -1], dtype=np.int8)), + float_col(name="Float", data=[1.01, -1.01]), + double_col(name="Double", data=[1.01, -1.01]), + string_col(name="String", data=["foo", "bar"]), + datetime_col(name="Datetime", data=[1, -1]), + ] + yield dharrow.to_arrow(new_table(cols=cols)) + + def setUp(self) -> None: + super().setUp() + self.pa_table = next(self.gen_pa_table()) + self.test_table = dharrow.to_table(self.pa_table) + + def test_make_table_without_partition_schema(self): + backend = TestBackend(self.gen_pa_table(), pt_schema=self.pa_table.schema) + data_service = TableDataService(backend) + table = data_service.make_table(TableKeyImpl("test"), refreshing=False) + self.assertIsNotNone(table) + self.assertEqual(table.columns, self.test_table.columns) + + def test_make_static_table_with_partition_schema(self): + pc_schema = pa.schema( + [pa.field(name="Ticker", type=pa.string()), pa.field(name="Exchange", type=pa.string())]) + backend = TestBackend(self.gen_pa_table(), pt_schema=self.pa_table.schema, pc_schema=pc_schema) + data_service = TableDataService(backend) + table = data_service.make_table(TableKeyImpl("test"), refreshing=False) + self.assertIsNotNone(table) + self.assertTrue(table.columns[0].column_type == ColumnType.PARTITIONING) + self.assertTrue(table.columns[1].column_type == ColumnType.PARTITIONING) + self.assertEqual(table.columns[2:], self.test_table.columns[2:]) + self.assertEqual(table.size, 2) + self.assertEqual(backend.existing_partitions_called, 1) + self.assertEqual(backend.partition_size_called, 1) + + def test_make_live_table_with_partition_schema(self): + pc_schema = pa.schema( + [pa.field(name="Ticker", type=pa.string()), pa.field(name="Exchange", type=pa.string())]) + backend = TestBackend(self.gen_pa_table(), pt_schema=self.pa_table.schema, pc_schema=pc_schema) + data_service = TableDataService(backend) + table = data_service.make_table(TableKeyImpl("test"), refreshing=True) + self.assertIsNotNone(table) + self.assertTrue(table.columns[0].column_type == ColumnType.PARTITIONING) + self.assertTrue(table.columns[1].column_type == ColumnType.PARTITIONING) + self.assertEqual(table.columns[2:], self.test_table.columns[2:]) + + self.wait_ticking_table_update(table, 20, 5) + + self.assertGreaterEqual(table.size, 20) + self.assertEqual(backend.existing_partitions_called, 0) + self.assertEqual(backend.partition_size_called, 0) + + def test_make_live_table_with_partition_schema_ops(self): + pc_schema = pa.schema( + [pa.field(name="Ticker", type=pa.string()), pa.field(name="Exchange", type=pa.string())]) + backend = TestBackend(self.gen_pa_table(), pt_schema=self.pa_table.schema, pc_schema=pc_schema) + data_service = TableDataService(backend) + table = data_service.make_table(TableKeyImpl("test"), refreshing=True) + self.assertIsNotNone(table) + self.assertTrue(table.columns[0].column_type == ColumnType.PARTITIONING) + self.assertTrue(table.columns[1].column_type == ColumnType.PARTITIONING) + self.assertEqual(table.columns[2:], self.test_table.columns[2:]) + self.wait_ticking_table_update(table, 100, 5) + self.assertGreaterEqual(table.size, 100) + + t = table.select_distinct([c.name for c in table.columns]) + self.assertGreaterEqual(t.size, len(self.tickers)) + # t doesn't have the partitioning columns + self.assertEqual(t.columns, self.test_table.columns) + + def test_make_live_table_observe_subscription_cancellations(self): + pc_schema = pa.schema( + [pa.field(name="Ticker", type=pa.string()), pa.field(name="Exchange", type=pa.string())]) + backend = TestBackend(self.gen_pa_table(), pt_schema=self.pa_table.schema, pc_schema=pc_schema) + data_service = TableDataService(backend) + with liveness_scope(): + table = data_service.make_table(TableKeyImpl("test"), refreshing=True) + self.wait_ticking_table_update(table, 100, 5) + self.assertTrue(backend.sub_new_partition_cancelled) + self.assertFalse(all(backend.partitions_size_subscriptions.values())) + + def test_make_live_table_ensure_initial_partitions_exist(self): + pc_schema = pa.schema( + [pa.field(name="Ticker", type=pa.string()), pa.field(name="Exchange", type=pa.string())]) + backend = TestBackend(self.gen_pa_table(), pt_schema=self.pa_table.schema, pc_schema=pc_schema) + backend.subscriptions_enabled_for_test = False + data_service = TableDataService(backend) + table = data_service.make_table(TableKeyImpl("test"), refreshing=True) + table.coalesce() + # the initial partitions should be created + self.assertEqual(table.size, 2) + + def test_partition_sub_failure(self): + pc_schema = pa.schema( + [pa.field(name="Ticker", type=pa.string()), pa.field(name="Exchange", type=pa.string())]) + backend = TestBackend(self.gen_pa_table(), pt_schema=self.pa_table.schema, pc_schema=pc_schema) + data_service = TableDataService(backend) + backend.sub_new_partition_fail_test = True + table = data_service.make_table(TableKeyImpl("test"), refreshing=True) + with self.assertRaises(Exception) as cm: + # failure_cb will be called in the background thread after 2 PUG cycles + self.wait_ticking_table_update(table, 600, 2) + self.assertTrue(table.j_table.isFailed()) + + def test_partition_size_sub_failure(self): + pc_schema = pa.schema( + [pa.field(name="Ticker", type=pa.string()), pa.field(name="Exchange", type=pa.string())]) + backend = TestBackend(self.gen_pa_table(), pt_schema=self.pa_table.schema, pc_schema=pc_schema) + data_service = TableDataService(backend) + backend.sub_partition_size_fail_test = True + table = data_service.make_table(TableKeyImpl("test"), refreshing=True) + with self.assertRaises(Exception) as cm: + # failure_cb will be called in the background thread after 2 PUG cycles + self.wait_ticking_table_update(table, 600, 2) + + self.assertTrue(table.j_table.isFailed()) + + +if __name__ == '__main__': + unittest.main() diff --git a/server/src/main/java/io/deephaven/server/arrow/ArrowFlightUtil.java b/server/src/main/java/io/deephaven/server/arrow/ArrowFlightUtil.java index 275c0fd650e..56f225ea3c7 100644 --- a/server/src/main/java/io/deephaven/server/arrow/ArrowFlightUtil.java +++ b/server/src/main/java/io/deephaven/server/arrow/ArrowFlightUtil.java @@ -216,7 +216,7 @@ public void onNext(final InputStream request) { } if (mi.header.headerType() == MessageHeader.Schema) { - parseSchema(mi.header); + configureWithSchema(parseArrowSchema(mi)); return; }