From 060e365b57a54b25f336e45036785351cd79bb15 Mon Sep 17 00:00:00 2001 From: Vanathi Ganeshraj Date: Mon, 29 Apr 2024 16:23:37 +0000 Subject: [PATCH] Add schema handling to remote task execution --- .../wrangler/api/RemoteDirectiveResponse.java | 50 +++++++++++++++++++ .../wrangler/schema/TransientStoreKeys.java | 1 + ...RowSerializer.java => KryoSerializer.java} | 24 +++++---- .../cdap/wrangler/utils/SchemaConverter.java | 3 +- ...lizerTest.java => KryoSerializerTest.java} | 43 +++++++++++++--- .../cdap/wrangler/utils/ObjectSerDeTest.java | 18 +++++++ .../directive/RemoteDirectiveRequest.java | 9 +++- .../directive/RemoteExecutionTask.java | 33 +++++++++--- .../service/directive/WorkspaceHandler.java | 18 +++++-- 9 files changed, 169 insertions(+), 30 deletions(-) create mode 100644 wrangler-api/src/main/java/io/cdap/wrangler/api/RemoteDirectiveResponse.java rename wrangler-core/src/main/java/io/cdap/wrangler/utils/{RowSerializer.java => KryoSerializer.java} (76%) rename wrangler-core/src/test/java/io/cdap/wrangler/utils/{RowSerializerTest.java => KryoSerializerTest.java} (64%) diff --git a/wrangler-api/src/main/java/io/cdap/wrangler/api/RemoteDirectiveResponse.java b/wrangler-api/src/main/java/io/cdap/wrangler/api/RemoteDirectiveResponse.java new file mode 100644 index 000000000..0627c28f5 --- /dev/null +++ b/wrangler-api/src/main/java/io/cdap/wrangler/api/RemoteDirectiveResponse.java @@ -0,0 +1,50 @@ +/* + * Copyright © 2024 Cask Data, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package io.cdap.wrangler.api; + +import io.cdap.cdap.api.data.schema.Schema; + +import java.io.Serializable; +import java.util.List; + +/** + * Response after executing directives remotely + * Please make sure all fields are registered with {@link io.cdap.wrangler.utils.KryoSerializer} + */ +public class RemoteDirectiveResponse implements Serializable { + private final List rows; + private final Schema outputSchema; + + /** + * Only used by {@link io.cdap.wrangler.utils.KryoSerializer} + **/ + private RemoteDirectiveResponse() { + this(null, null); + } + + public RemoteDirectiveResponse(List rows, Schema outputSchema) { + this.rows = rows; + this.outputSchema = outputSchema; + } + + public List getRows() { + return rows; + } + + public Schema getOutputSchema() { + return outputSchema; + } +} diff --git a/wrangler-core/src/main/java/io/cdap/wrangler/schema/TransientStoreKeys.java b/wrangler-core/src/main/java/io/cdap/wrangler/schema/TransientStoreKeys.java index e35ef803f..da89fdb3c 100644 --- a/wrangler-core/src/main/java/io/cdap/wrangler/schema/TransientStoreKeys.java +++ b/wrangler-core/src/main/java/io/cdap/wrangler/schema/TransientStoreKeys.java @@ -18,6 +18,7 @@ /** * TransientStoreKeys for storing Workspace schema in TransientStore + * NOTE: Please add any needed value in {@link io.cdap.wrangler.api.RemoteDirectiveResponse} */ public final class TransientStoreKeys { public static final String INPUT_SCHEMA = "ws_input_schema"; diff --git a/wrangler-core/src/main/java/io/cdap/wrangler/utils/RowSerializer.java b/wrangler-core/src/main/java/io/cdap/wrangler/utils/KryoSerializer.java similarity index 76% rename from wrangler-core/src/main/java/io/cdap/wrangler/utils/RowSerializer.java rename to wrangler-core/src/main/java/io/cdap/wrangler/utils/KryoSerializer.java index 6520412f9..0d13a7b08 100644 --- a/wrangler-core/src/main/java/io/cdap/wrangler/utils/RowSerializer.java +++ b/wrangler-core/src/main/java/io/cdap/wrangler/utils/KryoSerializer.java @@ -19,12 +19,15 @@ import com.esotericsoftware.kryo.Serializer; import com.esotericsoftware.kryo.io.Input; import com.esotericsoftware.kryo.io.Output; +import com.esotericsoftware.kryo.serializers.JavaSerializer; import com.google.gson.Gson; import com.google.gson.JsonArray; import com.google.gson.JsonElement; import com.google.gson.JsonNull; import com.google.gson.JsonObject; import com.google.gson.JsonPrimitive; +import io.cdap.cdap.api.data.schema.Schema; +import io.cdap.wrangler.api.RemoteDirectiveResponse; import io.cdap.wrangler.api.Row; import java.sql.Time; import java.sql.Timestamp; @@ -33,20 +36,24 @@ import java.time.ZonedDateTime; import java.util.ArrayList; import java.util.Date; -import java.util.List; import java.util.Map; /** * A helper class with allows Serialization and Deserialization using Kryo * We should register all schema classes present in {@link SchemaConverter} + * and {@link RemoteDirectiveResponse} **/ -public class RowSerializer { +public class KryoSerializer { private final Kryo kryo; private static final Gson GSON = new Gson(); - public RowSerializer() { + public KryoSerializer() { kryo = new Kryo(); + // Register all classes from RemoteDirectiveResponse + kryo.register(RemoteDirectiveResponse.class); + // Schema does not have no-arg constructor but implements Serializable + kryo.register(Schema.class, new JavaSerializer()); // Register all classes from SchemaConverter kryo.register(Row.class); kryo.register(ArrayList.class); @@ -56,7 +63,7 @@ public RowSerializer() { kryo.register(Map.class); kryo.register(JsonNull.class); // JsonPrimitive does not have no-arg constructor hence we need a - // custom serializer + // custom serializer as it is not serializable by JavaSerializer kryo.register(JsonPrimitive.class, new JsonSerializer()); kryo.register(JsonArray.class); kryo.register(JsonObject.class); @@ -67,16 +74,15 @@ public RowSerializer() { kryo.register(Timestamp.class); } - public byte[] fromRows(List rows) { + public byte[] fromRemoteDirectiveResponse(RemoteDirectiveResponse response) { Output output = new Output(1024, -1); - kryo.writeClassAndObject(output, rows); + kryo.writeClassAndObject(output, response); return output.getBuffer(); } - public List toRows(byte[] bytes) { + public RemoteDirectiveResponse toRemoteDirectiveResponse(byte[] bytes) { Input input = new Input(bytes); - List result = (List) kryo.readClassAndObject(input); - return result; + return (RemoteDirectiveResponse) kryo.readClassAndObject(input); } static class JsonSerializer extends Serializer { diff --git a/wrangler-core/src/main/java/io/cdap/wrangler/utils/SchemaConverter.java b/wrangler-core/src/main/java/io/cdap/wrangler/utils/SchemaConverter.java index 2dd00ef16..a658c8115 100644 --- a/wrangler-core/src/main/java/io/cdap/wrangler/utils/SchemaConverter.java +++ b/wrangler-core/src/main/java/io/cdap/wrangler/utils/SchemaConverter.java @@ -19,7 +19,6 @@ import com.google.gson.JsonArray; import com.google.gson.JsonElement; import com.google.gson.JsonObject; -import com.google.gson.JsonParser; import com.google.gson.JsonPrimitive; import io.cdap.cdap.api.data.schema.Schema; import io.cdap.cdap.api.data.schema.Schema.Field; @@ -100,7 +99,7 @@ public Schema getSchema(Object value, String name) throws RecordConvertorExcepti * @param name name of the field * @param recordPrefix prefix to append at the beginning of a custom record * @return the schema of this object - * NOTE: ANY NEWLY SUPPORTED DATATYPE SHOULD ALSO BE REGISTERED IN {@link RowSerializer} + * NOTE: ANY NEWLY SUPPORTED DATATYPE SHOULD ALSO BE REGISTERED IN {@link KryoSerializer} */ @Nullable public Schema getSchema(Object value, String name, @Nullable String recordPrefix) throws RecordConvertorException { diff --git a/wrangler-core/src/test/java/io/cdap/wrangler/utils/RowSerializerTest.java b/wrangler-core/src/test/java/io/cdap/wrangler/utils/KryoSerializerTest.java similarity index 64% rename from wrangler-core/src/test/java/io/cdap/wrangler/utils/RowSerializerTest.java rename to wrangler-core/src/test/java/io/cdap/wrangler/utils/KryoSerializerTest.java index 97b45b28c..10477465e 100644 --- a/wrangler-core/src/test/java/io/cdap/wrangler/utils/RowSerializerTest.java +++ b/wrangler-core/src/test/java/io/cdap/wrangler/utils/KryoSerializerTest.java @@ -17,8 +17,10 @@ import com.google.common.collect.Lists; import com.google.gson.JsonParser; +import io.cdap.cdap.api.data.schema.Schema; import io.cdap.wrangler.TestingRig; import io.cdap.wrangler.api.RecipePipeline; +import io.cdap.wrangler.api.RemoteDirectiveResponse; import io.cdap.wrangler.api.Row; import org.junit.Assert; import org.junit.Test; @@ -37,7 +39,7 @@ import java.util.Map; import java.util.Set; -public class RowSerializerTest { +public class KryoSerializerTest { private static final String[] TESTS = new String[]{ JsonTestData.BASIC, @@ -67,8 +69,9 @@ public void testJsonTypes() throws Exception { Row row = new Row("body", test); List expectedRows = executor.execute(Lists.newArrayList(row)); - byte[] serializedRows = new RowSerializer().fromRows(expectedRows); - List gotRows = new RowSerializer().toRows(serializedRows); + byte[] serializedRows = new KryoSerializer().fromRemoteDirectiveResponse( + new RemoteDirectiveResponse(expectedRows, null)); + List gotRows = new KryoSerializer().toRemoteDirectiveResponse(serializedRows).getRows(); Assert.assertArrayEquals(expectedRows.toArray(), gotRows.toArray()); } } @@ -84,8 +87,9 @@ public void testLogicalTypes() throws Exception { testRow.add("bigdecimal", new BigDecimal(new BigInteger("123456"), 5)); testRow.add("datetime", LocalDateTime.now()); List expectedRows = Collections.singletonList(testRow); - byte[] serializedRows = new RowSerializer().fromRows(expectedRows); - List gotRows = new RowSerializer().toRows(serializedRows); + byte[] serializedRows = new KryoSerializer().fromRemoteDirectiveResponse( + new RemoteDirectiveResponse(expectedRows, null)); + List gotRows = new KryoSerializer().toRemoteDirectiveResponse(serializedRows).getRows(); Assert.assertArrayEquals(expectedRows.toArray(), gotRows.toArray()); } @@ -110,8 +114,33 @@ public void testCollectionTypes() throws Exception { testRow.add("map", map); List expectedRows = Collections.singletonList(testRow); - byte[] serializedRows = new RowSerializer().fromRows(expectedRows); - List gotRows = new RowSerializer().toRows(serializedRows); + byte[] serializedRows = new KryoSerializer().fromRemoteDirectiveResponse( + new RemoteDirectiveResponse(expectedRows, null)); + List gotRows = new KryoSerializer().toRemoteDirectiveResponse(serializedRows).getRows(); Assert.assertArrayEquals(expectedRows.toArray(), gotRows.toArray()); } + + @Test + public void testWithSchema() throws Exception { + Row testRow = new Row(); + testRow.add("id", 1); + testRow.add("name", "abc"); + testRow.add("date", LocalDate.of(2018, 11, 11)); + testRow.add("time", LocalTime.of(11, 11, 11)); + testRow.add("timestamp", ZonedDateTime.of(2018, 11, 11, 11, 11, 11, 0, ZoneId.of("UTC"))); + testRow.add("bigdecimal", new BigDecimal(new BigInteger("123456"), 5)); + testRow.add("datetime", LocalDateTime.now()); + List expectedRows = Collections.singletonList(testRow); + + SchemaConverter converter = new SchemaConverter(); + Schema expectedSchema = converter.toSchema("myrecord", expectedRows.get(0)); + + byte[] serializedRows = new KryoSerializer().fromRemoteDirectiveResponse( + new RemoteDirectiveResponse(expectedRows, expectedSchema)); + RemoteDirectiveResponse response = new KryoSerializer().toRemoteDirectiveResponse( + serializedRows); + + Assert.assertArrayEquals(expectedRows.toArray(), response.getRows().toArray()); + Assert.assertEquals(expectedSchema, response.getOutputSchema()); + } } diff --git a/wrangler-core/src/test/java/io/cdap/wrangler/utils/ObjectSerDeTest.java b/wrangler-core/src/test/java/io/cdap/wrangler/utils/ObjectSerDeTest.java index 689668444..f297b23e8 100644 --- a/wrangler-core/src/test/java/io/cdap/wrangler/utils/ObjectSerDeTest.java +++ b/wrangler-core/src/test/java/io/cdap/wrangler/utils/ObjectSerDeTest.java @@ -17,6 +17,8 @@ package io.cdap.wrangler.utils; import com.google.common.base.Charsets; +import io.cdap.cdap.api.data.schema.Schema; +import io.cdap.wrangler.api.RemoteDirectiveResponse; import io.cdap.wrangler.api.Row; import org.junit.Assert; import org.junit.Test; @@ -85,4 +87,20 @@ public void testLogicalTypeSerDe() throws Exception { actualRows = objectSerDe.toObject(bytes); Assert.assertEquals(expectedRows.size(), actualRows.size()); } + @Test + public void testRemoteDirectiveResponseSerDe() throws Exception { + List expectedRows = new ArrayList<>(); + Row firstRow = new Row(); + firstRow.add("id", 1); + expectedRows.add(firstRow); + Schema expectedSchema = Schema.recordOf(Schema.Field.of("id", Schema.of(Schema.Type.INT))); + RemoteDirectiveResponse expectedResponse = new RemoteDirectiveResponse(expectedRows, expectedSchema); + ObjectSerDe objectSerDe = new ObjectSerDe<>(); + + byte[] bytes = objectSerDe.toByteArray(expectedResponse); + RemoteDirectiveResponse actualResponse = objectSerDe.toObject(bytes); + + Assert.assertEquals(expectedResponse.getRows().size(), actualResponse.getRows().size()); + Assert.assertEquals(expectedResponse.getOutputSchema(), actualResponse.getOutputSchema()); + } } diff --git a/wrangler-service/src/main/java/io/cdap/wrangler/service/directive/RemoteDirectiveRequest.java b/wrangler-service/src/main/java/io/cdap/wrangler/service/directive/RemoteDirectiveRequest.java index 9b77f23f3..d3e6d959f 100644 --- a/wrangler-service/src/main/java/io/cdap/wrangler/service/directive/RemoteDirectiveRequest.java +++ b/wrangler-service/src/main/java/io/cdap/wrangler/service/directive/RemoteDirectiveRequest.java @@ -15,6 +15,7 @@ */ package io.cdap.wrangler.service.directive; +import io.cdap.cdap.api.data.schema.Schema; import io.cdap.wrangler.parser.DirectiveClass; import java.util.HashMap; @@ -29,13 +30,15 @@ public class RemoteDirectiveRequest { private final Map systemDirectives; private final String pluginNameSpace; private final byte[] data; + private final Schema inputSchema; RemoteDirectiveRequest(String recipe, Map systemDirectives, - String pluginNameSpace, byte[] data) { + String pluginNameSpace, byte[] data, Schema inputSchema) { this.recipe = recipe; this.systemDirectives = new HashMap<>(systemDirectives); this.pluginNameSpace = pluginNameSpace; this.data = data; + this.inputSchema = inputSchema; } public String getRecipe() { @@ -53,4 +56,8 @@ public byte[] getData() { public String getPluginNameSpace() { return pluginNameSpace; } + + public Schema getInputSchema() { + return inputSchema; + } } diff --git a/wrangler-service/src/main/java/io/cdap/wrangler/service/directive/RemoteExecutionTask.java b/wrangler-service/src/main/java/io/cdap/wrangler/service/directive/RemoteExecutionTask.java index 97ecd8eba..a8f4e22f7 100644 --- a/wrangler-service/src/main/java/io/cdap/wrangler/service/directive/RemoteExecutionTask.java +++ b/wrangler-service/src/main/java/io/cdap/wrangler/service/directive/RemoteExecutionTask.java @@ -16,10 +16,13 @@ package io.cdap.wrangler.service.directive; import com.google.gson.Gson; +import com.google.gson.GsonBuilder; +import io.cdap.cdap.api.data.schema.Schema; import io.cdap.cdap.api.service.worker.RunnableTask; import io.cdap.cdap.api.service.worker.RunnableTaskContext; import io.cdap.cdap.api.service.worker.SystemAppTaskContext; import io.cdap.cdap.features.Feature; +import io.cdap.cdap.internal.io.SchemaTypeAdapter; import io.cdap.directives.aggregates.DefaultTransientStore; import io.cdap.wrangler.api.Arguments; import io.cdap.wrangler.api.CompileException; @@ -30,7 +33,10 @@ import io.cdap.wrangler.api.ErrorRecordBase; import io.cdap.wrangler.api.ExecutorContext; import io.cdap.wrangler.api.RecipeException; +import io.cdap.wrangler.api.RemoteDirectiveResponse; import io.cdap.wrangler.api.Row; +import io.cdap.wrangler.api.TransientStore; +import io.cdap.wrangler.api.TransientVariableScope; import io.cdap.wrangler.api.parser.UsageDefinition; import io.cdap.wrangler.executor.RecipePipelineExecutor; import io.cdap.wrangler.expression.EL; @@ -43,22 +49,25 @@ import io.cdap.wrangler.proto.ErrorRecordsException; import io.cdap.wrangler.registry.DirectiveInfo; import io.cdap.wrangler.registry.UserDirectiveRegistry; +import io.cdap.wrangler.utils.KryoSerializer; import io.cdap.wrangler.utils.ObjectSerDe; - -import io.cdap.wrangler.utils.RowSerializer; import java.util.ArrayList; import java.util.List; import java.util.Map; import java.util.concurrent.atomic.AtomicBoolean; import java.util.stream.Collectors; +import static io.cdap.wrangler.schema.TransientStoreKeys.INPUT_SCHEMA; +import static io.cdap.wrangler.schema.TransientStoreKeys.OUTPUT_SCHEMA; + /** * Task for remote execution of directives */ public class RemoteExecutionTask implements RunnableTask { - private static final Gson GSON = new Gson(); - + private static final Gson GSON = new GsonBuilder() + .registerTypeAdapter(Schema.class, new SchemaTypeAdapter()) + .create(); @Override public void run(RunnableTaskContext runnableTaskContext) throws Exception { @@ -105,12 +114,18 @@ public void run(RunnableTaskContext runnableTaskContext) throws Exception { ObjectSerDe> objectSerDe = new ObjectSerDe<>(); List rows = objectSerDe.toObject(directiveRequest.getData()); + Schema inputSchema = directiveRequest.getInputSchema(); + TransientStore transientStore = new DefaultTransientStore(); + if (inputSchema != null) { + transientStore.set(TransientVariableScope.GLOBAL, INPUT_SCHEMA, inputSchema); + } + try (RecipePipelineExecutor executor = new RecipePipelineExecutor(() -> directives, new ServicePipelineContext( namespace, ExecutorContext.Environment.SERVICE, systemAppContext, - new DefaultTransientStore()))) { + transientStore))) { rows = executor.execute(rows); List errors = executor.errors().stream() .filter(ErrorRecordBase::isShownInWrangler) @@ -123,12 +138,16 @@ public void run(RunnableTaskContext runnableTaskContext) throws Exception { throw new BadRequestException(e.getMessage(), e); } + Schema outputSchema = transientStore.get(OUTPUT_SCHEMA); + RemoteDirectiveResponse response = new RemoteDirectiveResponse(rows, outputSchema); + ObjectSerDe responseSerDe = new ObjectSerDe<>(); + runnableTaskContext.setTerminateOnComplete(hasUDD.get() || EL.isUsed()); if (Feature.WRANGLER_KRYO_SERIALIZATION.isEnabled(systemAppContext)) { - runnableTaskContext.writeResult(new RowSerializer().fromRows(rows)); + runnableTaskContext.writeResult(new KryoSerializer().fromRemoteDirectiveResponse(response)); } else { - runnableTaskContext.writeResult(objectSerDe.toByteArray(rows)); + runnableTaskContext.writeResult(responseSerDe.toByteArray(response)); } } catch (DirectiveParseException | ClassNotFoundException | CompileException e) { throw new BadRequestException(e.getMessage(), e); diff --git a/wrangler-service/src/main/java/io/cdap/wrangler/service/directive/WorkspaceHandler.java b/wrangler-service/src/main/java/io/cdap/wrangler/service/directive/WorkspaceHandler.java index 6d2667df9..43ffc3664 100644 --- a/wrangler-service/src/main/java/io/cdap/wrangler/service/directive/WorkspaceHandler.java +++ b/wrangler-service/src/main/java/io/cdap/wrangler/service/directive/WorkspaceHandler.java @@ -49,6 +49,7 @@ import io.cdap.wrangler.api.DirectiveParseException; import io.cdap.wrangler.api.GrammarMigrator; import io.cdap.wrangler.api.RecipeException; +import io.cdap.wrangler.api.RemoteDirectiveResponse; import io.cdap.wrangler.api.Row; import io.cdap.wrangler.api.TransientVariableScope; import io.cdap.wrangler.parser.ConfigDirectiveContext; @@ -78,9 +79,9 @@ import io.cdap.wrangler.schema.TransientStoreKeys; import io.cdap.wrangler.store.recipe.RecipeStore; import io.cdap.wrangler.store.workspace.WorkspaceStore; +import io.cdap.wrangler.utils.KryoSerializer; import io.cdap.wrangler.utils.ObjectSerDe; import io.cdap.wrangler.utils.RowHelper; -import io.cdap.wrangler.utils.RowSerializer; import io.cdap.wrangler.utils.SchemaConverter; import io.cdap.wrangler.utils.StructuredToRowTransformer; import org.apache.commons.lang3.StringEscapeUtils; @@ -103,6 +104,9 @@ import javax.ws.rs.Path; import javax.ws.rs.PathParam; +import static io.cdap.wrangler.schema.TransientStoreKeys.INPUT_SCHEMA; +import static io.cdap.wrangler.schema.TransientStoreKeys.OUTPUT_SCHEMA; + /** * V2 endpoints for workspace */ @@ -620,17 +624,23 @@ private List executeRemotely(String namespace, List>().toObject(bytes); + response = new ObjectSerDe().toObject(bytes); + } + if (response.getOutputSchema() != null) { + TRANSIENT_STORE.set(TransientVariableScope.GLOBAL, OUTPUT_SCHEMA, response.getOutputSchema()); } + return response.getRows(); } private List getSample(SampleResponse sampleResponse) {