From 267e2262a059be2973ea0ede3ffac02d9721c58c Mon Sep 17 00:00:00 2001 From: Alex Zaytsev Date: Wed, 10 Aug 2022 20:13:31 +1200 Subject: [PATCH] Add support for nested non aggregating group by operators (#3087) Fixes #3076 --- .../Async/Linq/ByMethod/GroupByTests.cs | 23 +++++++++++++ .../Linq/ByMethod/GroupByTests.cs | 23 +++++++++++++ .../GroupBy/NonAggregatingGroupByRewriter.cs | 33 +++++++++++-------- .../ResultOperators/NonAggregatingGroupBy.cs | 8 +++-- .../ProcessNonAggregatingGroupBy.cs | 3 +- 5 files changed, 74 insertions(+), 16 deletions(-) diff --git a/src/NHibernate.Test/Async/Linq/ByMethod/GroupByTests.cs b/src/NHibernate.Test/Async/Linq/ByMethod/GroupByTests.cs index f23b8f2d14d..e2c553cfbcc 100644 --- a/src/NHibernate.Test/Async/Linq/ByMethod/GroupByTests.cs +++ b/src/NHibernate.Test/Async/Linq/ByMethod/GroupByTests.cs @@ -851,6 +851,29 @@ public async Task GroupByComputedValueFromNestedObjectSelectAsync() Assert.AreEqual(2155, orderGroups.Sum(g => g.Count)); } + [Test(Description="GH-3076")] + public async Task NestedNonAggregateGroupByAsync() + { + var list = await (db.OrderLines + .GroupBy(x => new { x.Order.OrderId, x.Product.ProductId }) // this works fine + .GroupBy(x => x.Key.ProductId) // exception: "A recognition error occurred" + .ToListAsync()); + + Assert.That(list, Has.Count.EqualTo(77)); + } + + [Test(Description="GH-3076")] + public async Task NestedNonAggregateGroupBySelectAsync() + { + var list = await (db.OrderLines + .GroupBy(x => new { x.Order.OrderId, x.Product.ProductId }) // this works fine + .GroupBy(x => x.Key.ProductId) // exception: "A recognition error occurred" + .Select(x => new { ProductId = x }) + .ToListAsync()); + + Assert.That(list, Has.Count.EqualTo(77)); + } + private static void CheckGrouping(IEnumerable> groupedItems, Func groupBy) { var used = new HashSet(); diff --git a/src/NHibernate.Test/Linq/ByMethod/GroupByTests.cs b/src/NHibernate.Test/Linq/ByMethod/GroupByTests.cs index 2b5bab7bcab..92f8457c0c5 100644 --- a/src/NHibernate.Test/Linq/ByMethod/GroupByTests.cs +++ b/src/NHibernate.Test/Linq/ByMethod/GroupByTests.cs @@ -840,6 +840,29 @@ public void GroupByComputedValueFromNestedObjectSelect() Assert.AreEqual(2155, orderGroups.Sum(g => g.Count)); } + [Test(Description="GH-3076")] + public void NestedNonAggregateGroupBy() + { + var list = db.OrderLines + .GroupBy(x => new { x.Order.OrderId, x.Product.ProductId }) // this works fine + .GroupBy(x => x.Key.ProductId) // exception: "A recognition error occurred" + .ToList(); + + Assert.That(list, Has.Count.EqualTo(77)); + } + + [Test(Description="GH-3076")] + public void NestedNonAggregateGroupBySelect() + { + var list = db.OrderLines + .GroupBy(x => new { x.Order.OrderId, x.Product.ProductId }) // this works fine + .GroupBy(x => x.Key.ProductId) // exception: "A recognition error occurred" + .Select(x => new { ProductId = x }) + .ToList(); + + Assert.That(list, Has.Count.EqualTo(77)); + } + private static void CheckGrouping(IEnumerable> groupedItems, Func groupBy) { var used = new HashSet(); diff --git a/src/NHibernate/Linq/GroupBy/NonAggregatingGroupByRewriter.cs b/src/NHibernate/Linq/GroupBy/NonAggregatingGroupByRewriter.cs index 90bc4c1884c..7231d3c0414 100644 --- a/src/NHibernate/Linq/GroupBy/NonAggregatingGroupByRewriter.cs +++ b/src/NHibernate/Linq/GroupBy/NonAggregatingGroupByRewriter.cs @@ -1,4 +1,5 @@ using System; +using System.Linq; using System.Linq.Expressions; using NHibernate.Linq.ResultOperators; using Remotion.Linq; @@ -13,22 +14,23 @@ public static class NonAggregatingGroupByRewriter { public static void ReWrite(QueryModel queryModel) { - if (queryModel.ResultOperators.Count == 1 - && queryModel.ResultOperators[0] is GroupResultOperator + if (queryModel.ResultOperators.Count > 0 + && queryModel.ResultOperators.All(r => r is GroupResultOperator) && IsNonAggregatingGroupBy(queryModel)) { - var resultOperator = (GroupResultOperator)queryModel.ResultOperators[0]; - queryModel.ResultOperators.Clear(); - queryModel.ResultOperators.Add(new NonAggregatingGroupBy(resultOperator)); + for (var i = 0; i < queryModel.ResultOperators.Count; i++) + { + var resultOperator = (GroupResultOperator) queryModel.ResultOperators[i]; + queryModel.ResultOperators[i] = new NonAggregatingGroupBy(resultOperator); + } + return; } - var subQueryExpression = queryModel.MainFromClause.FromExpression as SubQueryExpression; - - if ((subQueryExpression != null) - && (subQueryExpression.QueryModel.ResultOperators.Count == 1) - && (subQueryExpression.QueryModel.ResultOperators[0] is GroupResultOperator) - && (IsNonAggregatingGroupBy(queryModel))) + if (queryModel.MainFromClause.FromExpression is SubQueryExpression subQueryExpression + && subQueryExpression.QueryModel.ResultOperators.Count > 0 + && subQueryExpression.QueryModel.ResultOperators.All(r => r is GroupResultOperator) + && IsNonAggregatingGroupBy(queryModel)) { FlattenSubQuery(subQueryExpression, queryModel); } @@ -58,7 +60,12 @@ private static void FlattenSubQuery(SubQueryExpression subQueryExpression, Query throw new NotImplementedException(); } - queryModel.ResultOperators.Add(new NonAggregatingGroupBy((GroupResultOperator) subQueryModel.ResultOperators[0])); + for (var i = 0; i < subQueryModel.ResultOperators.Count; i++) + { + var resultOperator = new NonAggregatingGroupBy((GroupResultOperator) subQueryModel.ResultOperators[i]); + queryModel.ResultOperators.Add(resultOperator); + } + queryModel.ResultOperators.Add(clientSideSelect); } @@ -103,4 +110,4 @@ public ClientSideSelect2(LambdaExpression selectClause) SelectClause = selectClause; } } -} \ No newline at end of file +} diff --git a/src/NHibernate/Linq/ResultOperators/NonAggregatingGroupBy.cs b/src/NHibernate/Linq/ResultOperators/NonAggregatingGroupBy.cs index 93b86f9f171..69fd5a3772d 100644 --- a/src/NHibernate/Linq/ResultOperators/NonAggregatingGroupBy.cs +++ b/src/NHibernate/Linq/ResultOperators/NonAggregatingGroupBy.cs @@ -1,4 +1,5 @@ using Remotion.Linq.Clauses.ResultOperators; +using Remotion.Linq.Clauses.StreamedData; namespace NHibernate.Linq.ResultOperators { @@ -9,6 +10,9 @@ public NonAggregatingGroupBy(GroupResultOperator groupBy) GroupBy = groupBy; } - public GroupResultOperator GroupBy { get; private set; } + public GroupResultOperator GroupBy { get; } + + public override IStreamedDataInfo GetOutputDataInfo(IStreamedDataInfo inputInfo) => + GroupBy.GetOutputDataInfo(inputInfo); } -} \ No newline at end of file +} diff --git a/src/NHibernate/Linq/Visitors/ResultOperatorProcessors/ProcessNonAggregatingGroupBy.cs b/src/NHibernate/Linq/Visitors/ResultOperatorProcessors/ProcessNonAggregatingGroupBy.cs index 21c82a87eb0..133f2f23572 100644 --- a/src/NHibernate/Linq/Visitors/ResultOperatorProcessors/ProcessNonAggregatingGroupBy.cs +++ b/src/NHibernate/Linq/Visitors/ResultOperatorProcessors/ProcessNonAggregatingGroupBy.cs @@ -4,6 +4,7 @@ using NHibernate.Linq.ResultOperators; using NHibernate.Util; using Remotion.Linq.Clauses.ExpressionVisitors; +using Remotion.Linq.Clauses.StreamedData; namespace NHibernate.Linq.Visitors.ResultOperatorProcessors { @@ -11,7 +12,7 @@ public class ProcessNonAggregatingGroupBy : IResultOperatorProcessor