Skip to content

Commit

Permalink
[CALCITE-5918] Add MAP function (enabled in Spark library)
Browse files Browse the repository at this point in the history
  • Loading branch information
chucheng92 committed Oct 17, 2023
1 parent 5151168 commit 9d77112
Show file tree
Hide file tree
Showing 9 changed files with 242 additions and 7 deletions.
21 changes: 14 additions & 7 deletions core/src/main/codegen/templates/Parser.jj
Original file line number Diff line number Diff line change
Expand Up @@ -4886,14 +4886,21 @@ SqlNode MapConstructor() :
{
<MAP> { s = span(); }
(
LOOKAHEAD(1)
<LPAREN>
// by sub query "MAP (SELECT empno, deptno FROM emp)"
e = LeafQueryOrExpr(ExprContext.ACCEPT_QUERY)
<RPAREN>
(
// empty map function call: "map()"
LOOKAHEAD(2)
<LPAREN> <RPAREN> { args = SqlNodeList.EMPTY; }
|
args = ParenthesizedQueryOrCommaList(ExprContext.ACCEPT_ALL)
)
{
return SqlStdOperatorTable.MAP_QUERY.createCall(
s.end(this), e);
if (args.size() == 1 && args.get(0).isA(SqlKind.QUERY)) {
// MAP query constructor e.g. "MAP (SELECT empno, deptno FROM emps)"
return SqlStdOperatorTable.MAP_QUERY.createCall(s.end(this), args.get(0));
} else {
// MAP function e.g. "MAP(1, 2)" equivalent to standard "MAP[1, 2]"
return SqlLibraryOperators.MAP.createCall(s.end(this), args.getList());
}
}
|
// by enumeration "MAP[k0, v0, ..., kN, vN]"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,7 @@
import static org.apache.calcite.sql.fun.SqlLibraryOperators.LOGICAL_AND;
import static org.apache.calcite.sql.fun.SqlLibraryOperators.LOGICAL_OR;
import static org.apache.calcite.sql.fun.SqlLibraryOperators.LPAD;
import static org.apache.calcite.sql.fun.SqlLibraryOperators.MAP;
import static org.apache.calcite.sql.fun.SqlLibraryOperators.MAP_CONCAT;
import static org.apache.calcite.sql.fun.SqlLibraryOperators.MAP_ENTRIES;
import static org.apache.calcite.sql.fun.SqlLibraryOperators.MAP_FROM_ARRAYS;
Expand Down Expand Up @@ -880,6 +881,7 @@ Builder populate2() {
map.put(MAP_VALUE_CONSTRUCTOR, value);
map.put(ARRAY_VALUE_CONSTRUCTOR, value);
defineMethod(ARRAY, BuiltInMethod.ARRAYS_AS_LIST.method, NullPolicy.NONE);
defineMethod(MAP, BuiltInMethod.MAP_FUNCTION.method, NullPolicy.NONE);

// ITEM operator
map.put(ITEM, new ItemImplementor());
Expand Down
11 changes: 11 additions & 0 deletions core/src/main/java/org/apache/calcite/runtime/SqlFunctions.java
Original file line number Diff line number Diff line change
Expand Up @@ -5323,6 +5323,17 @@ public static Map mapFromArrays(List keysArray, List valuesArray) {
return map;
}

/** Support the MAP function. */
public static Map mapFunction(Object... args) {
final Map map = new LinkedHashMap<>();
for (int i = 0; i < args.length; i++) {
Object key = args[i++];
Object value = args[i];
map.put(key, value);
}
return map;
}

/** Support the STR_TO_MAP function. */
public static Map strToMap(String string, String stringDelimiter, String keyValueDelimiter) {
final Map map = new LinkedHashMap();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,9 @@
import org.apache.calcite.sql.validate.SqlValidatorUtil;
import org.apache.calcite.util.Litmus;
import org.apache.calcite.util.Optionality;
import org.apache.calcite.util.Pair;
import org.apache.calcite.util.Static;
import org.apache.calcite.util.Util;

import com.google.common.collect.ImmutableList;

Expand Down Expand Up @@ -1082,6 +1084,38 @@ private static RelDataType arrayReturnType(SqlOperatorBinding opBinding) {
SqlLibraryOperators::arrayReturnType,
OperandTypes.SAME_VARIADIC);

private static RelDataType mapReturnType(SqlOperatorBinding opBinding) {
Pair<@Nullable RelDataType, @Nullable RelDataType> type =
getComponentTypes(
opBinding.getTypeFactory(), opBinding.collectOperandTypes());
return SqlTypeUtil.createMapType(
opBinding.getTypeFactory(),
requireNonNull(type.left, "inferred key type"),
requireNonNull(type.right, "inferred value type"),
false);
}

private static Pair<@Nullable RelDataType, @Nullable RelDataType> getComponentTypes(
RelDataTypeFactory typeFactory,
List<RelDataType> argTypes) {
// special case, allows empty map
if (argTypes.size() == 0) {
return Pair.of(typeFactory.createUnknownType(), typeFactory.createUnknownType());
}
return Pair.of(
typeFactory.leastRestrictive(Util.quotientList(argTypes, 2, 0)),
typeFactory.leastRestrictive(Util.quotientList(argTypes, 2, 1)));
}

/** The "MAP(key, value, ...)" function (Spark);
* compare with the standard map value constructor, "MAP[key, value, ...]". */
@LibraryOperator(libraries = {SPARK})
public static final SqlFunction MAP =
SqlBasicFunction.create("MAP",
SqlLibraryOperators::mapReturnType,
OperandTypes.MAP_FUNCTION,
SqlFunctionCategory.SYSTEM);

@SuppressWarnings("argument.type.incompatible")
private static RelDataType arrayAppendPrependReturnType(SqlOperatorBinding opBinding) {
final RelDataType arrayType = opBinding.collectOperandTypes().get(0);
Expand Down
51 changes: 51 additions & 0 deletions core/src/main/java/org/apache/calcite/sql/type/OperandTypes.java
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import org.apache.calcite.sql.SqlUtil;
import org.apache.calcite.sql.validate.SqlValidatorScope;
import org.apache.calcite.util.ImmutableIntList;
import org.apache.calcite.util.Pair;
import org.apache.calcite.util.Util;

import com.google.common.collect.ImmutableList;
Expand Down Expand Up @@ -568,6 +569,9 @@ public static SqlOperandTypeChecker variadic(
public static final SqlSingleOperandTypeChecker MAP_FROM_ENTRIES =
new MapFromEntriesOperandTypeChecker();

public static final SqlSingleOperandTypeChecker MAP_FUNCTION =
new MapFunctionOperandTypeChecker();

/**
* Operand type-checking strategy where type must be a literal or NULL.
*/
Expand Down Expand Up @@ -1221,6 +1225,53 @@ private static class MapFromEntriesOperandTypeChecker
}
}

/**
* Operand type-checking strategy for a MAP function, it allows empty map.
*/
private static class MapFunctionOperandTypeChecker
extends SameOperandTypeChecker {

MapFunctionOperandTypeChecker() {
super(-1);
}

@Override public boolean checkOperandTypes(final SqlCallBinding callBinding,
final boolean throwOnFailure) {
final List<RelDataType> argTypes =
SqlTypeUtil.deriveType(callBinding, callBinding.operands());
// allows empty map
if (argTypes.size() == 0) {
return true;
}
// the size of map arg types must be even.
if (argTypes.size() % 2 > 0) {
throw callBinding.newValidationError(RESOURCE.mapRequiresEvenArgCount());
}
final Pair<@Nullable RelDataType, @Nullable RelDataType> componentType =
getComponentTypes(
callBinding.getTypeFactory(), argTypes);
// check key type & value type
if (null == componentType.left || null == componentType.right) {
if (throwOnFailure) {
throw callBinding.newValidationError(RESOURCE.needSameTypeParameter());
}
return false;
}
return true;
}

/**
* Extract the key type and value type of arg types.
*/
private static Pair<@Nullable RelDataType, @Nullable RelDataType> getComponentTypes(
RelDataTypeFactory typeFactory,
List<RelDataType> argTypes) {
return Pair.of(
typeFactory.leastRestrictive(Util.quotientList(argTypes, 2, 0)),
typeFactory.leastRestrictive(Util.quotientList(argTypes, 2, 1)));
}
}

/** Operand type-checker that accepts period types. Examples:
*
* <ul>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -774,6 +774,7 @@ public enum BuiltInMethod {
MAP_VALUES(SqlFunctions.class, "mapValues", Map.class),
MAP_FROM_ARRAYS(SqlFunctions.class, "mapFromArrays", List.class, List.class),
MAP_FROM_ENTRIES(SqlFunctions.class, "mapFromEntries", List.class),
MAP_FUNCTION(SqlFunctions.class, "mapFunction", Object[].class),
STR_TO_MAP(SqlFunctions.class, "strToMap", String.class, String.class, String.class),
SELECTIVITY(Selectivity.class, "getSelectivity", RexNode.class),
UNIQUE_KEYS(UniqueKeys.class, "getUniqueKeys", boolean.class),
Expand Down
1 change: 1 addition & 0 deletions site/_docs/reference.md
Original file line number Diff line number Diff line change
Expand Up @@ -2770,6 +2770,7 @@ BigQuery's type system uses confusingly different names for types and functions:
| b | TO_HEX(binary) | Converts *binary* into a hexadecimal varchar
| b | FROM_HEX(varchar) | Converts a hexadecimal-encoded *varchar* into bytes
| b o | LTRIM(string) | Returns *string* with all blanks removed from the start
| s | MAP(key, value [, key, value]*) | Returns a map with the given key/value pairs
| s | MAP_CONCAT(map [, map]*) | Concatenates one or more maps. If any input argument is `NULL` the function returns `NULL`. Note that calcite is using the LAST_WIN strategy
| s | MAP_ENTRIES(map) | Returns the entries of the *map* as an array, the order of the entries is not defined
| s | MAP_KEYS(map) | Returns the keys of the *map* as an array, the order of the entries is not defined
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6146,6 +6146,32 @@ private static Matcher<SqlNode> isCharLiteral(String s) {
.ok("(MAP[])");
}

@Test void testMapFunction() {
expr("map()").ok("MAP()");
// parser allows odd elements; validator will reject it
expr("map(1)").ok("MAP(1)");
expr("map(1, 'x', 2, 'y')")
.ok("MAP(1, 'x', 2, 'y')");
// with space
expr("map (1, 'x', 2, 'y')")
.ok("MAP(1, 'x', 2, 'y')");
}

@Test void testMapQueryConstructor() {
// parser allows odd elements; validator will reject it
sql("SELECT map(SELECT 1)")
.ok("SELECT (MAP ((SELECT 1)))");
sql("SELECT map(SELECT 1, 2)")
.ok("SELECT (MAP ((SELECT 1, 2)))");
sql("SELECT map(SELECT T.x, T.y FROM (VALUES(1, 2)) AS T(x, y))")
.ok("SELECT (MAP ((SELECT `T`.`X`, `T`.`Y`\n"
+ "FROM (VALUES (ROW(1, 2))) AS `T` (`X`, `Y`))))");
sql("SELECT map(1, ^SELECT^ x FROM (VALUES(1)) x)")
.fails("(?s)Incorrect syntax near the keyword 'SELECT'.*");
sql("SELECT map(SELECT x FROM (VALUES(1)) x, ^SELECT^ x FROM (VALUES(1)) x)")
.fails("(?s)Incorrect syntax near the keyword 'SELECT' at.*");
}

@Test void testVisitSqlInsertWithSqlShuttle() {
final String sql = "insert into emps select * from emps";
final SqlNode sqlNode = sql(sql).node();
Expand Down
102 changes: 102 additions & 0 deletions testkit/src/main/java/org/apache/calcite/test/SqlOperatorTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -6722,6 +6722,7 @@ private static void checkIf(SqlOperatorFixture f) {

/** Tests {@code MAP_CONCAT} function from Spark. */
@Test void testMapConcatFunc() {
// 1. check with std map constructor, map[k, v ...]
final SqlOperatorFixture f0 = fixture();
f0.setFor(SqlLibraryOperators.MAP_CONCAT);
f0.checkFails("^map_concat(map['foo', 1], map['bar', 2])^",
Expand Down Expand Up @@ -6766,11 +6767,44 @@ private static void checkIf(SqlOperatorFixture f) {
// test operands not in same type family.
f.checkFails("^map_concat(map[1, null], array[1])^",
"Parameters must be of the same type", false);

// 2. check with map function, map(k, v ...)
final SqlOperatorFixture f1 = fixture()
.setFor(SqlLibraryOperators.MAP_CONCAT)
.withLibrary(SqlLibrary.SPARK);
f1.checkScalar("map_concat(map('foo', 1), map('bar', 2))", "{foo=1, bar=2}",
"(CHAR(3) NOT NULL, INTEGER NOT NULL) MAP NOT NULL");
f1.checkScalar("map_concat(map('foo', 1), map('bar', 2), map('foo', 2))", "{foo=2, bar=2}",
"(CHAR(3) NOT NULL, INTEGER NOT NULL) MAP NOT NULL");
f1.checkScalar("map_concat(map(null, 1), map(null, 2))", "{null=2}",
"(NULL, INTEGER NOT NULL) MAP NOT NULL");
f1.checkScalar("map_concat(map(1, 2), map(1, null))", "{1=null}",
"(INTEGER NOT NULL, INTEGER) MAP NOT NULL");
// test zero arg, but it should return empty map.
f1.checkScalar("map_concat()", "{}",
"(VARCHAR NOT NULL, VARCHAR NOT NULL) MAP");

// after calcite supports cast(null as map<string, int>), it should add these tests.
if (TODO) {
f1.checkNull("map_concat(map('foo', 1), cast(null as map<string, int>))");
f1.checkType("map_concat(map('foo', 1), cast(null as map<string, int>))",
"(VARCHAR NOT NULL, INTEGER NOT NULL) MAP");
f1.checkNull("map_concat(cast(null as map<string, int>), map['foo', 1])");
f1.checkType("map_concat(cast(null as map<string, int>), map['foo', 1])",
"(VARCHAR NOT NULL, INTEGER NOT NULL) MAP");
}
f1.checkFails("^map_concat(map('foo', 1), null)^",
"Function 'MAP_CONCAT' should all be of type map, "
+ "but it is 'NULL'", false);
// test operands not in same type family.
f1.checkFails("^map_concat(map(1, null), array[1])^",
"Parameters must be of the same type", false);
}


/** Tests {@code MAP_ENTRIES} function from Spark. */
@Test void testMapEntriesFunc() {
// 1. check with std map constructor, map[k, v ...]
final SqlOperatorFixture f0 = fixture();
f0.setFor(SqlLibraryOperators.MAP_ENTRIES);
f0.checkFails("^map_entries(map['foo', 1, 'bar', 2])^",
Expand All @@ -6796,10 +6830,20 @@ private static void checkIf(SqlOperatorFixture f) {
"RecordType(INTEGER f0, BIGINT NOT NULL f1) NOT NULL ARRAY NOT NULL");
f.checkScalar("map_entries(map[1, cast(1 as decimal), null, 2])", "[{1, 1}, {null, 2}]",
"RecordType(INTEGER f0, DECIMAL(19, 0) NOT NULL f1) NOT NULL ARRAY NOT NULL");

// 2. check with map function, map(k, v ...)
final SqlOperatorFixture f1 = fixture()
.setFor(SqlLibraryOperators.MAP_ENTRIES)
.withLibrary(SqlLibrary.SPARK);
f1.checkScalar("map_entries(map('foo', 1, 'bar', 2))", "[{foo, 1}, {bar, 2}]",
"RecordType(CHAR(3) NOT NULL f0, INTEGER NOT NULL f1) NOT NULL ARRAY NOT NULL");
f1.checkScalar("map_entries(map('foo', 1, null, 2))", "[{foo, 1}, {null, 2}]",
"RecordType(CHAR(3) f0, INTEGER NOT NULL f1) NOT NULL ARRAY NOT NULL");
}

/** Tests {@code MAP_KEYS} function from Spark. */
@Test void testMapKeysFunc() {
// 1. check with std map constructor, map[k, v ...]
final SqlOperatorFixture f0 = fixture();
f0.setFor(SqlLibraryOperators.MAP_KEYS);
f0.checkFails("^map_keys(map['foo', 1, 'bar', 2])^",
Expand All @@ -6825,10 +6869,20 @@ private static void checkIf(SqlOperatorFixture f) {
"INTEGER ARRAY NOT NULL");
f.checkScalar("map_keys(map[1, cast(1 as decimal), null, 2])", "[1, null]",
"INTEGER ARRAY NOT NULL");

// 2. check with map function, map(k, v ...)
final SqlOperatorFixture f1 = fixture()
.setFor(SqlLibraryOperators.MAP_KEYS)
.withLibrary(SqlLibrary.SPARK);
f1.checkScalar("map_keys(map('foo', 1, 'bar', 2))", "[foo, bar]",
"CHAR(3) NOT NULL ARRAY NOT NULL");
f1.checkScalar("map_keys(map('foo', 1, null, 2))", "[foo, null]",
"CHAR(3) ARRAY NOT NULL");
}

/** Tests {@code MAP_VALUES} function from Spark. */
@Test void testMapValuesFunc() {
// 1. check with std map constructor, map[k, v ...]
final SqlOperatorFixture f0 = fixture();
f0.setFor(SqlLibraryOperators.MAP_VALUES);
f0.checkFails("^map_values(map['foo', 1, 'bar', 2])^",
Expand All @@ -6839,6 +6893,15 @@ private static void checkIf(SqlOperatorFixture f) {
"INTEGER NOT NULL ARRAY NOT NULL");
f.checkScalar("map_values(map['foo', 1, 'bar', cast(null as integer)])", "[1, null]",
"INTEGER ARRAY NOT NULL");

// 2. check with map function, map(k, v ...)
final SqlOperatorFixture f1 = fixture()
.setFor(SqlLibraryOperators.MAP_VALUES)
.withLibrary(SqlLibrary.SPARK);
f1.checkScalar("map_values(map('foo', 1, 'bar', 2))", "[1, 2]",
"INTEGER NOT NULL ARRAY NOT NULL");
f1.checkScalar("map_values(map('foo', 1, 'bar', cast(null as integer)))", "[1, null]",
"INTEGER ARRAY NOT NULL");
}

/** Tests {@code MAP_FROM_ARRAYS} function from Spark. */
Expand Down Expand Up @@ -10446,6 +10509,45 @@ private static void checkArrayConcatAggFuncFails(SqlOperatorFixture t) {
"{1=1, 2=2}", "(BIGINT NOT NULL, SMALLINT NOT NULL) MAP NOT NULL");
}

@Test void testMapFunction() {
final SqlOperatorFixture f = fixture();
f.setFor(SqlLibraryOperators.MAP, VmName.EXPAND);

f.checkFails("^Map()^",
"No match found for function signature "
+ "MAP\\(\\)", false);
f.checkFails("^Map(1, 'x')^",
"No match found for function signature "
+ "MAP\\(<NUMERIC>, <CHARACTER>\\)", false);
f.checkFails("^map(1, 'x', 2, 'x')^",
"No match found for function signature "
+ "MAP\\(<NUMERIC>, <CHARACTER>, <NUMERIC>, <CHARACTER>\\)", false);

final SqlOperatorFixture f1 = f.withLibrary(SqlLibrary.SPARK);
f1.checkFails("^Map(1)^",
"Map requires an even number of arguments", false);
f1.checkFails("^Map(1, 'x', 2)^",
"Map requires an even number of arguments", false);
f1.checkFails("^map(1, 1, 2, 'x')^",
"Parameters must be of the same type", false);
// this behavior is different from std MapValueConstructor
f1.checkScalar("map()",
"{}",
"(UNKNOWN NOT NULL, UNKNOWN NOT NULL) MAP NOT NULL");
f1.checkScalar("map('washington', null)",
"{washington=null}",
"(CHAR(10) NOT NULL, NULL) MAP NOT NULL");
f1.checkScalar("map('washington', 1)",
"{washington=1}",
"(CHAR(10) NOT NULL, INTEGER NOT NULL) MAP NOT NULL");
f1.checkScalar("map('washington', 1, 'washington', 2)",
"{washington=2}",
"(CHAR(10) NOT NULL, INTEGER NOT NULL) MAP NOT NULL");
f1.checkScalar("map('washington', 1, 'obama', 44)",
"{washington=1, obama=44}",
"(CHAR(10) NOT NULL, INTEGER NOT NULL) MAP NOT NULL");
}

@Test void testCeilFunc() {
final SqlOperatorFixture f = fixture();
f.setFor(SqlStdOperatorTable.CEIL, VM_FENNEL);
Expand Down

0 comments on commit 9d77112

Please sign in to comment.