Skip to content

Commit

Permalink
Add schema handling to remote task execution
Browse files Browse the repository at this point in the history
  • Loading branch information
vanathi-g authored and samdgupi committed Apr 30, 2024
1 parent aadb82b commit 3898779
Show file tree
Hide file tree
Showing 9 changed files with 168 additions and 29 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
/*
* Copyright © 2024 Cask Data, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may not
* use this file except in compliance with the License. You may obtain a copy of
* the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations under
* the License.
*/
package io.cdap.wrangler.api;

import io.cdap.cdap.api.data.schema.Schema;

import java.io.Serializable;
import java.util.List;

/**
* Response after executing directives remotely
* Please make sure all fields are registered with {@link io.cdap.wrangler.utils.KryoSerializer}
*/
public class RemoteDirectiveResponse implements Serializable {
private final List<Row> rows;
private final Schema outputSchema;

/**
* Only used by {@link io.cdap.wrangler.utils.KryoSerializer}
**/
private RemoteDirectiveResponse() {
this(null, null);
}

public RemoteDirectiveResponse(List<Row> rows, Schema outputSchema) {
this.rows = rows;
this.outputSchema = outputSchema;
}

public List<Row> getRows() {
return rows;
}

public Schema getOutputSchema() {
return outputSchema;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

/**
* TransientStoreKeys for storing Workspace schema in TransientStore
* NOTE: Please add any needed value in {@link io.cdap.wrangler.api.RemoteDirectiveResponse}
*/
public final class TransientStoreKeys {
public static final String INPUT_SCHEMA = "ws_input_schema";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,15 @@
import com.esotericsoftware.kryo.Serializer;
import com.esotericsoftware.kryo.io.Input;
import com.esotericsoftware.kryo.io.Output;
import com.esotericsoftware.kryo.serializers.JavaSerializer;
import com.google.gson.Gson;
import com.google.gson.JsonArray;
import com.google.gson.JsonElement;
import com.google.gson.JsonNull;
import com.google.gson.JsonObject;
import com.google.gson.JsonPrimitive;
import io.cdap.cdap.api.data.schema.Schema;
import io.cdap.wrangler.api.RemoteDirectiveResponse;
import io.cdap.wrangler.api.Row;
import java.sql.Time;
import java.sql.Timestamp;
Expand All @@ -33,20 +36,24 @@
import java.time.ZonedDateTime;
import java.util.ArrayList;
import java.util.Date;
import java.util.List;
import java.util.Map;

/**
* A helper class with allows Serialization and Deserialization using Kryo
* We should register all schema classes present in {@link SchemaConverter}
* and {@link RemoteDirectiveResponse}
**/
public class RowSerializer {
public class KryoSerializer {

private final Kryo kryo;
private static final Gson GSON = new Gson();

public RowSerializer() {
public KryoSerializer() {
kryo = new Kryo();
// Register all classes from RemoteDirectiveResponse
kryo.register(RemoteDirectiveResponse.class);
// Schema does not have no-arg constructor but implements Serializable
kryo.register(Schema.class, new JavaSerializer());
// Register all classes from SchemaConverter
kryo.register(Row.class);
kryo.register(ArrayList.class);
Expand All @@ -56,7 +63,7 @@ public RowSerializer() {
kryo.register(Map.class);
kryo.register(JsonNull.class);
// JsonPrimitive does not have no-arg constructor hence we need a
// custom serializer
// custom serializer as it is not serializable by JavaSerializer
kryo.register(JsonPrimitive.class, new JsonSerializer());
kryo.register(JsonArray.class);
kryo.register(JsonObject.class);
Expand All @@ -67,16 +74,15 @@ public RowSerializer() {
kryo.register(Timestamp.class);
}

public byte[] fromRows(List<Row> rows) {
public byte[] fromRemoteDirectiveResponse(RemoteDirectiveResponse response) {
Output output = new Output(1024, -1);
kryo.writeClassAndObject(output, rows);
kryo.writeClassAndObject(output, response);
return output.getBuffer();
}

public List<Row> toRows(byte[] bytes) {
public RemoteDirectiveResponse toRemoteDirectiveResponse(byte[] bytes) {
Input input = new Input(bytes);
List<Row> result = (List<Row>) kryo.readClassAndObject(input);
return result;
return (RemoteDirectiveResponse) kryo.readClassAndObject(input);
}

static class JsonSerializer extends Serializer<JsonElement> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import com.google.gson.JsonArray;
import com.google.gson.JsonElement;
import com.google.gson.JsonObject;
import com.google.gson.JsonParser;
import com.google.gson.JsonPrimitive;
import io.cdap.cdap.api.data.schema.Schema;
import io.cdap.cdap.api.data.schema.Schema.Field;
Expand Down Expand Up @@ -100,7 +99,7 @@ public Schema getSchema(Object value, String name) throws RecordConvertorExcepti
* @param name name of the field
* @param recordPrefix prefix to append at the beginning of a custom record
* @return the schema of this object
* NOTE: ANY NEWLY SUPPORTED DATATYPE SHOULD ALSO BE REGISTERED IN {@link RowSerializer}
* NOTE: ANY NEWLY SUPPORTED DATATYPE SHOULD ALSO BE REGISTERED IN {@link KryoSerializer}
*/
@Nullable
public Schema getSchema(Object value, String name, @Nullable String recordPrefix) throws RecordConvertorException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@
import com.google.gson.JsonParser;
import io.cdap.wrangler.TestingRig;
import io.cdap.wrangler.api.RecipePipeline;
import io.cdap.wrangler.api.RemoteDirectiveResponse;
import io.cdap.wrangler.api.Row;
import io.cdap.cdap.api.data.schema.Schema;
import org.junit.Assert;
import org.junit.Test;
import java.math.BigDecimal;
Expand All @@ -37,7 +39,7 @@
import java.util.Map;
import java.util.Set;

public class RowSerializerTest {
public class KryoSerializerTest {

private static final String[] TESTS = new String[]{
JsonTestData.BASIC,
Expand Down Expand Up @@ -67,8 +69,9 @@ public void testJsonTypes() throws Exception {
Row row = new Row("body", test);

List<Row> expectedRows = executor.execute(Lists.newArrayList(row));
byte[] serializedRows = new RowSerializer().fromRows(expectedRows);
List<Row> gotRows = new RowSerializer().toRows(serializedRows);
byte[] serializedRows = new KryoSerializer().fromRemoteDirectiveResponse(
new RemoteDirectiveResponse(expectedRows, null));
List<Row> gotRows = new KryoSerializer().toRemoteDirectiveResponse(serializedRows).getRows();
Assert.assertArrayEquals(expectedRows.toArray(), gotRows.toArray());
}
}
Expand All @@ -84,8 +87,9 @@ public void testLogicalTypes() throws Exception {
testRow.add("bigdecimal", new BigDecimal(new BigInteger("123456"), 5));
testRow.add("datetime", LocalDateTime.now());
List<Row> expectedRows = Collections.singletonList(testRow);
byte[] serializedRows = new RowSerializer().fromRows(expectedRows);
List<Row> gotRows = new RowSerializer().toRows(serializedRows);
byte[] serializedRows = new KryoSerializer().fromRemoteDirectiveResponse(
new RemoteDirectiveResponse(expectedRows, null));
List<Row> gotRows = new KryoSerializer().toRemoteDirectiveResponse(serializedRows).getRows();
Assert.assertArrayEquals(expectedRows.toArray(), gotRows.toArray());
}

Expand All @@ -110,8 +114,32 @@ public void testCollectionTypes() throws Exception {
testRow.add("map", map);

List<Row> expectedRows = Collections.singletonList(testRow);
byte[] serializedRows = new RowSerializer().fromRows(expectedRows);
List<Row> gotRows = new RowSerializer().toRows(serializedRows);
byte[] serializedRows = new KryoSerializer().fromRemoteDirectiveResponse(
new RemoteDirectiveResponse(expectedRows, null));
List<Row> gotRows = new KryoSerializer().toRemoteDirectiveResponse(serializedRows).getRows();
Assert.assertArrayEquals(expectedRows.toArray(), gotRows.toArray());
}

@Test
public void testWithSchema() throws Exception {
Row testRow = new Row();
testRow.add("id", 1);
testRow.add("name", "abc");
testRow.add("date", LocalDate.of(2018, 11, 11));
testRow.add("time", LocalTime.of(11, 11, 11));
testRow.add("timestamp", ZonedDateTime.of(2018, 11, 11, 11, 11, 11, 0, ZoneId.of("UTC")));
testRow.add("bigdecimal", new BigDecimal(new BigInteger("123456"), 5));
testRow.add("datetime", LocalDateTime.now());
List<Row> expectedRows = Collections.singletonList(testRow);

SchemaConverter converter = new SchemaConverter();
Schema expectedSchema = converter.toSchema("myrecord", expectedRows.get(0));

byte[] serializedRows = new KryoSerializer().fromRemoteDirectiveResponse(
new RemoteDirectiveResponse(expectedRows, expectedSchema));
RemoteDirectiveResponse response = new KryoSerializer().toRemoteDirectiveResponse(serializedRows);

Assert.assertArrayEquals(expectedRows.toArray(), response.getRows().toArray());
Assert.assertEquals (expectedSchema, response.getOutputSchema());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
package io.cdap.wrangler.utils;

import com.google.common.base.Charsets;
import io.cdap.cdap.api.data.schema.Schema;
import io.cdap.wrangler.api.RemoteDirectiveResponse;
import io.cdap.wrangler.api.Row;
import org.junit.Assert;
import org.junit.Test;
Expand Down Expand Up @@ -85,4 +87,20 @@ public void testLogicalTypeSerDe() throws Exception {
actualRows = objectSerDe.toObject(bytes);
Assert.assertEquals(expectedRows.size(), actualRows.size());
}
@Test
public void testRemoteDirectiveResponseSerDe() throws Exception {
List<Row> expectedRows = new ArrayList<>();
Row firstRow = new Row();
firstRow.add("id", 1);
expectedRows.add(firstRow);
Schema expectedSchema = Schema.recordOf(Schema.Field.of("id", Schema.of(Schema.Type.INT)));
RemoteDirectiveResponse expectedResponse = new RemoteDirectiveResponse(expectedRows, expectedSchema);
ObjectSerDe<RemoteDirectiveResponse> objectSerDe = new ObjectSerDe<>();

byte[] bytes = objectSerDe.toByteArray(expectedResponse);
RemoteDirectiveResponse actualResponse = objectSerDe.toObject(bytes);

Assert.assertEquals(expectedResponse.getRows().size(), actualResponse.getRows().size());
Assert.assertEquals(expectedResponse.getOutputSchema(), actualResponse.getOutputSchema());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
*/
package io.cdap.wrangler.service.directive;

import io.cdap.cdap.api.data.schema.Schema;
import io.cdap.wrangler.parser.DirectiveClass;

import java.util.HashMap;
Expand All @@ -29,13 +30,15 @@ public class RemoteDirectiveRequest {
private final Map<String, DirectiveClass> systemDirectives;
private final String pluginNameSpace;
private final byte[] data;
private final Schema inputSchema;

RemoteDirectiveRequest(String recipe, Map<String, DirectiveClass> systemDirectives,
String pluginNameSpace, byte[] data) {
String pluginNameSpace, byte[] data, Schema inputSchema) {
this.recipe = recipe;
this.systemDirectives = new HashMap<>(systemDirectives);
this.pluginNameSpace = pluginNameSpace;
this.data = data;
this.inputSchema = inputSchema;
}

public String getRecipe() {
Expand All @@ -53,4 +56,8 @@ public byte[] getData() {
public String getPluginNameSpace() {
return pluginNameSpace;
}

public Schema getInputSchema() {
return inputSchema;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,13 @@
package io.cdap.wrangler.service.directive;

import com.google.gson.Gson;
import com.google.gson.GsonBuilder;
import io.cdap.cdap.api.data.schema.Schema;
import io.cdap.cdap.api.service.worker.RunnableTask;
import io.cdap.cdap.api.service.worker.RunnableTaskContext;
import io.cdap.cdap.api.service.worker.SystemAppTaskContext;
import io.cdap.cdap.features.Feature;
import io.cdap.cdap.internal.io.SchemaTypeAdapter;
import io.cdap.directives.aggregates.DefaultTransientStore;
import io.cdap.wrangler.api.Arguments;
import io.cdap.wrangler.api.CompileException;
Expand All @@ -30,7 +33,10 @@
import io.cdap.wrangler.api.ErrorRecordBase;
import io.cdap.wrangler.api.ExecutorContext;
import io.cdap.wrangler.api.RecipeException;
import io.cdap.wrangler.api.RemoteDirectiveResponse;
import io.cdap.wrangler.api.Row;
import io.cdap.wrangler.api.TransientStore;
import io.cdap.wrangler.api.TransientVariableScope;
import io.cdap.wrangler.api.parser.UsageDefinition;
import io.cdap.wrangler.executor.RecipePipelineExecutor;
import io.cdap.wrangler.expression.EL;
Expand All @@ -45,20 +51,24 @@
import io.cdap.wrangler.registry.UserDirectiveRegistry;
import io.cdap.wrangler.utils.ObjectSerDe;

import io.cdap.wrangler.utils.RowSerializer;
import io.cdap.wrangler.utils.KryoSerializer;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.stream.Collectors;

import static io.cdap.wrangler.schema.TransientStoreKeys.INPUT_SCHEMA;
import static io.cdap.wrangler.schema.TransientStoreKeys.OUTPUT_SCHEMA;

/**
* Task for remote execution of directives
*/
public class RemoteExecutionTask implements RunnableTask {

private static final Gson GSON = new Gson();

private static final Gson GSON = new GsonBuilder()
.registerTypeAdapter(Schema.class, new SchemaTypeAdapter())
.create();

@Override
public void run(RunnableTaskContext runnableTaskContext) throws Exception {
Expand Down Expand Up @@ -105,12 +115,18 @@ public void run(RunnableTaskContext runnableTaskContext) throws Exception {
ObjectSerDe<List<Row>> objectSerDe = new ObjectSerDe<>();
List<Row> rows = objectSerDe.toObject(directiveRequest.getData());

Schema inputSchema = directiveRequest.getInputSchema();
TransientStore transientStore = new DefaultTransientStore();
if (inputSchema != null) {
transientStore.set(TransientVariableScope.GLOBAL, INPUT_SCHEMA, inputSchema);
}

try (RecipePipelineExecutor executor = new RecipePipelineExecutor(() -> directives,
new ServicePipelineContext(
namespace,
ExecutorContext.Environment.SERVICE,
systemAppContext,
new DefaultTransientStore()))) {
transientStore))) {
rows = executor.execute(rows);
List<ErrorRecordBase> errors = executor.errors().stream()
.filter(ErrorRecordBase::isShownInWrangler)
Expand All @@ -123,12 +139,16 @@ public void run(RunnableTaskContext runnableTaskContext) throws Exception {
throw new BadRequestException(e.getMessage(), e);
}

Schema outputSchema = transientStore.get(OUTPUT_SCHEMA);
RemoteDirectiveResponse response = new RemoteDirectiveResponse(rows, outputSchema);
ObjectSerDe<RemoteDirectiveResponse> responseSerDe = new ObjectSerDe<>();

runnableTaskContext.setTerminateOnComplete(hasUDD.get() || EL.isUsed());

if (Feature.WRANGLER_KRYO_SERIALIZATION.isEnabled(systemAppContext)) {
runnableTaskContext.writeResult(new RowSerializer().fromRows(rows));
runnableTaskContext.writeResult(new KryoSerializer().fromRemoteDirectiveResponse(response));
} else {
runnableTaskContext.writeResult(objectSerDe.toByteArray(rows));
runnableTaskContext.writeResult(responseSerDe.toByteArray(response));
}
} catch (DirectiveParseException | ClassNotFoundException | CompileException e) {
throw new BadRequestException(e.getMessage(), e);
Expand Down
Loading

0 comments on commit 3898779

Please sign in to comment.