Skip to content

Commit

Permalink
Add precision as an optional parameter in set-type
Browse files Browse the repository at this point in the history
  • Loading branch information
minurajeeve committed Nov 20, 2023
1 parent 297413a commit 2798dc4
Show file tree
Hide file tree
Showing 8 changed files with 164 additions and 44 deletions.
74 changes: 65 additions & 9 deletions wrangler-core/src/main/java/io/cdap/directives/column/SetType.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,14 @@
import io.cdap.cdap.api.annotation.Name;
import io.cdap.cdap.api.annotation.Plugin;
import io.cdap.cdap.api.data.schema.Schema;
import io.cdap.cdap.api.data.schema.Schema.LogicalType;
import io.cdap.wrangler.api.Arguments;
import io.cdap.wrangler.api.Directive;
import io.cdap.wrangler.api.DirectiveExecutionException;
import io.cdap.wrangler.api.DirectiveParseException;
import io.cdap.wrangler.api.ExecutorContext;
import io.cdap.wrangler.api.Optional;
import io.cdap.wrangler.api.Pair;
import io.cdap.wrangler.api.Row;
import io.cdap.wrangler.api.SchemaResolutionContext;
import io.cdap.wrangler.api.annotations.Categories;
Expand All @@ -40,26 +42,28 @@
import io.cdap.wrangler.utils.ColumnConverter;

import java.math.RoundingMode;
import java.util.HashMap;
import java.util.List;
import java.util.stream.Collectors;

/**
* A Wrangler step for converting data type of column
* Accepted types are: int, short, long, double, float, string, boolean and bytes
* When decimal type is selected, can also specify the scale and rounding mode
* When decimal type is selected, can also specify the scale, precision and rounding mode
*/
@Plugin(type = "directives")
@Name(SetType.NAME)
@Categories(categories = {"column"})
@Description("Converting data type of a column. Optional arguments scale and rounding-mode " +
"are used only when type is decimal.")
@Description("Converting data type of a column. Optional arguments scale, precision and "
+ "rounding-mode are used only when type is decimal.")
public final class SetType implements Directive, Lineage {
public static final String NAME = "set-type";

private String col;
private String type;
private Integer scale;
private RoundingMode roundingMode;
private Integer precision;

@Override
public UsageDefinition define() {
Expand All @@ -68,6 +72,7 @@ public UsageDefinition define() {
builder.define("type", TokenType.IDENTIFIER);
builder.define("scale", TokenType.NUMERIC, Optional.TRUE);
builder.define("rounding-mode", TokenType.TEXT, Optional.TRUE);
builder.define("precision", TokenType.PROPERTIES, "prop:{precision=<precision>}", Optional.TRUE);
return builder.build();
}

Expand All @@ -76,14 +81,19 @@ public void initialize(Arguments args) throws DirectiveParseException {
col = ((ColumnName) args.value("column")).value();
type = ((Identifier) args.value("type")).value();
if (type.equalsIgnoreCase("decimal")) {
precision = args.contains("precision") ? (Integer) ((HashMap<String, Numeric>) args.
value("precision").value()).get("precision").value().intValue() : null;
if (precision != null && precision < 1) {
throw new DirectiveParseException("precision cannot be less than 1");
}
scale = args.contains("scale") ? ((Numeric) args.value("scale")).value().intValue() : null;
if (scale == null && args.contains("rounding-mode")) {
throw new DirectiveParseException("'rounding-mode' can only be specified when a 'scale' is set");
if (scale == null && precision == null && args.contains("rounding-mode")) {
throw new DirectiveParseException("'rounding-mode' can only be specified when a 'scale' or 'precision' is set");
}
try {
roundingMode = args.contains("rounding-mode") ?
RoundingMode.valueOf(((Text) args.value("rounding-mode")).value()) :
(scale == null ? RoundingMode.UNNECESSARY : RoundingMode.HALF_EVEN);
(scale == null && precision == null ? RoundingMode.UNNECESSARY : RoundingMode.HALF_EVEN);
} catch (IllegalArgumentException e) {
throw new DirectiveParseException(String.format(
"Specified rounding-mode '%s' is not a valid Java rounding mode", args.value("rounding-mode").value()), e);
Expand All @@ -99,7 +109,7 @@ public void destroy() {
@Override
public List<Row> execute(List<Row> rows, ExecutorContext context) throws DirectiveExecutionException {
for (Row row : rows) {
ColumnConverter.convertType(NAME, row, col, type, scale, roundingMode);
ColumnConverter.convertType(NAME, row, col, type, scale, precision, roundingMode);
}
return rows;
}
Expand All @@ -121,8 +131,41 @@ public Schema getOutputSchema(SchemaResolutionContext context) {
.map(
field -> {
try {
return field.getName().equals(col) ?
Schema.Field.of(col, ColumnConverter.getSchemaForType(type, scale)) : field;
if (field.getName().equals(col)) {
Integer outputScale = scale;
Integer outputPrecision = precision;
Schema fieldSchema = field.getSchema().getNonNullable();
Pair<Integer, Integer> scaleAndPrecision = getPrecisionAndScale(fieldSchema);
Integer inputSchemaScale = scaleAndPrecision.getSecond();
Integer inputSchemaPrecision = scaleAndPrecision.getFirst();

if (scale == null && precision == null) {
outputScale = inputSchemaScale;
outputPrecision = inputSchemaPrecision;
} else if (scale == null && inputSchemaScale != null) {
if (precision - inputSchemaScale < 1) {
throw new DirectiveParseException(String.format(
"Cannot set scale as '%s' and precision as '%s' when "
+ "given precision - scale is less than 1 ", inputSchemaScale,
precision));
}
outputScale = inputSchemaScale;
outputPrecision = precision;

} else if (precision == null && inputSchemaPrecision != null) {
if (inputSchemaPrecision - scale < 1) {
throw new DirectiveParseException(String.format(
"Cannot set scale as '%s' and precision as '%s' when "
+ "given precision - scale is less than 1 ", scale,
inputSchemaPrecision));
}
outputScale = scale;
outputPrecision = inputSchemaPrecision;
}
return Schema.Field.of(col, ColumnConverter.getSchemaForType(type,
outputScale, outputPrecision));
}
return field;
} catch (DirectiveParseException e) {
throw new RuntimeException(e);
}
Expand All @@ -131,4 +174,17 @@ public Schema getOutputSchema(SchemaResolutionContext context) {
.collect(Collectors.toList())
);
}

/**
* extracts precision and scale from schema string
*/
public static Pair<Integer, Integer> getPrecisionAndScale(Schema fieldSchema) {
Integer precision = null;
Integer scale = null;
if (fieldSchema.getLogicalType() == LogicalType.DECIMAL) {
precision = fieldSchema.getPrecision();
scale = fieldSchema.getScale();
}
return new Pair<Integer, Integer>(precision, scale);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ public void initialize(Arguments args) throws DirectiveParseException {
@Override
public List<Row> execute(List<Row> rows, ExecutorContext context) throws DirectiveExecutionException {
for (Row row : rows) {
ColumnConverter.convertType(NAME, row, column, targetFieldTypeName, null, RoundingMode.UNNECESSARY);
ColumnConverter.convertType(NAME, row, column, targetFieldTypeName, null, null, RoundingMode.UNNECESSARY);
ColumnConverter.rename(NAME, row, column, targetFieldName);
}
return rows;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -126,13 +126,14 @@ public String migrate() throws DirectiveParseException {
}
break;

//set-type <column> <type> [<scale> <rounding-mode>]
//set-type <column> <type> [<scale> <rounding-mode> prop:{precision=<precision>}]
case "set-type": {
String col = getNextToken(tokenizer, command, "col", lineno);
String type = getNextToken(tokenizer, command, "type", lineno);
String scale = getNextToken(tokenizer, null, command, "scale", lineno, true);
String roundingMode = getNextToken(tokenizer, null, command, "rounding-mode", lineno, true);
transformed.add(String.format("set-type %s %s %s %s;", col(col), type, scale, roundingMode));
String precision = getNextToken(tokenizer, null, command, "precision", lineno, true);
transformed.add(String.format("set-type %s %s %s %s %s;", col(col), type, scale, roundingMode, precision));
}
break;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,13 @@

import io.cdap.cdap.api.common.Bytes;
import io.cdap.cdap.api.data.schema.Schema;
import io.cdap.cdap.api.data.schema.Schema.LogicalType;
import io.cdap.wrangler.api.DirectiveExecutionException;
import io.cdap.wrangler.api.DirectiveParseException;
import io.cdap.wrangler.api.Row;

import java.math.BigDecimal;
import java.math.MathContext;
import java.math.RoundingMode;
import java.util.Collections;
import java.util.HashMap;
Expand All @@ -45,7 +47,7 @@ private ColumnConverter() {
* @throws DirectiveExecutionException when a column matching the target name already exists
*/
public static void rename(String directiveName, Row row, String column, String toName)
throws DirectiveExecutionException {
throws DirectiveExecutionException {
int idx = row.find(column);
int existingColumn = row.find(toName);
if (idx == -1) {
Expand All @@ -57,9 +59,9 @@ public static void rename(String directiveName, Row row, String column, String t
row.setColumn(idx, toName);
} else {
throw new DirectiveExecutionException(
directiveName, String.format("Column '%s' already exists. Apply the 'drop %s' directive before " +
"renaming '%s' to '%s'.",
toName, toName, column, toName));
directiveName, String.format("Column '%s' already exists. Apply the 'drop %s' directive before " +
"renaming '%s' to '%s'.",
toName, toName, column, toName));
}
}

Expand All @@ -73,8 +75,8 @@ public static void rename(String directiveName, Row row, String column, String t
* @throws DirectiveExecutionException when an unsupported type is specified or the column can not be converted.
*/
public static void convertType(String directiveName, Row row, String column, String toType,
Integer scale, RoundingMode roundingMode)
throws DirectiveExecutionException {
Integer scale, Integer precision, RoundingMode roundingMode)
throws DirectiveExecutionException {
int idx = row.find(column);
if (idx != -1) {
Object object = row.getValue(idx);
Expand All @@ -84,21 +86,22 @@ public static void convertType(String directiveName, Row row, String column, Str
try {
Object converted = ColumnConverter.convertType(column, toType, object);
if (toType.equalsIgnoreCase(ColumnTypeNames.DECIMAL)) {
row.setValue(idx, setDecimalScale((BigDecimal) converted, scale, roundingMode));
row.setValue(idx, setDecimalScaleAndPrecision((BigDecimal) converted, scale,
precision, roundingMode));
} else {
row.setValue(idx, converted);
}
} catch (DirectiveExecutionException e) {
throw e;
} catch (Exception e) {
throw new DirectiveExecutionException(
directiveName, String.format("Column '%s' cannot be converted to a '%s'.", column, toType), e);
directiveName, String.format("Column '%s' cannot be converted to a '%s'.", column, toType), e);
}
}
}

private static Object convertType(String col, String toType, Object object)
throws Exception {
throws Exception {
toType = toType.toUpperCase();
switch (toType) {
case ColumnTypeNames.INTEGER:
Expand Down Expand Up @@ -291,38 +294,62 @@ private static Object convertType(String col, String toType, Object object)

default:
throw new DirectiveExecutionException(String.format(
"Column '%s' is of unsupported type '%s'. Supported types are: " +
"int, short, long, double, decimal, boolean, string, bytes", col, toType));
"Column '%s' is of unsupported type '%s'. Supported types are: " +
"int, short, long, double, decimal, boolean, string, bytes", col, toType));
}
throw new DirectiveExecutionException(
String.format("Column '%s' has value of type '%s' and cannot be converted to a '%s'.", col,
object.getClass().getSimpleName(), toType));
}

private static BigDecimal setDecimalScale(BigDecimal decimal, Integer scale, RoundingMode roundingMode)
throws DirectiveExecutionException {
if (scale == null) {
private static BigDecimal setDecimalScaleAndPrecision(BigDecimal decimal, Integer scale,
Integer precision, RoundingMode roundingMode)
throws DirectiveExecutionException {
if (scale == null && precision == null) {
return decimal;
}
try {
return decimal.setScale(scale, roundingMode);
if (precision == null) {
return decimal.setScale(scale, roundingMode);
} else if (scale == null) {
return decimal.round(new MathContext(precision, roundingMode));
} else {
BigDecimal result;
if (validateScaleAndPrecision(scale, precision, decimal)) {
result = decimal.setScale(scale, roundingMode);
result = result.round(new MathContext(precision, roundingMode));
} else {
throw new DirectiveExecutionException(String.format(
"Cannot set scale as '%s' and precision as '%s' for value '%s' when"
+ "given precision - scale is less than number of digits"
+ " before decimal point ", scale, precision, decimal));
}
return result;
}
} catch (ArithmeticException e) {
throw new DirectiveExecutionException(String.format(
"Cannot set scale as '%s' for value '%s' when rounding-mode is '%s'", scale, decimal, roundingMode), e);
"Cannot set scale as '%s' and precision '%s' for value '%s' when rounding-mode "
+ "is '%s'", scale, precision, decimal, roundingMode), e);
}
}

public static Schema getSchemaForType(String type, Integer scale) throws DirectiveParseException {
private static Boolean validateScaleAndPrecision(Integer scale, Integer precision, BigDecimal decimal) {
int digitsBeforeDecimalPoint = decimal.signum() == 0 ? 1 : decimal.precision() - decimal.scale();
return precision - scale >= digitsBeforeDecimalPoint;
}

public static Schema getSchemaForType(String type, Integer scale, Integer precision) throws DirectiveParseException {
Schema typeSchema;
type = type.toUpperCase();
if (type.equals(ColumnTypeNames.DECIMAL)) {
// TODO make set-type support setting decimal precision
precision = precision != null ? precision : 77;
scale = scale != null ? scale : 38;
typeSchema = Schema.nullableOf(Schema.decimalOf(77, scale));
typeSchema = Schema.nullableOf(Schema.decimalOf(precision, scale));
} else {
if (!SCHEMA_TYPE_MAP.containsKey(type)) {
throw new DirectiveParseException(String.format("'%s' is an unsupported type. " +
"Supported types are: int, short, long, double, decimal, boolean, string, bytes", type));
"Supported types are: int, short, long, double, decimal, boolean, string, bytes", type));
}
typeSchema = Schema.nullableOf(Schema.of(SCHEMA_TYPE_MAP.get(type)));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -213,13 +213,48 @@ public void testToDecimalInvalidRoundingMode() throws Exception {
TestingRig.execute(directives, rows);
}

@Test
public void testToDecimalWithScalePrecisionAndRoundingMode() throws Exception {
List<Row> rows = Collections.singletonList(new Row("scale_1_precision_4", "122.5")
.add("scale_3_precision_6", "456.789"));
String[] directives = new String[] {"set-type :scale_1_precision_4 decimal 0 'FLOOR' prop:{precision=3}",
"set-type :scale_3_precision_6 decimal 0 prop:{precision=5}"};
List<Row> results = TestingRig.execute(directives, rows);
Row row = results.get(0);

Assert.assertTrue(row.getValue(0) instanceof BigDecimal);
Assert.assertEquals(row.getValue(0), new BigDecimal("122"));

Assert.assertTrue(row.getValue(1) instanceof BigDecimal);
Assert.assertEquals(row.getValue(1), new BigDecimal("457"));
}

@Test
public void testToDecimalWithPrecision() throws Exception {
List<Row> rows = Collections.singletonList(new Row("scale_1_precision_4", "122.5"));
String[] directives = new String[] {"set-type :scale_1_precision_4 decimal 'FLOOR' prop:{precision=3}"};
List<Row> results = TestingRig.execute(directives, rows);
Row row = results.get(0);

Assert.assertTrue(row.getValue(0) instanceof BigDecimal);
Assert.assertEquals(row.getValue(0), new BigDecimal("122"));

}

@Test(expected = RecipeException.class)
public void testToDecimalWithInvalidPrecision() throws Exception {
List<Row> rows = Collections.singletonList(new Row("scale_1_precision_4", "122.5"));
String[] directives = new String[] {"set-type :scale_1_precision_4 decimal 0 'FLOOR' prop:{precision=-1}"};
TestingRig.execute(directives, rows);
}

@Test
public void testToDecimalScaleIsNull() throws Exception {
List<Row> rows = Collections.singletonList(new Row("scale_2", "125.45"));
String[] directives = new String[] {"set-type scale_2 decimal"};
Schema inputSchema = Schema.recordOf(
"inputSchema",
Schema.Field.of("scale_2", Schema.of(Schema.Type.DOUBLE))
Schema.Field.of("scale_2", Schema.nullableOf(Schema.of(Schema.Type.DOUBLE)))
);

Schema expectedSchema = Schema.recordOf(
Expand Down Expand Up @@ -377,14 +412,14 @@ public void testGetOutputSchemaForTypeChangedColumn() throws Exception {
.add("D", "random").add("E", 123).add("F", "true").add("G", 12L)
);
Schema inputSchema = Schema.recordOf(
"inputSchema",
Schema.Field.of("A", Schema.of(Schema.Type.STRING)),
Schema.Field.of("B", Schema.of(Schema.Type.STRING)),
Schema.Field.of("C", Schema.of(Schema.Type.STRING)),
Schema.Field.of("D", Schema.of(Schema.Type.STRING)),
Schema.Field.of("E", Schema.of(Schema.Type.INT)),
Schema.Field.of("F", Schema.of(Schema.Type.STRING)),
Schema.Field.of("G", Schema.of(Schema.Type.LONG))
"inputSchema",
Schema.Field.of("A", Schema.nullableOf(Schema.of(Schema.Type.STRING))),
Schema.Field.of("B", Schema.nullableOf(Schema.of(Schema.Type.STRING))),
Schema.Field.of("C", Schema.nullableOf(Schema.of(Schema.Type.STRING))),
Schema.Field.of("D", Schema.nullableOf(Schema.of(Schema.Type.STRING))),
Schema.Field.of("E", Schema.nullableOf(Schema.of(Schema.Type.INT))),
Schema.Field.of("F", Schema.nullableOf(Schema.of(Schema.Type.STRING))),
Schema.Field.of("G", Schema.nullableOf(Schema.of(Schema.Type.LONG)))
);
Schema expectedSchema = Schema.recordOf(
"expectedSchema",
Expand Down
Loading

0 comments on commit 2798dc4

Please sign in to comment.