diff --git a/wrangler-api/src/main/java/io/cdap/wrangler/api/RemoteDirectiveResponse.java b/wrangler-api/src/main/java/io/cdap/wrangler/api/RemoteDirectiveResponse.java new file mode 100644 index 000000000..57670ac4c --- /dev/null +++ b/wrangler-api/src/main/java/io/cdap/wrangler/api/RemoteDirectiveResponse.java @@ -0,0 +1,42 @@ +/* + * Copyright © 2024 Cask Data, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package io.cdap.wrangler.api; + +import io.cdap.cdap.api.data.schema.Schema; + +import java.io.Serializable; +import java.util.List; + +/** + * Response after executing directives remotely + */ +public class RemoteDirectiveResponse implements Serializable { + private final List rows; + private final Schema outputSchema; + + public RemoteDirectiveResponse(List rows, Schema outputSchema) { + this.rows = rows; + this.outputSchema = outputSchema; + } + + public List getRows() { + return rows; + } + + public Schema getOutputSchema() { + return outputSchema; + } +} diff --git a/wrangler-core/src/test/java/io/cdap/wrangler/utils/ObjectSerDeTest.java b/wrangler-core/src/test/java/io/cdap/wrangler/utils/ObjectSerDeTest.java index 689668444..f297b23e8 100644 --- a/wrangler-core/src/test/java/io/cdap/wrangler/utils/ObjectSerDeTest.java +++ b/wrangler-core/src/test/java/io/cdap/wrangler/utils/ObjectSerDeTest.java @@ -17,6 +17,8 @@ package io.cdap.wrangler.utils; import com.google.common.base.Charsets; +import io.cdap.cdap.api.data.schema.Schema; +import io.cdap.wrangler.api.RemoteDirectiveResponse; import io.cdap.wrangler.api.Row; import org.junit.Assert; import org.junit.Test; @@ -85,4 +87,20 @@ public void testLogicalTypeSerDe() throws Exception { actualRows = objectSerDe.toObject(bytes); Assert.assertEquals(expectedRows.size(), actualRows.size()); } + @Test + public void testRemoteDirectiveResponseSerDe() throws Exception { + List expectedRows = new ArrayList<>(); + Row firstRow = new Row(); + firstRow.add("id", 1); + expectedRows.add(firstRow); + Schema expectedSchema = Schema.recordOf(Schema.Field.of("id", Schema.of(Schema.Type.INT))); + RemoteDirectiveResponse expectedResponse = new RemoteDirectiveResponse(expectedRows, expectedSchema); + ObjectSerDe objectSerDe = new ObjectSerDe<>(); + + byte[] bytes = objectSerDe.toByteArray(expectedResponse); + RemoteDirectiveResponse actualResponse = objectSerDe.toObject(bytes); + + Assert.assertEquals(expectedResponse.getRows().size(), actualResponse.getRows().size()); + Assert.assertEquals(expectedResponse.getOutputSchema(), actualResponse.getOutputSchema()); + } } diff --git a/wrangler-service/src/main/java/io/cdap/wrangler/service/directive/RemoteDirectiveRequest.java b/wrangler-service/src/main/java/io/cdap/wrangler/service/directive/RemoteDirectiveRequest.java index 9b77f23f3..d3e6d959f 100644 --- a/wrangler-service/src/main/java/io/cdap/wrangler/service/directive/RemoteDirectiveRequest.java +++ b/wrangler-service/src/main/java/io/cdap/wrangler/service/directive/RemoteDirectiveRequest.java @@ -15,6 +15,7 @@ */ package io.cdap.wrangler.service.directive; +import io.cdap.cdap.api.data.schema.Schema; import io.cdap.wrangler.parser.DirectiveClass; import java.util.HashMap; @@ -29,13 +30,15 @@ public class RemoteDirectiveRequest { private final Map systemDirectives; private final String pluginNameSpace; private final byte[] data; + private final Schema inputSchema; RemoteDirectiveRequest(String recipe, Map systemDirectives, - String pluginNameSpace, byte[] data) { + String pluginNameSpace, byte[] data, Schema inputSchema) { this.recipe = recipe; this.systemDirectives = new HashMap<>(systemDirectives); this.pluginNameSpace = pluginNameSpace; this.data = data; + this.inputSchema = inputSchema; } public String getRecipe() { @@ -53,4 +56,8 @@ public byte[] getData() { public String getPluginNameSpace() { return pluginNameSpace; } + + public Schema getInputSchema() { + return inputSchema; + } } diff --git a/wrangler-service/src/main/java/io/cdap/wrangler/service/directive/RemoteExecutionTask.java b/wrangler-service/src/main/java/io/cdap/wrangler/service/directive/RemoteExecutionTask.java index 97ecd8eba..27c6dac60 100644 --- a/wrangler-service/src/main/java/io/cdap/wrangler/service/directive/RemoteExecutionTask.java +++ b/wrangler-service/src/main/java/io/cdap/wrangler/service/directive/RemoteExecutionTask.java @@ -16,10 +16,13 @@ package io.cdap.wrangler.service.directive; import com.google.gson.Gson; +import com.google.gson.GsonBuilder; +import io.cdap.cdap.api.data.schema.Schema; import io.cdap.cdap.api.service.worker.RunnableTask; import io.cdap.cdap.api.service.worker.RunnableTaskContext; import io.cdap.cdap.api.service.worker.SystemAppTaskContext; import io.cdap.cdap.features.Feature; +import io.cdap.cdap.internal.io.SchemaTypeAdapter; import io.cdap.directives.aggregates.DefaultTransientStore; import io.cdap.wrangler.api.Arguments; import io.cdap.wrangler.api.CompileException; @@ -30,7 +33,10 @@ import io.cdap.wrangler.api.ErrorRecordBase; import io.cdap.wrangler.api.ExecutorContext; import io.cdap.wrangler.api.RecipeException; +import io.cdap.wrangler.api.RemoteDirectiveResponse; import io.cdap.wrangler.api.Row; +import io.cdap.wrangler.api.TransientStore; +import io.cdap.wrangler.api.TransientVariableScope; import io.cdap.wrangler.api.parser.UsageDefinition; import io.cdap.wrangler.executor.RecipePipelineExecutor; import io.cdap.wrangler.expression.EL; @@ -52,13 +58,17 @@ import java.util.concurrent.atomic.AtomicBoolean; import java.util.stream.Collectors; +import static io.cdap.wrangler.schema.TransientStoreKeys.INPUT_SCHEMA; +import static io.cdap.wrangler.schema.TransientStoreKeys.OUTPUT_SCHEMA; + /** * Task for remote execution of directives */ public class RemoteExecutionTask implements RunnableTask { - private static final Gson GSON = new Gson(); - + private static final Gson GSON = new GsonBuilder() + .registerTypeAdapter(Schema.class, new SchemaTypeAdapter()) + .create(); @Override public void run(RunnableTaskContext runnableTaskContext) throws Exception { @@ -105,12 +115,18 @@ public void run(RunnableTaskContext runnableTaskContext) throws Exception { ObjectSerDe> objectSerDe = new ObjectSerDe<>(); List rows = objectSerDe.toObject(directiveRequest.getData()); + Schema inputSchema = directiveRequest.getInputSchema(); + TransientStore transientStore = new DefaultTransientStore(); + if (inputSchema != null) { + transientStore.set(TransientVariableScope.GLOBAL, INPUT_SCHEMA, inputSchema); + } + try (RecipePipelineExecutor executor = new RecipePipelineExecutor(() -> directives, new ServicePipelineContext( namespace, ExecutorContext.Environment.SERVICE, systemAppContext, - new DefaultTransientStore()))) { + transientStore))) { rows = executor.execute(rows); List errors = executor.errors().stream() .filter(ErrorRecordBase::isShownInWrangler) @@ -123,12 +139,16 @@ public void run(RunnableTaskContext runnableTaskContext) throws Exception { throw new BadRequestException(e.getMessage(), e); } + Schema outputSchema = transientStore.get(OUTPUT_SCHEMA); + RemoteDirectiveResponse response = new RemoteDirectiveResponse(rows, outputSchema); + ObjectSerDe responseSerDe = new ObjectSerDe<>(); + runnableTaskContext.setTerminateOnComplete(hasUDD.get() || EL.isUsed()); if (Feature.WRANGLER_KRYO_SERIALIZATION.isEnabled(systemAppContext)) { runnableTaskContext.writeResult(new RowSerializer().fromRows(rows)); } else { - runnableTaskContext.writeResult(objectSerDe.toByteArray(rows)); + runnableTaskContext.writeResult(responseSerDe.toByteArray(response)); } } catch (DirectiveParseException | ClassNotFoundException | CompileException e) { throw new BadRequestException(e.getMessage(), e); diff --git a/wrangler-service/src/main/java/io/cdap/wrangler/service/directive/WorkspaceHandler.java b/wrangler-service/src/main/java/io/cdap/wrangler/service/directive/WorkspaceHandler.java index 6d2667df9..c8c427bc8 100644 --- a/wrangler-service/src/main/java/io/cdap/wrangler/service/directive/WorkspaceHandler.java +++ b/wrangler-service/src/main/java/io/cdap/wrangler/service/directive/WorkspaceHandler.java @@ -49,6 +49,7 @@ import io.cdap.wrangler.api.DirectiveParseException; import io.cdap.wrangler.api.GrammarMigrator; import io.cdap.wrangler.api.RecipeException; +import io.cdap.wrangler.api.RemoteDirectiveResponse; import io.cdap.wrangler.api.Row; import io.cdap.wrangler.api.TransientVariableScope; import io.cdap.wrangler.parser.ConfigDirectiveContext; @@ -103,6 +104,9 @@ import javax.ws.rs.Path; import javax.ws.rs.PathParam; +import static io.cdap.wrangler.schema.TransientStoreKeys.INPUT_SCHEMA; +import static io.cdap.wrangler.schema.TransientStoreKeys.OUTPUT_SCHEMA; + /** * V2 endpoints for workspace */ @@ -620,7 +624,8 @@ private List executeRemotely(String namespace, List List executeRemotely(String namespace, List>().toObject(bytes); + RemoteDirectiveResponse response = new ObjectSerDe().toObject(bytes); + if (response.getOutputSchema() != null) { + TRANSIENT_STORE.set(TransientVariableScope.GLOBAL, OUTPUT_SCHEMA, response.getOutputSchema()); + } + return response.getRows(); } }