diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushDownCountIntoUnionAllTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushDownCountIntoUnionAllTest.java index aa9f7fb98451d9..bb35cb5e26f38f 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushDownCountIntoUnionAllTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushDownCountIntoUnionAllTest.java @@ -1,3 +1,20 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + package org.apache.doris.nereids.rules.rewrite; import org.apache.doris.nereids.trees.expressions.functions.agg.Count; @@ -31,8 +48,8 @@ void testPushCountStar() { .rewrite() .matches( logicalAggregate( - logicalUnion(logicalAggregate().when(agg -> ExpressionUtils.containsType(agg.getOutputExpressions(), Count.class)) - ,logicalAggregate().when(agg -> ExpressionUtils.containsType(agg.getOutputExpressions(), Count.class))) + logicalUnion(logicalAggregate().when(agg -> ExpressionUtils.containsType(agg.getOutputExpressions(), Count.class)), + logicalAggregate().when(agg -> ExpressionUtils.containsType(agg.getOutputExpressions(), Count.class))) ).when(agg -> ExpressionUtils.containsType(agg.getOutputExpressions(), Sum0.class)) ); String sql2 = "select count(*) from (select id,a from t1 union all select id,a from t1 where id>10) t;"; @@ -41,8 +58,8 @@ void testPushCountStar() { .rewrite() .matches( logicalAggregate( - logicalUnion(logicalAggregate().when(agg -> ExpressionUtils.containsType(agg.getOutputExpressions(), Count.class)) - ,logicalAggregate().when(agg -> ExpressionUtils.containsType(agg.getOutputExpressions(), Count.class))) + logicalUnion(logicalAggregate().when(agg -> ExpressionUtils.containsType(agg.getOutputExpressions(), Count.class)), + logicalAggregate().when(agg -> ExpressionUtils.containsType(agg.getOutputExpressions(), Count.class))) ).when(agg -> ExpressionUtils.containsType(agg.getOutputExpressions(), Sum0.class)) ); } @@ -55,8 +72,8 @@ void testPushCountColumn() { .rewrite() .matches( logicalAggregate( - logicalUnion(logicalAggregate().when(agg -> ExpressionUtils.containsType(agg.getOutputExpressions(), Count.class)) - ,logicalAggregate().when(agg -> ExpressionUtils.containsType(agg.getOutputExpressions(), Count.class))) + logicalUnion(logicalAggregate().when(agg -> ExpressionUtils.containsType(agg.getOutputExpressions(), Count.class)), + logicalAggregate().when(agg -> ExpressionUtils.containsType(agg.getOutputExpressions(), Count.class))) ).when(agg -> ExpressionUtils.containsType(agg.getOutputExpressions(), Sum0.class)) ); } @@ -70,8 +87,8 @@ void testPushCountColumnWithGroupBy() { .matches( logicalAggregate( logicalUnion(logicalAggregate().when(agg -> ExpressionUtils.containsType(agg.getOutputExpressions(), Count.class) - && agg.getGroupByExpressions().size() == 1) - ,logicalAggregate().when(agg -> ExpressionUtils.containsType(agg.getOutputExpressions(), Count.class) + && agg.getGroupByExpressions().size() == 1), + logicalAggregate().when(agg -> ExpressionUtils.containsType(agg.getOutputExpressions(), Count.class) && agg.getGroupByExpressions().size() == 1)) ).when(agg -> ExpressionUtils.containsType(agg.getOutputExpressions(), Sum0.class)) ); @@ -86,8 +103,8 @@ void testPush2CountColumn() { .matches( logicalAggregate( logicalUnion(logicalAggregate().when(agg -> ExpressionUtils.containsType(agg.getOutputExpressions(), Count.class) - && agg.getGroupByExpressions().size() == 1) - ,logicalAggregate().when(agg -> ExpressionUtils.containsType(agg.getOutputExpressions(), Count.class) + && agg.getGroupByExpressions().size() == 1), + logicalAggregate().when(agg -> ExpressionUtils.containsType(agg.getOutputExpressions(), Count.class) && agg.getGroupByExpressions().size() == 1)) ).when(agg -> ExpressionUtils.containsType(agg.getOutputExpressions(), Sum0.class)) );