diff --git a/wrangler-api/src/main/java/io/cdap/wrangler/api/Executor.java b/wrangler-api/src/main/java/io/cdap/wrangler/api/Executor.java index abc726295..8c85319e4 100644 --- a/wrangler-api/src/main/java/io/cdap/wrangler/api/Executor.java +++ b/wrangler-api/src/main/java/io/cdap/wrangler/api/Executor.java @@ -16,9 +16,11 @@ package io.cdap.wrangler.api; +import io.cdap.cdap.api.data.schema.Schema; import io.cdap.wrangler.api.annotations.PublicEvolving; import java.io.Serializable; +import javax.annotation.Nullable; /** * A interface defining the wrangle Executor in the wrangling {@link RecipePipeline}. @@ -80,5 +82,19 @@ O execute(I rows, ExecutorContext context) * correct at this phase of invocation. */ void destroy(); -} + /** + * This method is used to get the updated schema of the data after the directive's transformation has been applied. + * + * @param schemaResolutionContext context containing necessary information for getting output schema + * @return output {@link Schema} of the transformed data + * @implNote By default, returns a null and the schema is inferred from the data when necessary. + *

For consistent handling, override for directives that perform column renames, + * column data type changes or column additions with specific schemas.

+ */ + @Nullable + default Schema getOutputSchema(SchemaResolutionContext schemaResolutionContext) { + // no op + return null; + } +} diff --git a/wrangler-api/src/main/java/io/cdap/wrangler/api/SchemaResolutionContext.java b/wrangler-api/src/main/java/io/cdap/wrangler/api/SchemaResolutionContext.java new file mode 100644 index 000000000..015f8bdc6 --- /dev/null +++ b/wrangler-api/src/main/java/io/cdap/wrangler/api/SchemaResolutionContext.java @@ -0,0 +1,29 @@ +/* + * Copyright © 2023 Cask Data, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package io.cdap.wrangler.api; + +import io.cdap.cdap.api.data.schema.Schema; + +/** + * Interface to pass contextual information related to getting or generating the output schema of a {@link Executor} + */ +public interface SchemaResolutionContext { + /** + * @return {@link Schema} of the input data before transformation + */ + Schema getInputSchema(); +} diff --git a/wrangler-core/src/main/java/io/cdap/wrangler/executor/RecipePipelineExecutor.java b/wrangler-core/src/main/java/io/cdap/wrangler/executor/RecipePipelineExecutor.java index 7202ec90a..a41206e31 100644 --- a/wrangler-core/src/main/java/io/cdap/wrangler/executor/RecipePipelineExecutor.java +++ b/wrangler-core/src/main/java/io/cdap/wrangler/executor/RecipePipelineExecutor.java @@ -30,8 +30,12 @@ import io.cdap.wrangler.api.ReportErrorAndProceed; import io.cdap.wrangler.api.Row; import io.cdap.wrangler.api.TransientVariableScope; +import io.cdap.wrangler.schema.DirectiveOutputSchemaGenerator; +import io.cdap.wrangler.schema.DirectiveSchemaResolutionContext; +import io.cdap.wrangler.schema.TransientStoreKeys; import io.cdap.wrangler.utils.RecordConvertor; import io.cdap.wrangler.utils.RecordConvertorException; +import io.cdap.wrangler.utils.SchemaConverter; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -48,6 +52,7 @@ public final class RecipePipelineExecutor implements RecipePipeline directives; @@ -103,6 +108,19 @@ public List execute(List rows) throws RecipeException { List results = new ArrayList<>(); int i = 0; int directiveIndex = 0; + // Initialize schema with input schema from TransientStore if running in service env (design-time) / testing env + boolean designTime = context != null && context.getEnvironment() != null && + (context.getEnvironment().equals(ExecutorContext.Environment.SERVICE) || + context.getEnvironment().equals(ExecutorContext.Environment.TESTING)); + Schema inputSchema = designTime ? context.getTransientStore().get(TransientStoreKeys.INPUT_SCHEMA) : null; + + List outputSchemaGenerators = new ArrayList<>(); + if (designTime && inputSchema != null) { + for (Directive directive : directives) { + outputSchemaGenerators.add(new DirectiveOutputSchemaGenerator(directive, generator)); + } + } + try { collector.reset(); while (i < rows.size()) { @@ -122,6 +140,9 @@ public List execute(List rows) throws RecipeException { if (cumulativeRows.size() < 1) { break; } + if (designTime && inputSchema != null) { + outputSchemaGenerators.get(directiveIndex - 1).addNewOutputFields(cumulativeRows); + } } catch (ReportErrorAndProceed e) { messages.add(String.format("%s (ecode: %d)", e.getMessage(), e.getCode())); collector @@ -142,6 +163,11 @@ public List execute(List rows) throws RecipeException { } catch (DirectiveExecutionException e) { throw new RecipeException(e.getMessage(), e, i, directiveIndex); } + // Schema generation + if (designTime && inputSchema != null) { + context.getTransientStore().set(TransientVariableScope.GLOBAL, TransientStoreKeys.OUTPUT_SCHEMA, + getOutputSchema(inputSchema, outputSchemaGenerators)); + } return results; } @@ -161,4 +187,17 @@ private List getDirectives() throws RecipeException { } return directives; } + + private Schema getOutputSchema(Schema inputSchema, List outputSchemaGenerators) + throws RecipeException { + Schema schema = inputSchema; + for (DirectiveOutputSchemaGenerator outputSchemaGenerator : outputSchemaGenerators) { + try { + schema = outputSchemaGenerator.getDirectiveOutputSchema(new DirectiveSchemaResolutionContext(schema)); + } catch (RecordConvertorException e) { + throw new RecipeException("Error while generating output schema for a directive: " + e, e); + } + } + return schema; + } } diff --git a/wrangler-core/src/main/java/io/cdap/wrangler/schema/DirectiveOutputSchemaGenerator.java b/wrangler-core/src/main/java/io/cdap/wrangler/schema/DirectiveOutputSchemaGenerator.java new file mode 100644 index 000000000..980780f04 --- /dev/null +++ b/wrangler-core/src/main/java/io/cdap/wrangler/schema/DirectiveOutputSchemaGenerator.java @@ -0,0 +1,119 @@ +/* + * Copyright © 2023 Cask Data, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package io.cdap.wrangler.schema; + +import io.cdap.cdap.api.data.schema.Schema; +import io.cdap.wrangler.api.Directive; +import io.cdap.wrangler.api.Pair; +import io.cdap.wrangler.api.Row; +import io.cdap.wrangler.api.SchemaResolutionContext; +import io.cdap.wrangler.utils.RecordConvertorException; +import io.cdap.wrangler.utils.SchemaConverter; + +import java.util.ArrayList; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import javax.annotation.Nullable; + +/** + * This class can be used to generate the output schema for the output data of a directive. It maintains a map of + * output fields present across all output rows after applying a directive. This map is used to generate the schema + * if the directive does not return a custom output schema. + */ +public class DirectiveOutputSchemaGenerator { + private final SchemaConverter schemaGenerator; + private final Map outputFieldMap; + private final Directive directive; + + public DirectiveOutputSchemaGenerator(Directive directive, SchemaConverter schemaGenerator) { + this.directive = directive; + this.schemaGenerator = schemaGenerator; + outputFieldMap = new LinkedHashMap<>(); + } + + /** + * Method to add new fields from the given output to the map of fieldName --> value maintained for schema generation. + * A value is added to the map only if it is absent (or) if the existing value is null and given value is non-null + * @param output list of output {@link Row}s after applying directive. + */ + public void addNewOutputFields(List output) { + for (Row row : output) { + for (Pair field : row.getFields()) { + String fieldName = field.getFirst(); + Object fieldValue = field.getSecond(); + if (outputFieldMap.containsKey(fieldName)) { + // If existing value is null, override with this non-null value + if (fieldValue != null && outputFieldMap.get(fieldName) == null) { + outputFieldMap.put(fieldName, fieldValue); + } + } else { + outputFieldMap.put(fieldName, fieldValue); + } + } + } + } + + /** + * Method to get the output schema of the directive. Returns a generated schema based on maintained map of fields + * only if directive does not return a custom output schema. + * @param context input {@link Schema} of the data before applying the directive + * @return {@link Schema} corresponding to the output data + */ + public Schema getDirectiveOutputSchema(SchemaResolutionContext context) throws RecordConvertorException { + Schema directiveOutputSchema = directive.getOutputSchema(context); + return directiveOutputSchema != null ? directiveOutputSchema : + generateDirectiveOutputSchema(context.getInputSchema()); + } + + // Given the schema from previous step and output of current directive, generates the directive output schema. + private Schema generateDirectiveOutputSchema(Schema inputSchema) + throws RecordConvertorException { + List outputFields = new ArrayList<>(); + for (Map.Entry field : outputFieldMap.entrySet()) { + String fieldName = field.getKey(); + Object fieldValue = field.getValue(); + + Schema existing = inputSchema.getField(fieldName) != null ? inputSchema.getField(fieldName).getSchema() : null; + Schema generated = fieldValue != null && !isValidSchemaForValue(existing, fieldValue) ? + schemaGenerator.getSchema(fieldValue, fieldName) : null; + + if (generated != null) { + outputFields.add(Schema.Field.of(fieldName, generated)); + } else if (existing != null) { + if (!existing.isNullable()) { + existing = Schema.nullableOf(existing); + } + outputFields.add(Schema.Field.of(fieldName, existing)); + } else { + outputFields.add(Schema.Field.of(fieldName, Schema.of(Schema.Type.NULL))); + } + } + return Schema.recordOf("output", outputFields); + } + + // Checks whether the provided input schema is of valid type for given object + private boolean isValidSchemaForValue(@Nullable Schema schema, Object value) throws RecordConvertorException { + if (schema == null) { + return false; + } + Schema generated = schemaGenerator.getSchema(value, "temp_field_name"); + generated = generated.isNullable() ? generated.getNonNullable() : generated; + schema = schema.isNullable() ? schema.getNonNullable() : schema; + return generated.getLogicalType() == schema.getLogicalType() && generated.getType() == schema.getType(); + } +} diff --git a/wrangler-core/src/main/java/io/cdap/wrangler/schema/DirectiveSchemaResolutionContext.java b/wrangler-core/src/main/java/io/cdap/wrangler/schema/DirectiveSchemaResolutionContext.java new file mode 100644 index 000000000..9c4c702fb --- /dev/null +++ b/wrangler-core/src/main/java/io/cdap/wrangler/schema/DirectiveSchemaResolutionContext.java @@ -0,0 +1,36 @@ +/* + * Copyright © 2023 Cask Data, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package io.cdap.wrangler.schema; + +import io.cdap.cdap.api.data.schema.Schema; +import io.cdap.wrangler.api.Directive; +import io.cdap.wrangler.api.SchemaResolutionContext; + +/** + * Context to pass information related to getting or generating the output schema of a {@link Directive} + */ +public class DirectiveSchemaResolutionContext implements SchemaResolutionContext { + private final Schema inputSchema; + public DirectiveSchemaResolutionContext(Schema inputSchema) { + this.inputSchema = inputSchema; + } + + @Override + public Schema getInputSchema() { + return inputSchema; + } +} diff --git a/wrangler-core/src/main/java/io/cdap/wrangler/schema/TransientStoreKeys.java b/wrangler-core/src/main/java/io/cdap/wrangler/schema/TransientStoreKeys.java new file mode 100644 index 000000000..e35ef803f --- /dev/null +++ b/wrangler-core/src/main/java/io/cdap/wrangler/schema/TransientStoreKeys.java @@ -0,0 +1,29 @@ +/* + * Copyright © 2023 Cask Data, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package io.cdap.wrangler.schema; + +/** + * TransientStoreKeys for storing Workspace schema in TransientStore + */ +public final class TransientStoreKeys { + public static final String INPUT_SCHEMA = "ws_input_schema"; + public static final String OUTPUT_SCHEMA = "ws_output_schema"; + + private TransientStoreKeys() { + throw new AssertionError("Cannot instantiate a static utility class."); + } +} diff --git a/wrangler-core/src/test/java/io/cdap/wrangler/TestingPipelineContext.java b/wrangler-core/src/test/java/io/cdap/wrangler/TestingPipelineContext.java index 53fd7c465..1bbfc4c33 100644 --- a/wrangler-core/src/test/java/io/cdap/wrangler/TestingPipelineContext.java +++ b/wrangler-core/src/test/java/io/cdap/wrangler/TestingPipelineContext.java @@ -24,46 +24,27 @@ import io.cdap.wrangler.api.ExecutorContext; import io.cdap.wrangler.api.TransientStore; import io.cdap.wrangler.proto.Contexts; -import org.apache.commons.collections.map.HashedMap; import java.net.URL; import java.util.Collections; +import java.util.HashMap; import java.util.Map; /** * This class {@link TestingPipelineContext} is a runtime context that is provided for each * {@link Executor} execution. */ -class TestingPipelineContext implements ExecutorContext { - private StageMetrics metrics; - private String name; - private TransientStore store; - private Map properties; - - TestingPipelineContext() { - properties = new HashedMap(); +public class TestingPipelineContext implements ExecutorContext { + private final StageMetrics metrics; + private final String name; + private final TransientStore store; + private final Map properties; + + public TestingPipelineContext() { + name = "testing"; + properties = new HashMap<>(); store = new DefaultTransientStore(); - } - - /** - * @return Environment this context is prepared for. - */ - @Override - public Environment getEnvironment() { - return Environment.TESTING; - } - - @Override - public String getNamespace() { - return Contexts.SYSTEM; - } - - /** - * @return Measurements context. - */ - @Override - public StageMetrics getMetrics() { - return new StageMetrics() { + metrics = new StageMetrics() { @Override public void count(String s, int i) { @@ -96,12 +77,33 @@ public Map getTags() { }; } + /** + * @return Environment this context is prepared for. + */ + @Override + public Environment getEnvironment() { + return Environment.TESTING; + } + + @Override + public String getNamespace() { + return Contexts.SYSTEM; + } + + /** + * @return Measurements context. + */ + @Override + public StageMetrics getMetrics() { + return metrics; + } + /** * @return Context name. */ @Override public String getContextName() { - return "testing"; + return name; } /** diff --git a/wrangler-core/src/test/java/io/cdap/wrangler/TestingRig.java b/wrangler-core/src/test/java/io/cdap/wrangler/TestingRig.java index 9a5712b5d..10a6da4e2 100644 --- a/wrangler-core/src/test/java/io/cdap/wrangler/TestingRig.java +++ b/wrangler-core/src/test/java/io/cdap/wrangler/TestingRig.java @@ -60,7 +60,7 @@ private TestingRig() { */ public static List execute(String[] recipe, List rows) throws RecipeException, DirectiveParseException, DirectiveLoadException { - return execute(recipe, rows, null); + return execute(recipe, rows, new TestingPipelineContext()); } public static List execute(String[] recipe, List rows, ExecutorContext context) @@ -83,7 +83,7 @@ public static List execute(String[] recipe, List rows, ExecutorContext */ public static Pair, List> executeWithErrors(String[] recipe, List rows) throws RecipeException, DirectiveParseException, DirectiveLoadException, DirectiveNotFoundException { - return executeWithErrors(recipe, rows, null); + return executeWithErrors(recipe, rows, new TestingPipelineContext()); } public static Pair, List> executeWithErrors(String[] recipe, List rows, ExecutorContext context) diff --git a/wrangler-core/src/test/java/io/cdap/wrangler/executor/RecipePipelineExecutorTest.java b/wrangler-core/src/test/java/io/cdap/wrangler/executor/RecipePipelineExecutorTest.java index 53f1b9e78..8b858d50c 100644 --- a/wrangler-core/src/test/java/io/cdap/wrangler/executor/RecipePipelineExecutorTest.java +++ b/wrangler-core/src/test/java/io/cdap/wrangler/executor/RecipePipelineExecutorTest.java @@ -18,13 +18,21 @@ import io.cdap.cdap.api.data.format.StructuredRecord; import io.cdap.cdap.api.data.schema.Schema; +import io.cdap.wrangler.TestingPipelineContext; import io.cdap.wrangler.TestingRig; +import io.cdap.wrangler.api.ExecutorContext; import io.cdap.wrangler.api.RecipePipeline; import io.cdap.wrangler.api.Row; +import io.cdap.wrangler.api.TransientVariableScope; +import io.cdap.wrangler.schema.TransientStoreKeys; import org.junit.Assert; import org.junit.Test; +import java.math.BigDecimal; +import java.util.ArrayList; import java.util.Arrays; +import java.util.Collections; +import java.util.List; /** * Tests {@link RecipePipelineExecutor}. @@ -54,8 +62,8 @@ public void testPipeline() throws Exception { RecipePipeline pipeline = TestingRig.execute(commands); - Row row = new Row("__col", new String("a,b,c,d,e,f,1.0")); - StructuredRecord record = (StructuredRecord) pipeline.execute(Arrays.asList(row), schema).get(0); + Row row = new Row("__col", "a,b,c,d,e,f,1.0"); + StructuredRecord record = (StructuredRecord) pipeline.execute(Collections.singletonList(row), schema).get(0); // Validate the {@link StructuredRecord} Assert.assertEquals("a", record.get("first")); @@ -86,8 +94,8 @@ public void testPipelineWithMoreSimpleTypes() throws Exception { ); RecipePipeline pipeline = TestingRig.execute(commands); - Row row = new Row("__col", new String("Larry,Perez,lperezqt@umn.edu,1481666448,186.66")); - StructuredRecord record = (StructuredRecord) pipeline.execute(Arrays.asList(row), schema).get(0); + Row row = new Row("__col", "Larry,Perez,lperezqt@umn.edu,1481666448,186.66"); + StructuredRecord record = (StructuredRecord) pipeline.execute(Collections.singletonList(row), schema).get(0); // Validate the {@link StructuredRecord} Assert.assertEquals("Larry", record.get("first")); @@ -96,4 +104,99 @@ public void testPipelineWithMoreSimpleTypes() throws Exception { Assert.assertEquals(1481666448L, record.get("timestamp").longValue()); Assert.assertEquals(186.66f, record.get("weight"), 0.0001f); } + + @Test + public void testOutputSchemaGeneration() throws Exception { + String[] commands = new String[]{ + "parse-as-csv :body ,", + "drop :body", + "set-headers :decimal_col,:name,:timestamp,:weight,:date", + "set-type :timestamp double", + }; + Schema inputSchema = Schema.recordOf( + "input", + Schema.Field.of("body", Schema.of(Schema.Type.STRING)), + Schema.Field.of("decimal_col", Schema.decimalOf(10, 2)) + ); + Schema expectedSchema = Schema.recordOf( + "expected", + Schema.Field.of("decimal_col", Schema.nullableOf(Schema.decimalOf(10, 2))), + Schema.Field.of("name", Schema.nullableOf(Schema.of(Schema.Type.STRING))), + Schema.Field.of("timestamp", Schema.nullableOf(Schema.of(Schema.Type.DOUBLE))), + Schema.Field.of("weight", Schema.nullableOf(Schema.of(Schema.Type.STRING))), + Schema.Field.of("date", Schema.nullableOf(Schema.of(Schema.Type.STRING))) + ); + List inputRows = new ArrayList<>(); + inputRows.add(new Row("body", "Larry,1481666448,01/01/2000").add("decimal_col", new BigDecimal("123.45"))); + inputRows.add(new Row("body", "Barry,,172.3,05/01/2000").add("decimal_col", new BigDecimal("234235456.0000"))); + ExecutorContext context = new TestingPipelineContext(); + context.getTransientStore().set( + TransientVariableScope.GLOBAL, TransientStoreKeys.INPUT_SCHEMA, inputSchema); + + TestingRig.execute(commands, inputRows, context); + Schema outputSchema = context.getTransientStore().get(TransientStoreKeys.OUTPUT_SCHEMA); + + for (Schema.Field field : expectedSchema.getFields()) { + Assert.assertEquals(field.getName(), outputSchema.getField(field.getName()).getName()); + Assert.assertEquals(field.getSchema(), outputSchema.getField(field.getName()).getSchema()); + } + } + + @Test + public void testOutputSchemaGeneration_doesNotDropNullColumn() throws Exception { + Schema inputSchema = Schema.recordOf( + "input", + Schema.Field.of("id", Schema.of(Schema.Type.STRING)), + Schema.Field.of("null_col", Schema.of(Schema.Type.STRING)) + ); + String[] commands = new String[]{"set-type :id int"}; + Schema expectedSchema = Schema.recordOf( + "expected", + Schema.Field.of("id", Schema.nullableOf(Schema.of(Schema.Type.INT))), + Schema.Field.of("null_col", Schema.nullableOf(Schema.of(Schema.Type.STRING))) + ); + Row row = new Row(); + row.add("id", "123"); + row.add("null_col", null); + ExecutorContext context = new TestingPipelineContext(); + context.getTransientStore().set(TransientVariableScope.GLOBAL, TransientStoreKeys.INPUT_SCHEMA, inputSchema); + + TestingRig.execute(commands, Collections.singletonList(row), context); + Schema outputSchema = context.getTransientStore().get(TransientStoreKeys.OUTPUT_SCHEMA); + + Assert.assertEquals(expectedSchema.getField("null_col").getSchema(), outputSchema.getField("null_col").getSchema()); + } + + @Test + public void testOutputSchemaGeneration_columnOrdering() throws Exception { + Schema inputSchema = Schema.recordOf( + "input", + Schema.Field.of("body", Schema.of(Schema.Type.STRING)), + Schema.Field.of("value", Schema.of(Schema.Type.INT)) + ); + String[] commands = new String[] { + "parse-as-json :body 1", + "set-type :value long" + }; + List expectedFields = Arrays.asList( + Schema.Field.of("value", Schema.nullableOf(Schema.of(Schema.Type.LONG))), + Schema.Field.of("body_A", Schema.nullableOf(Schema.of(Schema.Type.LONG))), + Schema.Field.of("body_B", Schema.nullableOf(Schema.of(Schema.Type.STRING))), + Schema.Field.of("body_C", Schema.nullableOf(Schema.of(Schema.Type.DOUBLE))) + ); + Row row1 = new Row().add("body", "{\"A\":1, \"B\":\"hello\"}").add("value", 10L); + Row row2 = new Row().add("body", "{\"C\":1.23, \"A\":1, \"B\":\"world\"}").add("value", 20L); + ExecutorContext context = new TestingPipelineContext(); + context.getTransientStore().set(TransientVariableScope.GLOBAL, TransientStoreKeys.INPUT_SCHEMA, inputSchema); + + TestingRig.execute(commands, Arrays.asList(row1, row2), context); + Schema outputSchema = context.getTransientStore().get(TransientStoreKeys.OUTPUT_SCHEMA); + List outputFields = outputSchema.getFields(); + + Assert.assertEquals(expectedFields.size(), outputFields.size()); + for (int i = 0; i < expectedFields.size(); i++) { + Assert.assertEquals(expectedFields.get(i).getName(), outputFields.get(i).getName()); + Assert.assertEquals(expectedFields.get(i).getSchema(), outputFields.get(i).getSchema()); + } + } } diff --git a/wrangler-service/src/main/java/io/cdap/wrangler/service/directive/AbstractDirectiveHandler.java b/wrangler-service/src/main/java/io/cdap/wrangler/service/directive/AbstractDirectiveHandler.java index 9ae32bc04..9080fbed5 100644 --- a/wrangler-service/src/main/java/io/cdap/wrangler/service/directive/AbstractDirectiveHandler.java +++ b/wrangler-service/src/main/java/io/cdap/wrangler/service/directive/AbstractDirectiveHandler.java @@ -32,6 +32,7 @@ import io.cdap.wrangler.api.RecipeException; import io.cdap.wrangler.api.RecipeParser; import io.cdap.wrangler.api.Row; +import io.cdap.wrangler.api.TransientStore; import io.cdap.wrangler.executor.RecipePipelineExecutor; import io.cdap.wrangler.parser.ConfigDirectiveContext; import io.cdap.wrangler.parser.GrammarBasedParser; @@ -48,10 +49,10 @@ import io.cdap.wrangler.registry.DirectiveRegistry; import io.cdap.wrangler.registry.SystemDirectiveRegistry; import io.cdap.wrangler.registry.UserDirectiveRegistry; +import io.cdap.wrangler.schema.TransientStoreKeys; import io.cdap.wrangler.service.common.AbstractWranglerHandler; import io.cdap.wrangler.statistics.BasicStatistics; import io.cdap.wrangler.statistics.Statistics; -import io.cdap.wrangler.utils.SchemaConverter; import io.cdap.wrangler.validator.ColumnNameValidator; import io.cdap.wrangler.validator.Validator; import io.cdap.wrangler.validator.ValidatorException; @@ -62,7 +63,7 @@ import java.util.ArrayList; import java.util.HashMap; import java.util.HashSet; -import java.util.LinkedHashSet; +import java.util.LinkedHashMap; import java.util.List; import java.util.Map; import java.util.Set; @@ -81,6 +82,7 @@ public class AbstractDirectiveHandler extends AbstractWranglerHandler { protected static final String COLUMN_NAME = "body"; protected static final String RECORD_DELIMITER_HEADER = "recorddelimiter"; protected static final String DELIMITER_HEADER = "delimiter"; + protected static final TransientStore TRANSIENT_STORE = new DefaultTransientStore(); protected DirectiveRegistry composite; @@ -133,7 +135,7 @@ protected List executeDirectives( try (RecipePipelineExecutor executor = new RecipePipelineExecutor(parser, new ServicePipelineContext( namespace, ExecutorContext.Environment.SERVICE, - getContext(), new DefaultTransientStore()))) { + getContext(), TRANSIENT_STORE))) { List result = executor.execute(sample); List errors = executor.errors() @@ -154,10 +156,19 @@ protected List executeDirectives( protected DirectiveExecutionResponse generateExecutionResponse( List rows, int limit) throws Exception { List> values = new ArrayList<>(rows.size()); - Map types = new HashMap<>(); - Set headers = new LinkedHashSet<>(); - SchemaConverter convertor = new SchemaConverter(); - + Map types = new LinkedHashMap<>(); + + Schema outputSchema = TRANSIENT_STORE.get(TransientStoreKeys.OUTPUT_SCHEMA) != null ? + TRANSIENT_STORE.get(TransientStoreKeys.OUTPUT_SCHEMA) : TRANSIENT_STORE.get(TransientStoreKeys.INPUT_SCHEMA); + + for (Schema.Field field : outputSchema.getFields()) { + Schema schema = field.getSchema(); + schema = schema.isNullable() ? schema.getNonNullable() : schema; + String type = schema.getLogicalType() == null ? schema.getType().name() : schema.getLogicalType().name(); + // for backward compatibility, make the characters except the first one to lower case + type = type.substring(0, 1).toUpperCase() + type.substring(1).toLowerCase(); + types.put(field.getName(), type); + } // Iterate through all the new rows. for (Row row : rows) { // If output array has more than return result values, we terminate. @@ -170,20 +181,9 @@ protected DirectiveExecutionResponse generateExecutionResponse( // Iterate through all the fields of the row. for (Pair field : row.getFields()) { String fieldName = field.getFirst(); - headers.add(fieldName); Object object = field.getSecond(); if (object != null) { - Schema schema = convertor.getSchema(object, fieldName); - String type = object.getClass().getSimpleName(); - if (schema != null) { - schema = schema.isNullable() ? schema.getNonNullable() : schema; - type = schema.getLogicalType() == null ? schema.getType().name() : schema.getLogicalType().name(); - // for backward compatibility, make the characters except the first one to lower case - type = type.substring(0, 1).toUpperCase() + type.substring(1).toLowerCase(); - } - types.put(fieldName, type); - if ((object instanceof Iterable) || (object instanceof Row)) { value.put(fieldName, GSON.toJson(object)); @@ -201,7 +201,7 @@ protected DirectiveExecutionResponse generateExecutionResponse( } values.add(value); } - return new DirectiveExecutionResponse(values, headers, types, getWorkspaceSummary(rows)); + return new DirectiveExecutionResponse(values, types.keySet(), types, getWorkspaceSummary(rows)); } /** @@ -285,5 +285,4 @@ public static Row createUberRecord(List rows) { } return uber; } - } diff --git a/wrangler-service/src/main/java/io/cdap/wrangler/service/directive/WorkspaceHandler.java b/wrangler-service/src/main/java/io/cdap/wrangler/service/directive/WorkspaceHandler.java index 728ffeef6..f450ecd42 100644 --- a/wrangler-service/src/main/java/io/cdap/wrangler/service/directive/WorkspaceHandler.java +++ b/wrangler-service/src/main/java/io/cdap/wrangler/service/directive/WorkspaceHandler.java @@ -49,6 +49,7 @@ import io.cdap.wrangler.api.GrammarMigrator; import io.cdap.wrangler.api.RecipeException; import io.cdap.wrangler.api.Row; +import io.cdap.wrangler.api.TransientVariableScope; import io.cdap.wrangler.parser.ConfigDirectiveContext; import io.cdap.wrangler.parser.DirectiveClass; import io.cdap.wrangler.parser.GrammarWalker; @@ -73,10 +74,10 @@ import io.cdap.wrangler.proto.workspace.v2.WorkspaceUpdateRequest; import io.cdap.wrangler.registry.DirectiveInfo; import io.cdap.wrangler.registry.SystemDirectiveRegistry; +import io.cdap.wrangler.schema.TransientStoreKeys; import io.cdap.wrangler.store.recipe.RecipeStore; import io.cdap.wrangler.store.workspace.WorkspaceStore; import io.cdap.wrangler.utils.ObjectSerDe; -import io.cdap.wrangler.utils.SchemaConverter; import io.cdap.wrangler.utils.StructuredToRowTransformer; import org.apache.commons.lang3.StringEscapeUtils; @@ -400,14 +401,14 @@ public void specification(HttpServiceRequest request, HttpServiceResponder respo WorkspaceDetail detail = wsStore.getWorkspaceDetail(wsId); List directives = new ArrayList<>(detail.getWorkspace().getDirectives()); UserDirectivesCollector userDirectivesCollector = new UserDirectivesCollector(); - List result = executeDirectives(ns.getName(), directives, detail, - userDirectivesCollector); + executeDirectives(ns.getName(), directives, detail, userDirectivesCollector); userDirectivesCollector.addLoadDirectivesPragma(directives); - SchemaConverter schemaConvertor = new SchemaConverter(); + Schema outputSchema = TRANSIENT_STORE.get(TransientStoreKeys.OUTPUT_SCHEMA) != null ? + TRANSIENT_STORE.get(TransientStoreKeys.OUTPUT_SCHEMA) : TRANSIENT_STORE.get(TransientStoreKeys.INPUT_SCHEMA); + // check if the rows are empty before going to create a record schema, it will result in a 400 if empty fields // are passed to a record type schema - Schema schema = result.isEmpty() ? null : schemaConvertor.toSchema("record", createUberRecord(result)); Map properties = ImmutableMap.of("directives", String.join("\n", directives), "field", "*", "precondition", "false", @@ -417,8 +418,7 @@ public void specification(HttpServiceRequest request, HttpServiceResponder respo ArtifactSummary wrangler = composite.getLatestWranglerArtifact(); responder.sendString(GSON.toJson(new WorkspaceSpec( - srcSpecs, new StageSpec( - schema, new Plugin("Wrangler", "transform", properties, + srcSpecs, new StageSpec(outputSchema, new Plugin("Wrangler", "transform", properties, wrangler == null ? null : new Artifact(wrangler.getName(), wrangler.getVersion(), wrangler.getScope().name().toLowerCase())))))); @@ -540,6 +540,9 @@ private List executeDirectives(String namespace, GrammarWalker.Visitor grammarVisitor) throws Exception { // Remove all the #pragma from the existing directives. New ones will be generated. directives.removeIf(d -> PRAGMA_PATTERN.matcher(d).find()); + Schema inputSchema = detail.getWorkspace().getSampleSpec().getRelatedPlugins().iterator().next().getSchema(); + TRANSIENT_STORE.reset(TransientVariableScope.GLOBAL); + TRANSIENT_STORE.set(TransientVariableScope.GLOBAL, TransientStoreKeys.INPUT_SCHEMA, inputSchema); return getContext().isRemoteTaskEnabled() ? executeRemotely(namespace, directives, detail, grammarVisitor) :