Skip to content

Commit

Permalink
[CALCITE-5836] Implement Rel2Sql for MERGE
Browse files Browse the repository at this point in the history
  • Loading branch information
macroguo-ghy committed Sep 10, 2023
1 parent a7e3f7e commit 2250bd3
Show file tree
Hide file tree
Showing 5 changed files with 231 additions and 1 deletion.
1 change: 1 addition & 0 deletions core/src/main/java/org/apache/calcite/plan/RelOptUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -2138,6 +2138,7 @@ public static RelDataType createDmlRowType(
case INSERT:
case DELETE:
case UPDATE:
case MERGE:
return typeFactory.createStructType(
PairList.of(AvaticaConnection.ROWCOUNT_COLUMN_NAME,
typeFactory.createSqlType(SqlTypeName.BIGINT)));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -591,6 +591,7 @@ private static Meta.StatementType getStatementType(SqlKind kind) {
case INSERT:
case DELETE:
case UPDATE:
case MERGE:
return Meta.StatementType.IS_DML;
default:
return Meta.StatementType.SELECT;
Expand Down Expand Up @@ -667,6 +668,7 @@ <T> CalciteSignature<T> prepare2_(
case INSERT:
case DELETE:
case UPDATE:
case MERGE:
case EXPLAIN:
// FIXME: getValidatedNodeType is wrong for DML
x = RelOptUtil.createDmlRowType(sqlNode.getKind(), typeFactory);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.SqlLiteral;
import org.apache.calcite.sql.SqlMatchRecognize;
import org.apache.calcite.sql.SqlMerge;
import org.apache.calcite.sql.SqlNode;
import org.apache.calcite.sql.SqlNodeList;
import org.apache.calcite.sql.SqlSampleSpec;
Expand Down Expand Up @@ -1084,7 +1085,62 @@ public Result visit(TableModify modify) {

return result(sqlDelete, input.clauses, modify, null);
}
case MERGE:
case MERGE: {
final Result input = visitInput(modify, 0);
final SqlSelect select = input.asSelect();
// When querying with both the `WHEN MATCHED THEN UPDATE` and
// `WHEN NOT MATCHED THEN INSERT` clauses, the selectList consists of three parts:
// the insert expression, the target table reference, and the update expression.
// When querying with the `WHEN MATCHED THEN UPDATE` clause, the selectList will not
// include the update expression.
// However, when querying with the `WHEN NOT MATCHED THEN INSERT` clause,
// the expression list will only contain the insert expression.
final SqlNodeList selectList = SqlUtil.stripListAs(select.getSelectList());
final SqlJoin join = requireNonNull((SqlJoin) select.getFrom());
final SqlNode condition = requireNonNull(join.getCondition());
final SqlNode source = join.getLeft();

SqlUpdate update = null;
final List<String> updateColumnList =
requireNonNull(modify.getUpdateColumnList(),
() -> "modify.getUpdateColumnList() is null for " + modify);
final int nUpdateFiled = updateColumnList.size();
if (nUpdateFiled != 0) {
final SqlNodeList expressionList =
Util.last(selectList, nUpdateFiled).stream()
.collect(SqlNodeList.toList());
update =
new SqlUpdate(POS, sqlTargetTable,
identifierList(updateColumnList),
expressionList,
condition, null, null);
}

final RelDataType targetRowType = modify.getTable().getRowType();
final int nTargetFiled = targetRowType.getFieldCount();
final int nInsertFiled = nUpdateFiled == 0
? selectList.size() : selectList.size() - nTargetFiled - nUpdateFiled;
SqlInsert insert = null;
if (nInsertFiled != 0) {
final SqlNodeList expressionList =
Util.first(selectList, nInsertFiled).stream()
.collect(SqlNodeList.toList());
final SqlNode valuesCall =
SqlStdOperatorTable.VALUES.createCall(expressionList);
final SqlNodeList columnList = targetRowType.getFieldNames().stream()
.map(f -> new SqlIdentifier(f, POS))
.collect(SqlNodeList.toList());
insert = new SqlInsert(POS, SqlNodeList.EMPTY, sqlTargetTable, valuesCall, columnList);
}

final SqlNode target = join.getRight();
final SqlNode targetTableAlias = target.getKind() == SqlKind.AS
? ((SqlCall) target).operand(1) : null;
final SqlMerge merge =
new SqlMerge(POS, sqlTargetTable, condition, source, update, insert, null,
(SqlIdentifier) targetTableAlias);
return result(merge, input.clauses, modify, null);
}
default:
throw new AssertionError("not implemented: " + modify);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7027,6 +7027,136 @@ private void checkLiteral2(String expression, String expected) {
.ok(expected);
}

@Test void testMerge() {
final String sql1 = "merge into \"DEPT\" as \"t\"\n"
+ "using \"DEPT\" as \"s\"\n"
+ "on \"t\".\"DEPTNO\" = \"s\".\"DEPTNO\"\n"
+ "when matched then\n"
+ "update set \"DNAME\" = \"s\".\"DNAME\"\n"
+ "when not matched then\n"
+ "insert (DEPTNO, DNAME, LOC)\n"
+ "values (\"s\".\"DEPTNO\" + 1, lower(\"s\".\"DNAME\"), upper(\"s\".\"LOC\"))";
final String expected1 = "MERGE INTO \"SCOTT\".\"DEPT\" AS \"DEPT0\"\n"
+ "USING \"SCOTT\".\"DEPT\"\n"
+ "ON \"DEPT\".\"DEPTNO\" = \"DEPT0\".\"DEPTNO\"\n"
+ "WHEN MATCHED THEN UPDATE SET \"DNAME\" = \"DEPT\".\"DNAME\"\n"
+ "WHEN NOT MATCHED THEN INSERT (\"DEPTNO\", \"DNAME\", \"LOC\") "
+ "VALUES CAST(\"DEPT\".\"DEPTNO\" + 1 AS TINYINT),\n"
+ "LOWER(\"DEPT\".\"DNAME\"),\n"
+ "UPPER(\"DEPT\".\"LOC\")";
sql(sql1)
.schema(CalciteAssert.SchemaSpec.JDBC_SCOTT)
.ok(expected1);

// without insert columns
final String sql2 = "merge into \"DEPT\" as \"t\"\n"
+ "using \"DEPT\" as \"s\"\n"
+ "on \"t\".\"DEPTNO\" = \"s\".\"DEPTNO\"\n"
+ "when matched then\n"
+ "update set \"DNAME\" = \"s\".\"DNAME\"\n"
+ "when not matched then insert\n"
+ "values (\"s\".\"DEPTNO\" + 1, lower(\"s\".\"DNAME\"), upper(\"s\".\"LOC\"))";
final String expected2 = expected1;
sql(sql2)
.schema(CalciteAssert.SchemaSpec.JDBC_SCOTT)
.ok(expected2);

// reorder insert columns
final String sql3 = "merge into \"DEPT\" as \"t\"\n"
+ "using \"DEPT\" as \"s\"\n"
+ "on \"t\".\"DEPTNO\" = \"s\".\"DEPTNO\"\n"
+ "when matched then\n"
+ "update set \"DNAME\" = \"s\".\"DNAME\"\n"
+ "when not matched then\n"
+ "insert (DEPTNO, LOC, DNAME)\n"
+ "values (\"s\".\"DEPTNO\" + 1, lower(\"s\".\"DNAME\"), 'abc')";
final String expected3 = "MERGE INTO \"SCOTT\".\"DEPT\" AS \"DEPT0\"\n"
+ "USING \"SCOTT\".\"DEPT\"\n"
+ "ON \"DEPT\".\"DEPTNO\" = \"DEPT0\".\"DEPTNO\"\n"
+ "WHEN MATCHED THEN UPDATE SET \"DNAME\" = \"DEPT\".\"DNAME\"\n"
+ "WHEN NOT MATCHED THEN INSERT (\"DEPTNO\", \"DNAME\", \"LOC\") "
+ "VALUES CAST(\"DEPT\".\"DEPTNO\" + 1 AS TINYINT),\n"
+ "'abc',\n"
+ "LOWER(\"DEPT\".\"DNAME\")";
sql(sql3)
.schema(CalciteAssert.SchemaSpec.JDBC_SCOTT)
.ok(expected3);

// without WHEN NOT MATCHED THEN
final String sql4 = "merge into \"DEPT\" as \"t\"\n"
+ "using \"DEPT\" as \"s\"\n"
+ "on \"t\".\"DEPTNO\" = \"s\".\"DEPTNO\"\n"
+ "when matched then\n"
+ "update set \"DNAME\" = \"s\".\"DNAME\"";
final String expected4 = "MERGE INTO \"SCOTT\".\"DEPT\" AS \"DEPT0\"\n"
+ "USING \"SCOTT\".\"DEPT\"\n"
+ "ON \"DEPT\".\"DEPTNO\" = \"DEPT0\".\"DEPTNO\"\n"
+ "WHEN MATCHED THEN UPDATE SET \"DNAME\" = \"DEPT\".\"DNAME\"";
sql(sql4)
.schema(CalciteAssert.SchemaSpec.JDBC_SCOTT)
.ok(expected4);

// without WHEN MATCHED THEN
final String sql5 = "merge into \"DEPT\" as \"t\"\n"
+ "using \"DEPT\" as \"s\"\n"
+ "on \"t\".\"DEPTNO\" = \"s\".\"DEPTNO\"\n"
+ "when not matched then\n"
+ "insert (DEPTNO, DNAME, LOC)\n"
+ "values (\"s\".\"DEPTNO\" + 1, lower(\"s\".\"DNAME\"), upper(\"s\".\"LOC\"))";
final String expected5 = "MERGE INTO \"SCOTT\".\"DEPT\" AS \"DEPT0\"\n"
+ "USING \"SCOTT\".\"DEPT\"\n"
+ "ON \"DEPT\".\"DEPTNO\" = \"DEPT0\".\"DEPTNO\"\n"
+ "WHEN NOT MATCHED THEN INSERT (\"DEPTNO\", \"DNAME\", \"LOC\") "
+ "VALUES CAST(\"DEPT\".\"DEPTNO\" + 1 AS TINYINT),\n"
+ "LOWER(\"DEPT\".\"DNAME\"),\n"
+ "UPPER(\"DEPT\".\"LOC\")";
sql(sql5)
.schema(CalciteAssert.SchemaSpec.JDBC_SCOTT)
.ok(expected5);

// using query
final String sql6 = "merge into \"DEPT\" as \"t\"\n"
+ "using (select * from \"DEPT\" where \"DEPTNO\" <> 5) as \"s\"\n"
+ "on \"t\".\"DEPTNO\" = \"s\".\"DEPTNO\"\n"
+ "when not matched then\n"
+ "insert (DEPTNO, DNAME, LOC)\n"
+ "values (\"s\".\"DEPTNO\" + 1, lower(\"s\".\"DNAME\"), upper(\"s\".\"LOC\"))";
final String expected6 = "MERGE INTO \"SCOTT\".\"DEPT\" AS \"DEPT0\"\n"
+ "USING (SELECT *\n"
+ "FROM \"SCOTT\".\"DEPT\"\n"
+ "WHERE CAST(\"DEPTNO\" AS INTEGER) <> 5) AS \"t0\"\n"
+ "ON \"t0\".\"DEPTNO\" = \"DEPT0\".\"DEPTNO\"\n"
+ "WHEN NOT MATCHED THEN INSERT (\"DEPTNO\", \"DNAME\", \"LOC\") "
+ "VALUES CAST(\"t0\".\"DEPTNO\" + 1 AS TINYINT),\n"
+ "LOWER(\"t0\".\"DNAME\"),\n"
+ "UPPER(\"t0\".\"LOC\")";
sql(sql6)
.schema(CalciteAssert.SchemaSpec.JDBC_SCOTT)
.ok(expected6);

final String sql7 = "merge into \"DEPT\" as \"t\"\n"
+ "using (select * from (values (1, 'name', 'loc'))) as \"s\"(\"a\", \"b\", \"c\")\n"
+ "on \"t\".\"DEPTNO\" = \"s\".\"a\"\n"
+ "when matched then\n"
+ "update set \"DNAME\" = 'abc'"
+ "when not matched then\n"
+ "insert (DEPTNO, DNAME, LOC)\n"
+ "values (\"s\".\"a\" + 1, lower(\"s\".\"b\"), upper(\"s\".\"c\"))";
final String expected7 = "MERGE INTO \"SCOTT\".\"DEPT\" AS \"t1\"\n"
+ "USING (SELECT *\n"
+ "FROM (VALUES (1, 'name', 'loc')) "
+ "AS \"t\" (\"EXPR$0\", \"EXPR$1\", \"EXPR$2\")) AS \"t0\"\n"
+ "ON \"t0\".\"EXPR$0\" = \"t1\".\"DEPTNO0\"\n"
+ "WHEN MATCHED THEN UPDATE SET \"DNAME\" = 'abc'\n"
+ "WHEN NOT MATCHED THEN INSERT (\"DEPTNO\", \"DNAME\", \"LOC\") "
+ "VALUES CAST(\"t0\".\"EXPR$0\" + 1 AS TINYINT),\n"
+ "LOWER(\"t0\".\"EXPR$1\"),\n"
+ "UPPER(\"t0\".\"EXPR$2\")";
sql(sql7)
.schema(CalciteAssert.SchemaSpec.JDBC_SCOTT)
.ok(expected7);
}

/** Test case for
* <a href="https://issues.apache.org/jira/browse/CALCITE-5265">[CALCITE-5265]
* JDBC adapter sometimes adds unnecessary parentheses around SELECT in INSERT</a>. */
Expand Down
41 changes: 41 additions & 0 deletions core/src/test/java/org/apache/calcite/test/JdbcAdapterTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -1106,6 +1106,47 @@ private LockWrapper exclusiveCleanDb(Connection c) throws SQLException {
.returns("C=null\nC=null\nC=null\nC=null\nC=null\nC=null\nC=null\n");
}

@Test void testMerge() throws Exception {
final String sql = "merge into \"foodmart\".\"expense_fact\"\n"
+ "using (values(666, 42)) as vals(store_id, amount)\n"
+ "on \"expense_fact\".\"store_id\" = vals.store_id\n"
+ "when matched then update\n"
+ "set \"amount\" = vals.amount\n"
+ "when not matched then insert\n"
+ "values (vals.store_id, 666, TIMESTAMP '1997-01-01 00:00:00', 666, '666', 666,"
+ " vals.amount)";
final String explain = "PLAN=JdbcToEnumerableConverter\n"
+ " JdbcTableModify(table=[[foodmart, expense_fact]], operation=[MERGE],"
+ " updateColumnList=[[amount]], flattened=[false])\n"
+ " JdbcProject(STORE_ID=[$0], $f1=[666], $f2=[1997-01-01 00:00:00], $f3=[666],"
+ " $f4=['666'], $f5=[666], AMOUNT=[CAST($1):DECIMAL(10, 4) NOT NULL], store_id=[$2],"
+ " account_id=[$3], exp_date=[$4], time_id=[$5], category_id=[$6], currency_id=[$7],"
+ " amount=[$8], AMOUNT0=[$1])\n"
+ " JdbcJoin(condition=[=($2, $0)], joinType=[left])\n"
+ " JdbcValues(tuples=[[{ 666, 42 }]])\n"
+ " JdbcTableScan(table=[[foodmart, expense_fact]])\n";
final String jdbcSql = "MERGE INTO \"foodmart\".\"expense_fact\"\n"
+ "USING (VALUES (666, 42)) AS \"t\" (\"STORE_ID\", \"AMOUNT\")\n"
+ "ON \"t\".\"STORE_ID\" = \"expense_fact\".\"store_id\"\n"
+ "WHEN MATCHED THEN UPDATE SET \"amount\" = \"t\".\"AMOUNT\"\n"
+ "WHEN NOT MATCHED THEN INSERT (\"store_id\", \"account_id\", \"exp_date\", \"time_id\", "
+ "\"category_id\", \"currency_id\", \"amount\") VALUES \"t\".\"STORE_ID\",\n"
+ "666,\nTIMESTAMP '1997-01-01 00:00:00',\n666,\n'666',\n666,\n"
+ "CAST(\"t\".\"AMOUNT\" AS DECIMAL(10, 4))";
final AssertThat that =
CalciteAssert.model(FoodmartSchema.FOODMART_MODEL)
.enable(CalciteAssert.DB == DatabaseInstance.HSQLDB);
that.doWithConnection(connection -> {
try (LockWrapper ignore = exclusiveCleanDb(connection)) {
that.query(sql)
.explainContains(explain)
.planUpdateHasSql(jdbcSql, 1);
} catch (SQLException e) {
throw TestUtil.rethrow(e);
}
});
}

/** Acquires a lock, and releases it when closed. */
static class LockWrapper implements AutoCloseable {
private final Lock lock;
Expand Down

0 comments on commit 2250bd3

Please sign in to comment.