diff --git a/wrangler-api/src/main/java/io/cdap/wrangler/api/Executor.java b/wrangler-api/src/main/java/io/cdap/wrangler/api/Executor.java
index abc726295..8c85319e4 100644
--- a/wrangler-api/src/main/java/io/cdap/wrangler/api/Executor.java
+++ b/wrangler-api/src/main/java/io/cdap/wrangler/api/Executor.java
@@ -16,9 +16,11 @@
package io.cdap.wrangler.api;
+import io.cdap.cdap.api.data.schema.Schema;
import io.cdap.wrangler.api.annotations.PublicEvolving;
import java.io.Serializable;
+import javax.annotation.Nullable;
/**
* A interface defining the wrangle Executor in the wrangling {@link RecipePipeline}.
@@ -80,5 +82,19 @@ O execute(I rows, ExecutorContext context)
* correct at this phase of invocation.
*/
void destroy();
-}
+ /**
+ * This method is used to get the updated schema of the data after the directive's transformation has been applied.
+ *
+ * @param schemaResolutionContext context containing necessary information for getting output schema
+ * @return output {@link Schema} of the transformed data
+ * @implNote By default, returns a null and the schema is inferred from the data when necessary.
+ *
For consistent handling, override for directives that perform column renames,
+ * column data type changes or column additions with specific schemas.
+ */
+ @Nullable
+ default Schema getOutputSchema(SchemaResolutionContext schemaResolutionContext) {
+ // no op
+ return null;
+ }
+}
diff --git a/wrangler-api/src/main/java/io/cdap/wrangler/api/SchemaResolutionContext.java b/wrangler-api/src/main/java/io/cdap/wrangler/api/SchemaResolutionContext.java
new file mode 100644
index 000000000..015f8bdc6
--- /dev/null
+++ b/wrangler-api/src/main/java/io/cdap/wrangler/api/SchemaResolutionContext.java
@@ -0,0 +1,29 @@
+/*
+ * Copyright © 2023 Cask Data, Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License"); you may not
+ * use this file except in compliance with the License. You may obtain a copy of
+ * the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+ * License for the specific language governing permissions and limitations under
+ * the License.
+ */
+
+package io.cdap.wrangler.api;
+
+import io.cdap.cdap.api.data.schema.Schema;
+
+/**
+ * Interface to pass contextual information related to getting or generating the output schema of a {@link Executor}
+ */
+public interface SchemaResolutionContext {
+ /**
+ * @return {@link Schema} of the input data before transformation
+ */
+ Schema getInputSchema();
+}
diff --git a/wrangler-core/src/main/java/io/cdap/wrangler/executor/RecipePipelineExecutor.java b/wrangler-core/src/main/java/io/cdap/wrangler/executor/RecipePipelineExecutor.java
index 7202ec90a..a41206e31 100644
--- a/wrangler-core/src/main/java/io/cdap/wrangler/executor/RecipePipelineExecutor.java
+++ b/wrangler-core/src/main/java/io/cdap/wrangler/executor/RecipePipelineExecutor.java
@@ -30,8 +30,12 @@
import io.cdap.wrangler.api.ReportErrorAndProceed;
import io.cdap.wrangler.api.Row;
import io.cdap.wrangler.api.TransientVariableScope;
+import io.cdap.wrangler.schema.DirectiveOutputSchemaGenerator;
+import io.cdap.wrangler.schema.DirectiveSchemaResolutionContext;
+import io.cdap.wrangler.schema.TransientStoreKeys;
import io.cdap.wrangler.utils.RecordConvertor;
import io.cdap.wrangler.utils.RecordConvertorException;
+import io.cdap.wrangler.utils.SchemaConverter;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -48,6 +52,7 @@ public final class RecipePipelineExecutor implements RecipePipeline directives;
@@ -103,6 +108,19 @@ public List execute(List rows) throws RecipeException {
List results = new ArrayList<>();
int i = 0;
int directiveIndex = 0;
+ // Initialize schema with input schema from TransientStore if running in service env (design-time) / testing env
+ boolean designTime = context != null && context.getEnvironment() != null &&
+ (context.getEnvironment().equals(ExecutorContext.Environment.SERVICE) ||
+ context.getEnvironment().equals(ExecutorContext.Environment.TESTING));
+ Schema inputSchema = designTime ? context.getTransientStore().get(TransientStoreKeys.INPUT_SCHEMA) : null;
+
+ List outputSchemaGenerators = new ArrayList<>();
+ if (designTime && inputSchema != null) {
+ for (Directive directive : directives) {
+ outputSchemaGenerators.add(new DirectiveOutputSchemaGenerator(directive, generator));
+ }
+ }
+
try {
collector.reset();
while (i < rows.size()) {
@@ -122,6 +140,9 @@ public List execute(List rows) throws RecipeException {
if (cumulativeRows.size() < 1) {
break;
}
+ if (designTime && inputSchema != null) {
+ outputSchemaGenerators.get(directiveIndex - 1).addNewOutputFields(cumulativeRows);
+ }
} catch (ReportErrorAndProceed e) {
messages.add(String.format("%s (ecode: %d)", e.getMessage(), e.getCode()));
collector
@@ -142,6 +163,11 @@ public List execute(List rows) throws RecipeException {
} catch (DirectiveExecutionException e) {
throw new RecipeException(e.getMessage(), e, i, directiveIndex);
}
+ // Schema generation
+ if (designTime && inputSchema != null) {
+ context.getTransientStore().set(TransientVariableScope.GLOBAL, TransientStoreKeys.OUTPUT_SCHEMA,
+ getOutputSchema(inputSchema, outputSchemaGenerators));
+ }
return results;
}
@@ -161,4 +187,17 @@ private List getDirectives() throws RecipeException {
}
return directives;
}
+
+ private Schema getOutputSchema(Schema inputSchema, List outputSchemaGenerators)
+ throws RecipeException {
+ Schema schema = inputSchema;
+ for (DirectiveOutputSchemaGenerator outputSchemaGenerator : outputSchemaGenerators) {
+ try {
+ schema = outputSchemaGenerator.getDirectiveOutputSchema(new DirectiveSchemaResolutionContext(schema));
+ } catch (RecordConvertorException e) {
+ throw new RecipeException("Error while generating output schema for a directive: " + e, e);
+ }
+ }
+ return schema;
+ }
}
diff --git a/wrangler-core/src/main/java/io/cdap/wrangler/schema/DirectiveOutputSchemaGenerator.java b/wrangler-core/src/main/java/io/cdap/wrangler/schema/DirectiveOutputSchemaGenerator.java
new file mode 100644
index 000000000..980780f04
--- /dev/null
+++ b/wrangler-core/src/main/java/io/cdap/wrangler/schema/DirectiveOutputSchemaGenerator.java
@@ -0,0 +1,119 @@
+/*
+ * Copyright © 2023 Cask Data, Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License"); you may not
+ * use this file except in compliance with the License. You may obtain a copy of
+ * the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+ * License for the specific language governing permissions and limitations under
+ * the License.
+ */
+
+package io.cdap.wrangler.schema;
+
+import io.cdap.cdap.api.data.schema.Schema;
+import io.cdap.wrangler.api.Directive;
+import io.cdap.wrangler.api.Pair;
+import io.cdap.wrangler.api.Row;
+import io.cdap.wrangler.api.SchemaResolutionContext;
+import io.cdap.wrangler.utils.RecordConvertorException;
+import io.cdap.wrangler.utils.SchemaConverter;
+
+import java.util.ArrayList;
+import java.util.LinkedHashMap;
+import java.util.List;
+import java.util.Map;
+import javax.annotation.Nullable;
+
+/**
+ * This class can be used to generate the output schema for the output data of a directive. It maintains a map of
+ * output fields present across all output rows after applying a directive. This map is used to generate the schema
+ * if the directive does not return a custom output schema.
+ */
+public class DirectiveOutputSchemaGenerator {
+ private final SchemaConverter schemaGenerator;
+ private final Map outputFieldMap;
+ private final Directive directive;
+
+ public DirectiveOutputSchemaGenerator(Directive directive, SchemaConverter schemaGenerator) {
+ this.directive = directive;
+ this.schemaGenerator = schemaGenerator;
+ outputFieldMap = new LinkedHashMap<>();
+ }
+
+ /**
+ * Method to add new fields from the given output to the map of fieldName --> value maintained for schema generation.
+ * A value is added to the map only if it is absent (or) if the existing value is null and given value is non-null
+ * @param output list of output {@link Row}s after applying directive.
+ */
+ public void addNewOutputFields(List output) {
+ for (Row row : output) {
+ for (Pair field : row.getFields()) {
+ String fieldName = field.getFirst();
+ Object fieldValue = field.getSecond();
+ if (outputFieldMap.containsKey(fieldName)) {
+ // If existing value is null, override with this non-null value
+ if (fieldValue != null && outputFieldMap.get(fieldName) == null) {
+ outputFieldMap.put(fieldName, fieldValue);
+ }
+ } else {
+ outputFieldMap.put(fieldName, fieldValue);
+ }
+ }
+ }
+ }
+
+ /**
+ * Method to get the output schema of the directive. Returns a generated schema based on maintained map of fields
+ * only if directive does not return a custom output schema.
+ * @param context input {@link Schema} of the data before applying the directive
+ * @return {@link Schema} corresponding to the output data
+ */
+ public Schema getDirectiveOutputSchema(SchemaResolutionContext context) throws RecordConvertorException {
+ Schema directiveOutputSchema = directive.getOutputSchema(context);
+ return directiveOutputSchema != null ? directiveOutputSchema :
+ generateDirectiveOutputSchema(context.getInputSchema());
+ }
+
+ // Given the schema from previous step and output of current directive, generates the directive output schema.
+ private Schema generateDirectiveOutputSchema(Schema inputSchema)
+ throws RecordConvertorException {
+ List outputFields = new ArrayList<>();
+ for (Map.Entry field : outputFieldMap.entrySet()) {
+ String fieldName = field.getKey();
+ Object fieldValue = field.getValue();
+
+ Schema existing = inputSchema.getField(fieldName) != null ? inputSchema.getField(fieldName).getSchema() : null;
+ Schema generated = fieldValue != null && !isValidSchemaForValue(existing, fieldValue) ?
+ schemaGenerator.getSchema(fieldValue, fieldName) : null;
+
+ if (generated != null) {
+ outputFields.add(Schema.Field.of(fieldName, generated));
+ } else if (existing != null) {
+ if (!existing.isNullable()) {
+ existing = Schema.nullableOf(existing);
+ }
+ outputFields.add(Schema.Field.of(fieldName, existing));
+ } else {
+ outputFields.add(Schema.Field.of(fieldName, Schema.of(Schema.Type.NULL)));
+ }
+ }
+ return Schema.recordOf("output", outputFields);
+ }
+
+ // Checks whether the provided input schema is of valid type for given object
+ private boolean isValidSchemaForValue(@Nullable Schema schema, Object value) throws RecordConvertorException {
+ if (schema == null) {
+ return false;
+ }
+ Schema generated = schemaGenerator.getSchema(value, "temp_field_name");
+ generated = generated.isNullable() ? generated.getNonNullable() : generated;
+ schema = schema.isNullable() ? schema.getNonNullable() : schema;
+ return generated.getLogicalType() == schema.getLogicalType() && generated.getType() == schema.getType();
+ }
+}
diff --git a/wrangler-core/src/main/java/io/cdap/wrangler/schema/DirectiveSchemaResolutionContext.java b/wrangler-core/src/main/java/io/cdap/wrangler/schema/DirectiveSchemaResolutionContext.java
new file mode 100644
index 000000000..9c4c702fb
--- /dev/null
+++ b/wrangler-core/src/main/java/io/cdap/wrangler/schema/DirectiveSchemaResolutionContext.java
@@ -0,0 +1,36 @@
+/*
+ * Copyright © 2023 Cask Data, Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License"); you may not
+ * use this file except in compliance with the License. You may obtain a copy of
+ * the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+ * License for the specific language governing permissions and limitations under
+ * the License.
+ */
+
+package io.cdap.wrangler.schema;
+
+import io.cdap.cdap.api.data.schema.Schema;
+import io.cdap.wrangler.api.Directive;
+import io.cdap.wrangler.api.SchemaResolutionContext;
+
+/**
+ * Context to pass information related to getting or generating the output schema of a {@link Directive}
+ */
+public class DirectiveSchemaResolutionContext implements SchemaResolutionContext {
+ private final Schema inputSchema;
+ public DirectiveSchemaResolutionContext(Schema inputSchema) {
+ this.inputSchema = inputSchema;
+ }
+
+ @Override
+ public Schema getInputSchema() {
+ return inputSchema;
+ }
+}
diff --git a/wrangler-core/src/main/java/io/cdap/wrangler/schema/TransientStoreKeys.java b/wrangler-core/src/main/java/io/cdap/wrangler/schema/TransientStoreKeys.java
new file mode 100644
index 000000000..e35ef803f
--- /dev/null
+++ b/wrangler-core/src/main/java/io/cdap/wrangler/schema/TransientStoreKeys.java
@@ -0,0 +1,29 @@
+/*
+ * Copyright © 2023 Cask Data, Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License"); you may not
+ * use this file except in compliance with the License. You may obtain a copy of
+ * the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+ * License for the specific language governing permissions and limitations under
+ * the License.
+ */
+
+package io.cdap.wrangler.schema;
+
+/**
+ * TransientStoreKeys for storing Workspace schema in TransientStore
+ */
+public final class TransientStoreKeys {
+ public static final String INPUT_SCHEMA = "ws_input_schema";
+ public static final String OUTPUT_SCHEMA = "ws_output_schema";
+
+ private TransientStoreKeys() {
+ throw new AssertionError("Cannot instantiate a static utility class.");
+ }
+}
diff --git a/wrangler-core/src/test/java/io/cdap/wrangler/TestingPipelineContext.java b/wrangler-core/src/test/java/io/cdap/wrangler/TestingPipelineContext.java
index 53fd7c465..1bbfc4c33 100644
--- a/wrangler-core/src/test/java/io/cdap/wrangler/TestingPipelineContext.java
+++ b/wrangler-core/src/test/java/io/cdap/wrangler/TestingPipelineContext.java
@@ -24,46 +24,27 @@
import io.cdap.wrangler.api.ExecutorContext;
import io.cdap.wrangler.api.TransientStore;
import io.cdap.wrangler.proto.Contexts;
-import org.apache.commons.collections.map.HashedMap;
import java.net.URL;
import java.util.Collections;
+import java.util.HashMap;
import java.util.Map;
/**
* This class {@link TestingPipelineContext} is a runtime context that is provided for each
* {@link Executor} execution.
*/
-class TestingPipelineContext implements ExecutorContext {
- private StageMetrics metrics;
- private String name;
- private TransientStore store;
- private Map properties;
-
- TestingPipelineContext() {
- properties = new HashedMap();
+public class TestingPipelineContext implements ExecutorContext {
+ private final StageMetrics metrics;
+ private final String name;
+ private final TransientStore store;
+ private final Map properties;
+
+ public TestingPipelineContext() {
+ name = "testing";
+ properties = new HashMap<>();
store = new DefaultTransientStore();
- }
-
- /**
- * @return Environment this context is prepared for.
- */
- @Override
- public Environment getEnvironment() {
- return Environment.TESTING;
- }
-
- @Override
- public String getNamespace() {
- return Contexts.SYSTEM;
- }
-
- /**
- * @return Measurements context.
- */
- @Override
- public StageMetrics getMetrics() {
- return new StageMetrics() {
+ metrics = new StageMetrics() {
@Override
public void count(String s, int i) {
@@ -96,12 +77,33 @@ public Map getTags() {
};
}
+ /**
+ * @return Environment this context is prepared for.
+ */
+ @Override
+ public Environment getEnvironment() {
+ return Environment.TESTING;
+ }
+
+ @Override
+ public String getNamespace() {
+ return Contexts.SYSTEM;
+ }
+
+ /**
+ * @return Measurements context.
+ */
+ @Override
+ public StageMetrics getMetrics() {
+ return metrics;
+ }
+
/**
* @return Context name.
*/
@Override
public String getContextName() {
- return "testing";
+ return name;
}
/**
diff --git a/wrangler-core/src/test/java/io/cdap/wrangler/TestingRig.java b/wrangler-core/src/test/java/io/cdap/wrangler/TestingRig.java
index 9a5712b5d..10a6da4e2 100644
--- a/wrangler-core/src/test/java/io/cdap/wrangler/TestingRig.java
+++ b/wrangler-core/src/test/java/io/cdap/wrangler/TestingRig.java
@@ -60,7 +60,7 @@ private TestingRig() {
*/
public static List execute(String[] recipe, List rows)
throws RecipeException, DirectiveParseException, DirectiveLoadException {
- return execute(recipe, rows, null);
+ return execute(recipe, rows, new TestingPipelineContext());
}
public static List execute(String[] recipe, List rows, ExecutorContext context)
@@ -83,7 +83,7 @@ public static List execute(String[] recipe, List rows, ExecutorContext
*/
public static Pair, List> executeWithErrors(String[] recipe, List rows)
throws RecipeException, DirectiveParseException, DirectiveLoadException, DirectiveNotFoundException {
- return executeWithErrors(recipe, rows, null);
+ return executeWithErrors(recipe, rows, new TestingPipelineContext());
}
public static Pair, List> executeWithErrors(String[] recipe, List rows, ExecutorContext context)
diff --git a/wrangler-core/src/test/java/io/cdap/wrangler/executor/RecipePipelineExecutorTest.java b/wrangler-core/src/test/java/io/cdap/wrangler/executor/RecipePipelineExecutorTest.java
index 53f1b9e78..8b858d50c 100644
--- a/wrangler-core/src/test/java/io/cdap/wrangler/executor/RecipePipelineExecutorTest.java
+++ b/wrangler-core/src/test/java/io/cdap/wrangler/executor/RecipePipelineExecutorTest.java
@@ -18,13 +18,21 @@
import io.cdap.cdap.api.data.format.StructuredRecord;
import io.cdap.cdap.api.data.schema.Schema;
+import io.cdap.wrangler.TestingPipelineContext;
import io.cdap.wrangler.TestingRig;
+import io.cdap.wrangler.api.ExecutorContext;
import io.cdap.wrangler.api.RecipePipeline;
import io.cdap.wrangler.api.Row;
+import io.cdap.wrangler.api.TransientVariableScope;
+import io.cdap.wrangler.schema.TransientStoreKeys;
import org.junit.Assert;
import org.junit.Test;
+import java.math.BigDecimal;
+import java.util.ArrayList;
import java.util.Arrays;
+import java.util.Collections;
+import java.util.List;
/**
* Tests {@link RecipePipelineExecutor}.
@@ -54,8 +62,8 @@ public void testPipeline() throws Exception {
RecipePipeline pipeline = TestingRig.execute(commands);
- Row row = new Row("__col", new String("a,b,c,d,e,f,1.0"));
- StructuredRecord record = (StructuredRecord) pipeline.execute(Arrays.asList(row), schema).get(0);
+ Row row = new Row("__col", "a,b,c,d,e,f,1.0");
+ StructuredRecord record = (StructuredRecord) pipeline.execute(Collections.singletonList(row), schema).get(0);
// Validate the {@link StructuredRecord}
Assert.assertEquals("a", record.get("first"));
@@ -86,8 +94,8 @@ public void testPipelineWithMoreSimpleTypes() throws Exception {
);
RecipePipeline pipeline = TestingRig.execute(commands);
- Row row = new Row("__col", new String("Larry,Perez,lperezqt@umn.edu,1481666448,186.66"));
- StructuredRecord record = (StructuredRecord) pipeline.execute(Arrays.asList(row), schema).get(0);
+ Row row = new Row("__col", "Larry,Perez,lperezqt@umn.edu,1481666448,186.66");
+ StructuredRecord record = (StructuredRecord) pipeline.execute(Collections.singletonList(row), schema).get(0);
// Validate the {@link StructuredRecord}
Assert.assertEquals("Larry", record.get("first"));
@@ -96,4 +104,99 @@ public void testPipelineWithMoreSimpleTypes() throws Exception {
Assert.assertEquals(1481666448L, record.get("timestamp").longValue());
Assert.assertEquals(186.66f, record.get("weight"), 0.0001f);
}
+
+ @Test
+ public void testOutputSchemaGeneration() throws Exception {
+ String[] commands = new String[]{
+ "parse-as-csv :body ,",
+ "drop :body",
+ "set-headers :decimal_col,:name,:timestamp,:weight,:date",
+ "set-type :timestamp double",
+ };
+ Schema inputSchema = Schema.recordOf(
+ "input",
+ Schema.Field.of("body", Schema.of(Schema.Type.STRING)),
+ Schema.Field.of("decimal_col", Schema.decimalOf(10, 2))
+ );
+ Schema expectedSchema = Schema.recordOf(
+ "expected",
+ Schema.Field.of("decimal_col", Schema.nullableOf(Schema.decimalOf(10, 2))),
+ Schema.Field.of("name", Schema.nullableOf(Schema.of(Schema.Type.STRING))),
+ Schema.Field.of("timestamp", Schema.nullableOf(Schema.of(Schema.Type.DOUBLE))),
+ Schema.Field.of("weight", Schema.nullableOf(Schema.of(Schema.Type.STRING))),
+ Schema.Field.of("date", Schema.nullableOf(Schema.of(Schema.Type.STRING)))
+ );
+ List inputRows = new ArrayList<>();
+ inputRows.add(new Row("body", "Larry,1481666448,01/01/2000").add("decimal_col", new BigDecimal("123.45")));
+ inputRows.add(new Row("body", "Barry,,172.3,05/01/2000").add("decimal_col", new BigDecimal("234235456.0000")));
+ ExecutorContext context = new TestingPipelineContext();
+ context.getTransientStore().set(
+ TransientVariableScope.GLOBAL, TransientStoreKeys.INPUT_SCHEMA, inputSchema);
+
+ TestingRig.execute(commands, inputRows, context);
+ Schema outputSchema = context.getTransientStore().get(TransientStoreKeys.OUTPUT_SCHEMA);
+
+ for (Schema.Field field : expectedSchema.getFields()) {
+ Assert.assertEquals(field.getName(), outputSchema.getField(field.getName()).getName());
+ Assert.assertEquals(field.getSchema(), outputSchema.getField(field.getName()).getSchema());
+ }
+ }
+
+ @Test
+ public void testOutputSchemaGeneration_doesNotDropNullColumn() throws Exception {
+ Schema inputSchema = Schema.recordOf(
+ "input",
+ Schema.Field.of("id", Schema.of(Schema.Type.STRING)),
+ Schema.Field.of("null_col", Schema.of(Schema.Type.STRING))
+ );
+ String[] commands = new String[]{"set-type :id int"};
+ Schema expectedSchema = Schema.recordOf(
+ "expected",
+ Schema.Field.of("id", Schema.nullableOf(Schema.of(Schema.Type.INT))),
+ Schema.Field.of("null_col", Schema.nullableOf(Schema.of(Schema.Type.STRING)))
+ );
+ Row row = new Row();
+ row.add("id", "123");
+ row.add("null_col", null);
+ ExecutorContext context = new TestingPipelineContext();
+ context.getTransientStore().set(TransientVariableScope.GLOBAL, TransientStoreKeys.INPUT_SCHEMA, inputSchema);
+
+ TestingRig.execute(commands, Collections.singletonList(row), context);
+ Schema outputSchema = context.getTransientStore().get(TransientStoreKeys.OUTPUT_SCHEMA);
+
+ Assert.assertEquals(expectedSchema.getField("null_col").getSchema(), outputSchema.getField("null_col").getSchema());
+ }
+
+ @Test
+ public void testOutputSchemaGeneration_columnOrdering() throws Exception {
+ Schema inputSchema = Schema.recordOf(
+ "input",
+ Schema.Field.of("body", Schema.of(Schema.Type.STRING)),
+ Schema.Field.of("value", Schema.of(Schema.Type.INT))
+ );
+ String[] commands = new String[] {
+ "parse-as-json :body 1",
+ "set-type :value long"
+ };
+ List expectedFields = Arrays.asList(
+ Schema.Field.of("value", Schema.nullableOf(Schema.of(Schema.Type.LONG))),
+ Schema.Field.of("body_A", Schema.nullableOf(Schema.of(Schema.Type.LONG))),
+ Schema.Field.of("body_B", Schema.nullableOf(Schema.of(Schema.Type.STRING))),
+ Schema.Field.of("body_C", Schema.nullableOf(Schema.of(Schema.Type.DOUBLE)))
+ );
+ Row row1 = new Row().add("body", "{\"A\":1, \"B\":\"hello\"}").add("value", 10L);
+ Row row2 = new Row().add("body", "{\"C\":1.23, \"A\":1, \"B\":\"world\"}").add("value", 20L);
+ ExecutorContext context = new TestingPipelineContext();
+ context.getTransientStore().set(TransientVariableScope.GLOBAL, TransientStoreKeys.INPUT_SCHEMA, inputSchema);
+
+ TestingRig.execute(commands, Arrays.asList(row1, row2), context);
+ Schema outputSchema = context.getTransientStore().get(TransientStoreKeys.OUTPUT_SCHEMA);
+ List outputFields = outputSchema.getFields();
+
+ Assert.assertEquals(expectedFields.size(), outputFields.size());
+ for (int i = 0; i < expectedFields.size(); i++) {
+ Assert.assertEquals(expectedFields.get(i).getName(), outputFields.get(i).getName());
+ Assert.assertEquals(expectedFields.get(i).getSchema(), outputFields.get(i).getSchema());
+ }
+ }
}
diff --git a/wrangler-service/src/main/java/io/cdap/wrangler/service/directive/AbstractDirectiveHandler.java b/wrangler-service/src/main/java/io/cdap/wrangler/service/directive/AbstractDirectiveHandler.java
index 9ae32bc04..9080fbed5 100644
--- a/wrangler-service/src/main/java/io/cdap/wrangler/service/directive/AbstractDirectiveHandler.java
+++ b/wrangler-service/src/main/java/io/cdap/wrangler/service/directive/AbstractDirectiveHandler.java
@@ -32,6 +32,7 @@
import io.cdap.wrangler.api.RecipeException;
import io.cdap.wrangler.api.RecipeParser;
import io.cdap.wrangler.api.Row;
+import io.cdap.wrangler.api.TransientStore;
import io.cdap.wrangler.executor.RecipePipelineExecutor;
import io.cdap.wrangler.parser.ConfigDirectiveContext;
import io.cdap.wrangler.parser.GrammarBasedParser;
@@ -48,10 +49,10 @@
import io.cdap.wrangler.registry.DirectiveRegistry;
import io.cdap.wrangler.registry.SystemDirectiveRegistry;
import io.cdap.wrangler.registry.UserDirectiveRegistry;
+import io.cdap.wrangler.schema.TransientStoreKeys;
import io.cdap.wrangler.service.common.AbstractWranglerHandler;
import io.cdap.wrangler.statistics.BasicStatistics;
import io.cdap.wrangler.statistics.Statistics;
-import io.cdap.wrangler.utils.SchemaConverter;
import io.cdap.wrangler.validator.ColumnNameValidator;
import io.cdap.wrangler.validator.Validator;
import io.cdap.wrangler.validator.ValidatorException;
@@ -62,7 +63,7 @@
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
-import java.util.LinkedHashSet;
+import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
@@ -81,6 +82,7 @@ public class AbstractDirectiveHandler extends AbstractWranglerHandler {
protected static final String COLUMN_NAME = "body";
protected static final String RECORD_DELIMITER_HEADER = "recorddelimiter";
protected static final String DELIMITER_HEADER = "delimiter";
+ protected static final TransientStore TRANSIENT_STORE = new DefaultTransientStore();
protected DirectiveRegistry composite;
@@ -133,7 +135,7 @@ protected List executeDirectives(
try (RecipePipelineExecutor executor = new RecipePipelineExecutor(parser,
new ServicePipelineContext(
namespace, ExecutorContext.Environment.SERVICE,
- getContext(), new DefaultTransientStore()))) {
+ getContext(), TRANSIENT_STORE))) {
List result = executor.execute(sample);
List errors = executor.errors()
@@ -154,10 +156,19 @@ protected List executeDirectives(
protected DirectiveExecutionResponse generateExecutionResponse(
List rows, int limit) throws Exception {
List