Skip to content

Commit

Permalink
Add schema handling to remote task execution
Browse files Browse the repository at this point in the history
  • Loading branch information
vanathi-g committed Apr 30, 2024
1 parent aadb82b commit fc79157
Show file tree
Hide file tree
Showing 5 changed files with 103 additions and 7 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
/*
* Copyright © 2024 Cask Data, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may not
* use this file except in compliance with the License. You may obtain a copy of
* the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations under
* the License.
*/
package io.cdap.wrangler.api;

import io.cdap.cdap.api.data.schema.Schema;

import java.io.Serializable;
import java.util.List;

/**
* Response after executing directives remotely
*/
public class RemoteDirectiveResponse implements Serializable {
private final List<Row> rows;
private final Schema outputSchema;

public RemoteDirectiveResponse(List<Row> rows, Schema outputSchema) {
this.rows = rows;
this.outputSchema = outputSchema;
}

public List<Row> getRows() {
return rows;
}

public Schema getOutputSchema() {
return outputSchema;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
package io.cdap.wrangler.utils;

import com.google.common.base.Charsets;
import io.cdap.cdap.api.data.schema.Schema;
import io.cdap.wrangler.api.RemoteDirectiveResponse;
import io.cdap.wrangler.api.Row;
import org.junit.Assert;
import org.junit.Test;
Expand Down Expand Up @@ -85,4 +87,20 @@ public void testLogicalTypeSerDe() throws Exception {
actualRows = objectSerDe.toObject(bytes);
Assert.assertEquals(expectedRows.size(), actualRows.size());
}
@Test
public void testRemoteDirectiveResponseSerDe() throws Exception {
List<Row> expectedRows = new ArrayList<>();
Row firstRow = new Row();
firstRow.add("id", 1);
expectedRows.add(firstRow);
Schema expectedSchema = Schema.recordOf(Schema.Field.of("id", Schema.of(Schema.Type.INT)));
RemoteDirectiveResponse expectedResponse = new RemoteDirectiveResponse(expectedRows, expectedSchema);
ObjectSerDe<RemoteDirectiveResponse> objectSerDe = new ObjectSerDe<>();

byte[] bytes = objectSerDe.toByteArray(expectedResponse);
RemoteDirectiveResponse actualResponse = objectSerDe.toObject(bytes);

Assert.assertEquals(expectedResponse.getRows().size(), actualResponse.getRows().size());
Assert.assertEquals(expectedResponse.getOutputSchema(), actualResponse.getOutputSchema());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
*/
package io.cdap.wrangler.service.directive;

import io.cdap.cdap.api.data.schema.Schema;
import io.cdap.wrangler.parser.DirectiveClass;

import java.util.HashMap;
Expand All @@ -29,13 +30,15 @@ public class RemoteDirectiveRequest {
private final Map<String, DirectiveClass> systemDirectives;
private final String pluginNameSpace;
private final byte[] data;
private final Schema inputSchema;

RemoteDirectiveRequest(String recipe, Map<String, DirectiveClass> systemDirectives,
String pluginNameSpace, byte[] data) {
String pluginNameSpace, byte[] data, Schema inputSchema) {
this.recipe = recipe;
this.systemDirectives = new HashMap<>(systemDirectives);
this.pluginNameSpace = pluginNameSpace;
this.data = data;
this.inputSchema = inputSchema;
}

public String getRecipe() {
Expand All @@ -53,4 +56,8 @@ public byte[] getData() {
public String getPluginNameSpace() {
return pluginNameSpace;
}

public Schema getInputSchema() {
return inputSchema;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,13 @@
package io.cdap.wrangler.service.directive;

import com.google.gson.Gson;
import com.google.gson.GsonBuilder;
import io.cdap.cdap.api.data.schema.Schema;
import io.cdap.cdap.api.service.worker.RunnableTask;
import io.cdap.cdap.api.service.worker.RunnableTaskContext;
import io.cdap.cdap.api.service.worker.SystemAppTaskContext;
import io.cdap.cdap.features.Feature;
import io.cdap.cdap.internal.io.SchemaTypeAdapter;
import io.cdap.directives.aggregates.DefaultTransientStore;
import io.cdap.wrangler.api.Arguments;
import io.cdap.wrangler.api.CompileException;
Expand All @@ -30,7 +33,10 @@
import io.cdap.wrangler.api.ErrorRecordBase;
import io.cdap.wrangler.api.ExecutorContext;
import io.cdap.wrangler.api.RecipeException;
import io.cdap.wrangler.api.RemoteDirectiveResponse;
import io.cdap.wrangler.api.Row;
import io.cdap.wrangler.api.TransientStore;
import io.cdap.wrangler.api.TransientVariableScope;
import io.cdap.wrangler.api.parser.UsageDefinition;
import io.cdap.wrangler.executor.RecipePipelineExecutor;
import io.cdap.wrangler.expression.EL;
Expand All @@ -52,13 +58,17 @@
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.stream.Collectors;

import static io.cdap.wrangler.schema.TransientStoreKeys.INPUT_SCHEMA;
import static io.cdap.wrangler.schema.TransientStoreKeys.OUTPUT_SCHEMA;

/**
* Task for remote execution of directives
*/
public class RemoteExecutionTask implements RunnableTask {

private static final Gson GSON = new Gson();

private static final Gson GSON = new GsonBuilder()
.registerTypeAdapter(Schema.class, new SchemaTypeAdapter())
.create();

@Override
public void run(RunnableTaskContext runnableTaskContext) throws Exception {
Expand Down Expand Up @@ -105,12 +115,18 @@ public void run(RunnableTaskContext runnableTaskContext) throws Exception {
ObjectSerDe<List<Row>> objectSerDe = new ObjectSerDe<>();
List<Row> rows = objectSerDe.toObject(directiveRequest.getData());

Schema inputSchema = directiveRequest.getInputSchema();
TransientStore transientStore = new DefaultTransientStore();
if (inputSchema != null) {
transientStore.set(TransientVariableScope.GLOBAL, INPUT_SCHEMA, inputSchema);
}

try (RecipePipelineExecutor executor = new RecipePipelineExecutor(() -> directives,
new ServicePipelineContext(
namespace,
ExecutorContext.Environment.SERVICE,
systemAppContext,
new DefaultTransientStore()))) {
transientStore))) {
rows = executor.execute(rows);
List<ErrorRecordBase> errors = executor.errors().stream()
.filter(ErrorRecordBase::isShownInWrangler)
Expand All @@ -123,12 +139,16 @@ public void run(RunnableTaskContext runnableTaskContext) throws Exception {
throw new BadRequestException(e.getMessage(), e);
}

Schema outputSchema = transientStore.get(OUTPUT_SCHEMA);
RemoteDirectiveResponse response = new RemoteDirectiveResponse(rows, outputSchema);
ObjectSerDe<RemoteDirectiveResponse> responseSerDe = new ObjectSerDe<>();

runnableTaskContext.setTerminateOnComplete(hasUDD.get() || EL.isUsed());

if (Feature.WRANGLER_KRYO_SERIALIZATION.isEnabled(systemAppContext)) {
runnableTaskContext.writeResult(new RowSerializer().fromRows(rows));
} else {
runnableTaskContext.writeResult(objectSerDe.toByteArray(rows));
runnableTaskContext.writeResult(responseSerDe.toByteArray(response));
}
} catch (DirectiveParseException | ClassNotFoundException | CompileException e) {
throw new BadRequestException(e.getMessage(), e);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
import io.cdap.wrangler.api.DirectiveParseException;
import io.cdap.wrangler.api.GrammarMigrator;
import io.cdap.wrangler.api.RecipeException;
import io.cdap.wrangler.api.RemoteDirectiveResponse;
import io.cdap.wrangler.api.Row;
import io.cdap.wrangler.api.TransientVariableScope;
import io.cdap.wrangler.parser.ConfigDirectiveContext;
Expand Down Expand Up @@ -103,6 +104,9 @@
import javax.ws.rs.Path;
import javax.ws.rs.PathParam;

import static io.cdap.wrangler.schema.TransientStoreKeys.INPUT_SCHEMA;
import static io.cdap.wrangler.schema.TransientStoreKeys.OUTPUT_SCHEMA;

/**
* V2 endpoints for workspace
*/
Expand Down Expand Up @@ -620,7 +624,8 @@ private <E extends Exception> List<Row> executeRemotely(String namespace, List<S
}

RemoteDirectiveRequest directiveRequest = new RemoteDirectiveRequest(recipe, systemDirectives,
namespace, detail.getSampleAsBytes());
namespace, detail.getSampleAsBytes(),
TRANSIENT_STORE.get(INPUT_SCHEMA));
RunnableTaskRequest runnableTaskRequest = RunnableTaskRequest.getBuilder(RemoteExecutionTask.class.getName())
.withParam(GSON.toJson(directiveRequest))
.withNamespace(namespace)
Expand All @@ -629,7 +634,11 @@ private <E extends Exception> List<Row> executeRemotely(String namespace, List<S
if (Feature.WRANGLER_KRYO_SERIALIZATION.isEnabled(getContext())) {
return new RowSerializer().toRows(bytes);
} else {
return new ObjectSerDe<List<Row>>().toObject(bytes);
RemoteDirectiveResponse response = new ObjectSerDe<RemoteDirectiveResponse>().toObject(bytes);
if (response.getOutputSchema() != null) {
TRANSIENT_STORE.set(TransientVariableScope.GLOBAL, OUTPUT_SCHEMA, response.getOutputSchema());
}
return response.getRows();
}
}

Expand Down

0 comments on commit fc79157

Please sign in to comment.