Skip to content

Commit

Permalink
Refactor output schema generation
Browse files Browse the repository at this point in the history
  • Loading branch information
vanathi-g committed Jul 24, 2023
1 parent f596f7e commit 9c45337
Show file tree
Hide file tree
Showing 3 changed files with 132 additions and 153 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,16 +24,16 @@
import io.cdap.wrangler.api.ErrorRowException;
import io.cdap.wrangler.api.Executor;
import io.cdap.wrangler.api.ExecutorContext;
import io.cdap.wrangler.api.Pair;
import io.cdap.wrangler.api.RecipeException;
import io.cdap.wrangler.api.RecipeParser;
import io.cdap.wrangler.api.RecipePipeline;
import io.cdap.wrangler.api.ReportErrorAndProceed;
import io.cdap.wrangler.api.Row;
import io.cdap.wrangler.api.TransientVariableScope;
import io.cdap.wrangler.utils.OutputSchemaGenerator;
import io.cdap.wrangler.utils.DirectiveOutputSchemaGenerator;
import io.cdap.wrangler.utils.RecordConvertor;
import io.cdap.wrangler.utils.RecordConvertorException;
import io.cdap.wrangler.utils.SchemaConverter;
import io.cdap.wrangler.utils.TransientStoreKeys;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
Expand All @@ -51,6 +51,7 @@ public final class RecipePipelineExecutor implements RecipePipeline<Row, Structu

private final ErrorRecordCollector collector = new ErrorRecordCollector();
private final RecordConvertor convertor = new RecordConvertor();
private final SchemaConverter generator = new SchemaConverter();
private final RecipeParser recipeParser;
private final ExecutorContext context;
private List<Directive> directives;
Expand Down Expand Up @@ -112,8 +113,12 @@ public List<Row> execute(List<Row> rows) throws RecipeException {
context.getEnvironment().equals(ExecutorContext.Environment.TESTING));
Schema inputSchema = designTime ? context.getTransientStore().get(TransientStoreKeys.INPUT_SCHEMA) : null;

OutputSchemaGenerator outputSchemaGenerator = designTime && inputSchema != null ?
new OutputSchemaGenerator(inputSchema, directives) : null;
List<DirectiveOutputSchemaGenerator> outputSchemaGenerators = new ArrayList<>();
if (designTime && inputSchema != null) {
for (Directive directive : directives) {
outputSchemaGenerators.add(new DirectiveOutputSchemaGenerator(directive, generator));
}
}

try {
collector.reset();
Expand All @@ -135,9 +140,7 @@ public List<Row> execute(List<Row> rows) throws RecipeException {
break;
}
if (designTime && inputSchema != null) {
for (Pair<String, Object> field : getRowUnion(cumulativeRows).getFields()) {
outputSchemaGenerator.addDirectiveField(directiveIndex - 1, field.getFirst(), field.getSecond());
}
outputSchemaGenerators.get(directiveIndex - 1).addNewOutputFields(cumulativeRows);
}
} catch (ReportErrorAndProceed e) {
messages.add(String.format("%s (ecode: %d)", e.getMessage(), e.getCode()));
Expand All @@ -161,12 +164,8 @@ public List<Row> execute(List<Row> rows) throws RecipeException {
}
// Schema generation
if (designTime && inputSchema != null) {
try {
context.getTransientStore().set(TransientVariableScope.GLOBAL, TransientStoreKeys.OUTPUT_SCHEMA,
outputSchemaGenerator.generateOutputSchema());
} catch (RecordConvertorException e) {
throw new RuntimeException(e);
}
context.getTransientStore().set(TransientVariableScope.GLOBAL, TransientStoreKeys.OUTPUT_SCHEMA,
getOutputSchema(inputSchema, outputSchemaGenerators));
}
return results;
}
Expand All @@ -188,15 +187,16 @@ private List<Directive> getDirectives() throws RecipeException {
return directives;
}

public static Row getRowUnion(List<Row> rows) {
Row union = new Row();
for (Row row : rows) {
for (int i = 0; i < row.width(); ++i) {
if (union.find(row.getColumn(i)) == -1) {
union.add(row.getColumn(i), row.getValue(i));
}
private Schema getOutputSchema(Schema inputSchema, List<DirectiveOutputSchemaGenerator> outputSchemaGenerators)
throws RecipeException {
Schema schema = inputSchema;
for (DirectiveOutputSchemaGenerator outputSchemaGenerator : outputSchemaGenerators) {
try {
schema = outputSchemaGenerator.getDirectiveOutputSchema(schema);
} catch (RecordConvertorException e) {
throw new RecipeException("Error while generating output schema for a directive: " + e, e);
}
}
return union;
return schema;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
/*
* Copyright © 2023 Cask Data, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may not
* use this file except in compliance with the License. You may obtain a copy of
* the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations under
* the License.
*/

package io.cdap.wrangler.utils;

import io.cdap.cdap.api.data.schema.Schema;
import io.cdap.wrangler.api.Directive;
import io.cdap.wrangler.api.Pair;
import io.cdap.wrangler.api.Row;

import java.util.LinkedHashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import javax.annotation.Nullable;

/**
* This class can be used to generate the output schema for the output data of a directive. It maintains a map of
* output fields present across all output rows after applying a directive. This map is used to generate the schema
* if the directive does not return a custom output schema.
*/
public class DirectiveOutputSchemaGenerator {
private final SchemaConverter schemaGenerator;
private final Map<String, Object> outputFieldMap;
private final Directive directive;

public DirectiveOutputSchemaGenerator(Directive directive, SchemaConverter schemaGenerator) {
this.directive = directive;
this.schemaGenerator = schemaGenerator;
outputFieldMap = new LinkedHashMap<>();
}

/**
* Method to add new fields from the given output to the map of fieldName --> value maintained for schema generation.
* A value is added to the map only if it is absent (or) if the existing value is null and given value is non-null
* @param output list of output {@link Row}s after applying directive.
*/
public void addNewOutputFields(List<Row> output) {
for (Row row : output) {
for (Pair<String, Object> field : row.getFields()) {
String fieldName = field.getFirst();
Object fieldValue = field.getSecond();
if (outputFieldMap.containsKey(fieldName)) {
// If existing value is null, override with this non-null value
if (fieldValue != null && outputFieldMap.get(fieldName) == null) {
outputFieldMap.put(fieldName, fieldValue);
}
} else {
outputFieldMap.putIfAbsent(fieldName, fieldValue);
}
}
}
}

/**
* Method to get the output schema of the directive. Returns a generated schema based on maintained map of fields
* only if directive does not return a custom output schema.
* @param inputSchema input {@link Schema} of the data before applying the directive
* @return {@link Schema} corresponding to the output data
*/
public Schema getDirectiveOutputSchema(Schema inputSchema) throws RecordConvertorException {
Schema directiveOutputSchema = directive.getOutputSchema(inputSchema);
return directiveOutputSchema != null ? directiveOutputSchema : generateDirectiveOutputSchema(inputSchema);
}

// Given the schema from previous step and output of current directive, generates the directive output schema.
private Schema generateDirectiveOutputSchema(Schema inputSchema)
throws RecordConvertorException {
List<Schema.Field> outputFields = new LinkedList<>();
for (String fieldName : outputFieldMap.keySet()) {
Object fieldValue = outputFieldMap.get(fieldName);

Schema existing = inputSchema.getField(fieldName) != null ? inputSchema.getField(fieldName).getSchema() : null;
Schema generated = fieldValue != null && !isValidSchemaForValue(existing, fieldValue) ?
schemaGenerator.getSchema(fieldValue, fieldName) : null;

if (generated != null) {
outputFields.add(Schema.Field.of(fieldName, generated));
} else if (existing != null) {
outputFields.add(Schema.Field.of(fieldName, existing));
} else {
outputFields.add(Schema.Field.of(fieldName, Schema.of(Schema.Type.NULL)));
}
}
return Schema.recordOf("output", outputFields);
}

// Checks whether the provided input schema is of valid type for given object
private boolean isValidSchemaForValue(@Nullable Schema schema, Object value) throws RecordConvertorException {
if (schema == null) {
return false;
}
Schema generated = schemaGenerator.getSchema(value, "temp_field_name");
generated = generated.isNullable() ? generated.getNonNullable() : generated;
schema = schema.isNullable() ? schema.getNonNullable() : schema;
return generated.getLogicalType() == schema.getLogicalType() && generated.getType() == schema.getType();
}
}

This file was deleted.

0 comments on commit 9c45337

Please sign in to comment.