diff --git a/athena-dynamodb/src/main/java/com/amazonaws/athena/connectors/dynamodb/util/DDBPredicateUtils.java b/athena-dynamodb/src/main/java/com/amazonaws/athena/connectors/dynamodb/util/DDBPredicateUtils.java index 3ca4612c9d..a23373ce55 100644 --- a/athena-dynamodb/src/main/java/com/amazonaws/athena/connectors/dynamodb/util/DDBPredicateUtils.java +++ b/athena-dynamodb/src/main/java/com/amazonaws/athena/connectors/dynamodb/util/DDBPredicateUtils.java @@ -20,6 +20,7 @@ package com.amazonaws.athena.connectors.dynamodb.util; import com.amazonaws.athena.connector.lambda.domain.predicate.EquatableValueSet; +import com.amazonaws.athena.connector.lambda.domain.predicate.Marker; import com.amazonaws.athena.connector.lambda.domain.predicate.Range; import com.amazonaws.athena.connector.lambda.domain.predicate.SortedRangeSet; import com.amazonaws.athena.connector.lambda.domain.predicate.ValueSet; @@ -181,6 +182,34 @@ private static String toPredicate(String columnName, String operator, Object val return aliasColumn(columnName) + " " + operator + " " + valueName; } + private static void validateColumnRange(Range range) + { + if (!range.getLow().isLowerUnbounded()) { + switch (range.getLow().getBound()) { + case ABOVE: + break; + case EXACTLY: + break; + case BELOW: + throw new IllegalArgumentException("Low marker should never use BELOW bound"); + default: + throw new AssertionError("Unhandled lower bound: " + range.getLow().getBound()); + } + } + if (!range.getHigh().isUpperUnbounded()) { + switch (range.getHigh().getBound()) { + case ABOVE: + throw new IllegalArgumentException("High marker should never use ABOVE bound"); + case EXACTLY: + break; + case BELOW: + break; + default: + throw new AssertionError("Unhandled upper bound: " + range.getHigh().getBound()); + } + } + } + /** * Generates a filter expression for a single column given a {@link ValueSet} predicate for that column. * @@ -212,41 +241,38 @@ public static String generateSingleColumnFilter(String originalColumnName, Value checkState(!range.isAll()); // Already checked if (range.isSingleValue()) { singleValues.add(range.getLow().getValue()); + continue; + } + validateColumnRange(range); + List rangeConjuncts = new ArrayList<>(); + if (range.getLow().getBound().equals(Marker.Bound.EXACTLY) && range.getHigh().getBound().equals(Marker.Bound.EXACTLY)) { + String startBetweenPredicate = toPredicate(originalColumnName, "BETWEEN", range.getLow().getValue(), accumulator, valueNameProducer.getNext(), recordMetadata); + String endBetweenPredicate = valueNameProducer.getNext(); + bindValue(originalColumnName, range.getHigh().getValue(), accumulator, recordMetadata); + rangeConjuncts.add(startBetweenPredicate); + rangeConjuncts.add(endBetweenPredicate); } else { - List rangeConjuncts = new ArrayList<>(); - if (!range.getLow().isLowerUnbounded()) { - switch (range.getLow().getBound()) { - case ABOVE: - rangeConjuncts.add(toPredicate(originalColumnName, ">", range.getLow().getValue(), accumulator, valueNameProducer.getNext(), recordMetadata)); - break; - case EXACTLY: - rangeConjuncts.add(toPredicate(originalColumnName, ">=", range.getLow().getValue(), accumulator, valueNameProducer.getNext(), recordMetadata)); - break; - case BELOW: - throw new IllegalArgumentException("Low marker should never use BELOW bound"); - default: - throw new AssertionError("Unhandled lower bound: " + range.getLow().getBound()); - } + switch (range.getLow().getBound()) { + case ABOVE: + rangeConjuncts.add(toPredicate(originalColumnName, ">", range.getLow().getValue(), accumulator, valueNameProducer.getNext(), recordMetadata)); + break; + case EXACTLY: + rangeConjuncts.add(toPredicate(originalColumnName, ">=", range.getLow().getValue(), accumulator, valueNameProducer.getNext(), recordMetadata)); + break; } - if (!range.getHigh().isUpperUnbounded()) { - switch (range.getHigh().getBound()) { - case ABOVE: - throw new IllegalArgumentException("High marker should never use ABOVE bound"); - case EXACTLY: - rangeConjuncts.add(toPredicate(originalColumnName, "<=", range.getHigh().getValue(), accumulator, valueNameProducer.getNext(), recordMetadata)); - break; - case BELOW: - rangeConjuncts.add(toPredicate(originalColumnName, "<", range.getHigh().getValue(), accumulator, valueNameProducer.getNext(), recordMetadata)); - break; - default: - throw new AssertionError("Unhandled upper bound: " + range.getHigh().getBound()); - } + switch (range.getHigh().getBound()) { + case EXACTLY: + rangeConjuncts.add(toPredicate(originalColumnName, "<=", range.getHigh().getValue(), accumulator, valueNameProducer.getNext(), recordMetadata)); + break; + case BELOW: + rangeConjuncts.add(toPredicate(originalColumnName, "<", range.getHigh().getValue(), accumulator, valueNameProducer.getNext(), recordMetadata)); + break; } - // If rangeConjuncts is null, then the range was ALL, which should already have been checked for - checkState(!rangeConjuncts.isEmpty()); - disjuncts.add("(" + AND_JOINER.join(rangeConjuncts) + ")"); } + // If rangeConjuncts is null, then the range was ALL, which should already have been checked for + checkState(!rangeConjuncts.isEmpty()); + disjuncts.add("(" + AND_JOINER.join(rangeConjuncts) + ")"); } } else { diff --git a/athena-dynamodb/src/test/java/com/amazonaws/athena/connectors/dynamodb/DynamoDBMetadataHandlerTest.java b/athena-dynamodb/src/test/java/com/amazonaws/athena/connectors/dynamodb/DynamoDBMetadataHandlerTest.java index e742177bd3..3726571310 100644 --- a/athena-dynamodb/src/test/java/com/amazonaws/athena/connectors/dynamodb/DynamoDBMetadataHandlerTest.java +++ b/athena-dynamodb/src/test/java/com/amazonaws/athena/connectors/dynamodb/DynamoDBMetadataHandlerTest.java @@ -355,13 +355,50 @@ public void doGetTableLayoutQueryIndex() assertThat(res.getPartitions().getSchema().getCustomMetadata().get(HASH_KEY_NAME_METADATA), equalTo("col_4")); assertThat(res.getPartitions().getRowCount(), equalTo(2)); assertThat(res.getPartitions().getSchema().getCustomMetadata().get(RANGE_KEY_NAME_METADATA), equalTo("col_5")); - assertThat(res.getPartitions().getSchema().getCustomMetadata().get(RANGE_KEY_FILTER_METADATA), equalTo("(#col_5 >= :v0 AND #col_5 <= :v1)")); + assertThat(res.getPartitions().getSchema().getCustomMetadata().get(RANGE_KEY_FILTER_METADATA), equalTo("(#col_5 BETWEEN :v0 AND :v1)")); ImmutableMap expressionNames = ImmutableMap.of("#col_4", "col_4", "#col_5", "col_5"); assertThat(res.getPartitions().getSchema().getCustomMetadata().get(EXPRESSION_NAMES_METADATA), equalTo(Jackson.toJsonString(expressionNames))); ImmutableMap expressionValues = ImmutableMap.of(":v0", ItemUtils.toAttributeValue(startTime), ":v1", ItemUtils.toAttributeValue(endTime)); assertThat(res.getPartitions().getSchema().getCustomMetadata().get(EXPRESSION_VALUES_METADATA), equalTo(Jackson.toJsonString(expressionValues))); + + // Note that while we were able to fix the inclusive upper and lower bound cases, we cannot fix mixed + // inclusion bounds for now. + // So this key condition is expected to fail when used against a real DDB instance with: + // "KeyConditionExpressions must only contain one condition per key" + // However, we still test the mixed cases below to make sure that we don't accidentally generate the BETWEEN version even though + // this will cause customer queries with mixed inclusion to fail. + { + SortedRangeSet.Builder timeValueSet2 = SortedRangeSet.newBuilder(Types.MinorType.DATEMILLI.getType(), false); + timeValueSet2.add(Range.range(allocator, Types.MinorType.DATEMILLI.getType(), startTime, + true /* inclusive lowerbound */, endTime, false /* exclusive upperbound */)); + constraintsMap.put("col_5", timeValueSet2.build()); + GetTableLayoutResponse res2 = handler.doGetTableLayout(allocator, new GetTableLayoutRequest(TEST_IDENTITY, + TEST_QUERY_ID, + TEST_CATALOG_NAME, + TEST_TABLE_NAME, + new Constraints(constraintsMap), + SchemaBuilder.newBuilder().build(), + Collections.EMPTY_SET)); + assertThat(res2.getPartitions().getSchema().getCustomMetadata().get(RANGE_KEY_FILTER_METADATA), equalTo("(#col_5 >= :v0 AND #col_5 < :v1)")); + } + + { + SortedRangeSet.Builder timeValueSet2 = SortedRangeSet.newBuilder(Types.MinorType.DATEMILLI.getType(), false); + timeValueSet2.add(Range.range(allocator, Types.MinorType.DATEMILLI.getType(), startTime, + false /* exclusive lowerbound */, endTime, true /* inclusive upperbound*/)); + constraintsMap.put("col_5", timeValueSet2.build()); + GetTableLayoutResponse res2 = handler.doGetTableLayout(allocator, new GetTableLayoutRequest(TEST_IDENTITY, + TEST_QUERY_ID, + TEST_CATALOG_NAME, + TEST_TABLE_NAME, + new Constraints(constraintsMap), + SchemaBuilder.newBuilder().build(), + Collections.EMPTY_SET)); + assertThat(res2.getPartitions().getSchema().getCustomMetadata().get(RANGE_KEY_FILTER_METADATA), equalTo("(#col_5 > :v0 AND #col_5 <= :v1)")); + } + // ------------------------------------------------------------------------- } @Test diff --git a/athena-dynamodb/src/test/java/com/amazonaws/athena/connectors/dynamodb/DynamoDBRecordHandlerTest.java b/athena-dynamodb/src/test/java/com/amazonaws/athena/connectors/dynamodb/DynamoDBRecordHandlerTest.java index 49cce43353..f921d88c8f 100644 --- a/athena-dynamodb/src/test/java/com/amazonaws/athena/connectors/dynamodb/DynamoDBRecordHandlerTest.java +++ b/athena-dynamodb/src/test/java/com/amazonaws/athena/connectors/dynamodb/DynamoDBRecordHandlerTest.java @@ -520,7 +520,7 @@ public void testStructWithSchemaFromGlueTable() throws Exception List columns = new ArrayList<>(); columns.add(new Column().withName("col0").withType("string")); columns.add(new Column().withName("outermap").withType("struct>")); - columns.add(new Column().withName("structcol").withType("struct>")); + columns.add(new Column().withName("structcol").withType("struct>")); Map param = ImmutableMap.of( SOURCE_TABLE_PROPERTY, TEST_TABLE6,