Skip to content

Commit

Permalink
Merge pull request #657 from data-integrations/schema-dir-changes
Browse files Browse the repository at this point in the history
Add custom schema handling logic to directives that rename columns
  • Loading branch information
vanathi-g authored Sep 1, 2023
2 parents af5203d + bc562b5 commit 8564e2d
Show file tree
Hide file tree
Showing 6 changed files with 192 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,15 @@
import io.cdap.cdap.api.annotation.Description;
import io.cdap.cdap.api.annotation.Name;
import io.cdap.cdap.api.annotation.Plugin;
import io.cdap.cdap.api.data.schema.Schema;
import io.cdap.wrangler.api.Arguments;
import io.cdap.wrangler.api.Directive;
import io.cdap.wrangler.api.DirectiveExecutionException;
import io.cdap.wrangler.api.DirectiveParseException;
import io.cdap.wrangler.api.ExecutorContext;
import io.cdap.wrangler.api.Optional;
import io.cdap.wrangler.api.Row;
import io.cdap.wrangler.api.SchemaResolutionContext;
import io.cdap.wrangler.api.annotations.Categories;
import io.cdap.wrangler.api.lineage.Lineage;
import io.cdap.wrangler.api.lineage.Many;
Expand All @@ -35,6 +37,7 @@
import io.cdap.wrangler.api.parser.UsageDefinition;

import java.util.List;
import java.util.stream.Collectors;

/**
* This class <code>ChangeColCaseNames</code> converts the case of the columns
Expand Down Expand Up @@ -94,5 +97,20 @@ public Mutation lineage() {
.all(Many.of())
.build();
}
}

@Override
public Schema getOutputSchema(SchemaResolutionContext context) {
Schema inputSchema = context.getInputSchema();
return Schema.recordOf(
"outputSchema",
inputSchema.getFields().stream()
.map(
field -> {
String fieldName = toLower ? field.getName().toLowerCase() : field.getName().toUpperCase();
return Schema.Field.of(fieldName, field.getSchema());
}
)
.collect(Collectors.toList())
);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,14 @@
import io.cdap.cdap.api.annotation.Description;
import io.cdap.cdap.api.annotation.Name;
import io.cdap.cdap.api.annotation.Plugin;
import io.cdap.cdap.api.data.schema.Schema;
import io.cdap.wrangler.api.Arguments;
import io.cdap.wrangler.api.Directive;
import io.cdap.wrangler.api.DirectiveExecutionException;
import io.cdap.wrangler.api.DirectiveParseException;
import io.cdap.wrangler.api.ExecutorContext;
import io.cdap.wrangler.api.Row;
import io.cdap.wrangler.api.SchemaResolutionContext;
import io.cdap.wrangler.api.annotations.Categories;
import io.cdap.wrangler.api.lineage.Lineage;
import io.cdap.wrangler.api.lineage.Many;
Expand All @@ -36,6 +38,7 @@
import org.unix4j.builder.Unix4jCommandBuilder;

import java.util.List;
import java.util.stream.Collectors;

/**
* Applies a sed expression on the column names.
Expand Down Expand Up @@ -73,8 +76,7 @@ public List<Row> execute(List<Row> rows, ExecutorContext context) throws Directi
for (int i = 0; i < row.width(); ++i) {
String name = row.getColumn(i);
try {
Unix4jCommandBuilder builder = Unix4j.echo(name).sed(sed);
row.setColumn(i, builder.toStringResult());
row.setColumn(i, getSedReplacedColumnName(name));
} catch (IllegalArgumentException e) {
throw new DirectiveExecutionException(NAME, e.getMessage(), e);
}
Expand All @@ -90,5 +92,22 @@ public Mutation lineage() {
.all(Many.of())
.build();
}
}

@Override
public Schema getOutputSchema(SchemaResolutionContext context) {
Schema inputSchema = context.getInputSchema();
return Schema.recordOf(
"outputSchema",
inputSchema.getFields().stream()
.map(
field -> Schema.Field.of(getSedReplacedColumnName(field.getName()), field.getSchema())
)
.collect(Collectors.toList())
);
}

private String getSedReplacedColumnName(String colName) {
Unix4jCommandBuilder builder = Unix4j.echo(colName).sed(sed);
return builder.toStringResult();
}
}
24 changes: 23 additions & 1 deletion wrangler-core/src/main/java/io/cdap/directives/column/Copy.java
Original file line number Diff line number Diff line change
Expand Up @@ -19,21 +19,23 @@
import io.cdap.cdap.api.annotation.Description;
import io.cdap.cdap.api.annotation.Name;
import io.cdap.cdap.api.annotation.Plugin;
import io.cdap.cdap.api.data.schema.Schema;
import io.cdap.wrangler.api.Arguments;
import io.cdap.wrangler.api.Directive;
import io.cdap.wrangler.api.DirectiveExecutionException;
import io.cdap.wrangler.api.DirectiveParseException;
import io.cdap.wrangler.api.ExecutorContext;
import io.cdap.wrangler.api.Optional;
import io.cdap.wrangler.api.Row;
import io.cdap.wrangler.api.SchemaResolutionContext;
import io.cdap.wrangler.api.annotations.Categories;
import io.cdap.wrangler.api.lineage.Lineage;
import io.cdap.wrangler.api.lineage.Many;
import io.cdap.wrangler.api.lineage.Mutation;
import io.cdap.wrangler.api.parser.ColumnName;
import io.cdap.wrangler.api.parser.TokenType;
import io.cdap.wrangler.api.parser.UsageDefinition;

import java.util.ArrayList;
import java.util.List;

/**
Expand Down Expand Up @@ -110,4 +112,24 @@ public Mutation lineage() {
.conditional(source.value(), destination.value())
.build();
}

@Override
public Schema getOutputSchema(SchemaResolutionContext context) {
Schema inputSchema = context.getInputSchema();
List<Schema.Field> outputFields = new ArrayList<>();
Schema sourceSchema = inputSchema.getField(source.value()).getSchema();

for (Schema.Field field : inputSchema.getFields()) {
if (field.getName().equals(destination.value())) {
outputFields.add(Schema.Field.of(destination.value(), sourceSchema));
} else {
outputFields.add(field);
}
}
if (!force) {
outputFields.add(Schema.Field.of(destination.value(), sourceSchema));
}

return Schema.recordOf("outputSchema", outputFields);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,15 @@

package io.cdap.directives.column;

import io.cdap.cdap.api.data.schema.Schema;
import io.cdap.wrangler.TestingRig;
import io.cdap.wrangler.api.Row;
import org.junit.Assert;
import org.junit.Test;

import java.math.BigDecimal;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;

/**
Expand All @@ -44,4 +47,35 @@ public void testColumnCaseChanges() throws Exception {
Assert.assertTrue(rows.size() == 1);
Assert.assertEquals("url", rows.get(0).getColumn(0));
}

@Test
public void testGetOutputSchemaForCaseChangedCols() throws Exception {
String[] directives = new String[] {
"change-column-case lower",
};
List<Row> rows = Collections.singletonList(
new Row("ALL_CAPS", 1).add("MiXeD_CAse", "random").add("all_lower", new BigDecimal("143235.016"))
);
Schema inputSchema = Schema.recordOf(
"inputSchema",
Schema.Field.of("ALL_CAPS", Schema.of(Schema.Type.INT)),
Schema.Field.of("MiXeD_CAse", Schema.of(Schema.Type.STRING)),
Schema.Field.of("all_lower", Schema.decimalOf(10, 3))
);
Schema expectedSchema = Schema.recordOf(
"expectedSchema",
Schema.Field.of("all_caps", Schema.of(Schema.Type.INT)),
Schema.Field.of("mixed_case", Schema.of(Schema.Type.STRING)),
Schema.Field.of("all_lower", Schema.decimalOf(10, 3))
);

Schema outputSchema = TestingRig.executeAndGetSchema(directives, rows, inputSchema);

Assert.assertEquals(outputSchema.getFields().size(), expectedSchema.getFields().size());
for (Schema.Field expectedField : expectedSchema.getFields()) {
Assert.assertEquals(
outputSchema.getField(expectedField.getName()).getSchema().getType(), expectedField.getSchema().getType()
);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,16 @@

package io.cdap.directives.column;

import io.cdap.cdap.api.data.schema.Schema;
import io.cdap.wrangler.TestingRig;
import io.cdap.wrangler.api.RecipeException;
import io.cdap.wrangler.api.Row;
import org.junit.Assert;
import org.junit.Test;

import java.math.BigDecimal;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;

/**
Expand Down Expand Up @@ -65,4 +68,36 @@ public void testIncorrectSedExpression() throws Exception {

TestingRig.execute(directives, rows);
}
@Test
public void testGetOutputSchemaForReplacedColumnNames() throws Exception {
String[] directives = new String[] {
"columns-replace s/^data_//g",
};
List<Row> rows = Collections.singletonList(
new Row("data_a", 1).add("data_data_confuse", "ABC").add("no_data", null).add("random", new BigDecimal("12.44"))
);
Schema inputSchema = Schema.recordOf(
"inputSchema",
Schema.Field.of("data_a", Schema.of(Schema.Type.INT)),
Schema.Field.of("data_data_confuse", Schema.of(Schema.Type.STRING)),
Schema.Field.of("no_data", Schema.of(Schema.Type.DOUBLE)),
Schema.Field.of("random", Schema.decimalOf(10, 3))
);
Schema expectedSchema = Schema.recordOf(
"expectedSchema",
Schema.Field.of("a", Schema.of(Schema.Type.INT)),
Schema.Field.of("data_confuse", Schema.of(Schema.Type.STRING)),
Schema.Field.of("no_data", Schema.of(Schema.Type.DOUBLE)),
Schema.Field.of("random", Schema.decimalOf(10, 3))
);

Schema outputSchema = TestingRig.executeAndGetSchema(directives, rows, inputSchema);

Assert.assertEquals(outputSchema.getFields().size(), expectedSchema.getFields().size());
for (Schema.Field expectedField : expectedSchema.getFields()) {
Assert.assertEquals(
outputSchema.getField(expectedField.getName()).getSchema().getType(), expectedField.getSchema().getType()
);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,16 @@

package io.cdap.directives.column;

import io.cdap.cdap.api.data.schema.Schema;
import io.cdap.wrangler.TestingRig;
import io.cdap.wrangler.api.RecipeException;
import io.cdap.wrangler.api.Row;
import org.junit.Assert;
import org.junit.Test;

import java.math.BigDecimal;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;

/**
Expand Down Expand Up @@ -96,4 +99,60 @@ public void testForceCopy() throws Exception {
Assert.assertEquals(rows.get(2).getValue("body_2"), rows.get(2).getValue("body_1"));
}

@Test
public void testGetOutputSchemaForForceCopiedColumn() throws Exception {
String[] directives = new String[] {
"copy :col_B :col_A true",
};
List<Row> rows = Collections.singletonList(
new Row("col_A", 1).add("col_B", new BigDecimal("143235.016"))
);
Schema inputSchema = Schema.recordOf(
"inputSchema",
Schema.Field.of("col_A", Schema.of(Schema.Type.INT)),
Schema.Field.of("col_B", Schema.decimalOf(10, 3))
);
Schema expectedSchema = Schema.recordOf(
"expectedSchema",
Schema.Field.of("col_A", Schema.decimalOf(10, 3)),
Schema.Field.of("col_B", Schema.decimalOf(10, 3))
);

Schema outputSchema = TestingRig.executeAndGetSchema(directives, rows, inputSchema);

Assert.assertEquals(outputSchema.getFields().size(), expectedSchema.getFields().size());
for (Schema.Field expectedField : expectedSchema.getFields()) {
Assert.assertEquals(
outputSchema.getField(expectedField.getName()).getSchema().getType(), expectedField.getSchema().getType()
);
}
}

@Test
public void testGetOutputSchemaForCopiedColumn() throws Exception {
String[] directives = new String[] {
"copy :col_A :col_B",
};
List<Row> rows = Collections.singletonList(
new Row("col_A", new BigDecimal("143235.016"))
);
Schema inputSchema = Schema.recordOf(
"inputSchema",
Schema.Field.of("col_A", Schema.decimalOf(10, 3))
);
Schema expectedSchema = Schema.recordOf(
"expectedSchema",
Schema.Field.of("col_A", Schema.decimalOf(10, 3)),
Schema.Field.of("col_B", Schema.decimalOf(10, 3))
);

Schema outputSchema = TestingRig.executeAndGetSchema(directives, rows, inputSchema);

Assert.assertEquals(outputSchema.getFields().size(), expectedSchema.getFields().size());
for (Schema.Field expectedField : expectedSchema.getFields()) {
Assert.assertEquals(
outputSchema.getField(expectedField.getName()).getSchema().getType(), expectedField.getSchema().getType()
);
}
}
}

0 comments on commit 8564e2d

Please sign in to comment.