From c20b140dfbffdfc8d3faa61b27faea3256ff372e Mon Sep 17 00:00:00 2001 From: Smit Patel Date: Tue, 4 Jan 2022 21:36:54 -0800 Subject: [PATCH] Query: Update column expression correctly when lifting joins from group by aggregate subquery When replacing columns, we used the outer select expression which had additional joins from previous term whose aliases match tables in current join and it got updated with wrong table. The fix is to utilize the original tables when replacing columns. These column replacement is to map the columns from initial tables of group by from subquery to outer group by query. Resolves #27083 --- .../SqlExpressions/SelectExpression.Helper.cs | 21 ++-- .../Query/SqlExpressions/SelectExpression.cs | 5 +- .../Query/SimpleQueryTestBase.cs | 98 +++++++++++++++++++ .../Query/SimpleQuerySqlServerTest.cs | 16 +++ 4 files changed, 130 insertions(+), 10 deletions(-) diff --git a/src/EFCore.Relational/Query/SqlExpressions/SelectExpression.Helper.cs b/src/EFCore.Relational/Query/SqlExpressions/SelectExpression.Helper.cs index d5320431d84..ed599055dc7 100644 --- a/src/EFCore.Relational/Query/SqlExpressions/SelectExpression.Helper.cs +++ b/src/EFCore.Relational/Query/SqlExpressions/SelectExpression.Helper.cs @@ -810,9 +810,8 @@ private sealed class CloningExpressionVisitor : ExpressionVisitor } // Now that we have SelectExpression, we visit all components and update table references inside columns - newSelectExpression = - (SelectExpression)new ColumnExpressionReplacingExpressionVisitor(selectExpression, newSelectExpression) - .Visit(newSelectExpression); + newSelectExpression = (SelectExpression)new ColumnExpressionReplacingExpressionVisitor( + selectExpression, newSelectExpression._tableReferences).Visit(newSelectExpression); return newSelectExpression; } @@ -826,10 +825,11 @@ private sealed class ColumnExpressionReplacingExpressionVisitor : ExpressionVisi private readonly SelectExpression _oldSelectExpression; private readonly Dictionary _newTableReferences; - public ColumnExpressionReplacingExpressionVisitor(SelectExpression oldSelectExpression, SelectExpression newSelectExpression) + public ColumnExpressionReplacingExpressionVisitor( + SelectExpression oldSelectExpression, IEnumerable newTableReferences) { _oldSelectExpression = oldSelectExpression; - _newTableReferences = newSelectExpression._tableReferences.ToDictionary(e => e.Alias); + _newTableReferences = newTableReferences.ToDictionary(e => e.Alias); } [return: NotNullIfNotNull("expression")] @@ -894,8 +894,14 @@ public GroupByAggregateLiftingExpressionVisitor(SelectExpression selectExpressio if (initialTableCounts > 0) { // If there are no initial table then this is not correlated grouping subquery + // We only replace columns from initial tables. + // Additional tables may have been added to outer from other terms which may end up matching on table alias var columnExpressionReplacingExpressionVisitor = - new ColumnExpressionReplacingExpressionVisitor(subquery, _selectExpression); + AppContext.TryGetSwitch("Microsoft.EntityFrameworkCore.Issue27083", out var enabled2) && enabled2 + ? new ColumnExpressionReplacingExpressionVisitor( + subquery, _selectExpression._tableReferences) + : new ColumnExpressionReplacingExpressionVisitor( + subquery, _selectExpression._tableReferences.Take(initialTableCounts)); if (subquery._tables.Count != initialTableCounts) { // If subquery has more tables then we expanded join on it. @@ -931,7 +937,8 @@ private void CopyOverOwnedJoinInSameTable(SelectExpression target, SelectExpress { if (target._projection.Count != source._projection.Count) { - var columnExpressionReplacingExpressionVisitor = new ColumnExpressionReplacingExpressionVisitor(source, target); + var columnExpressionReplacingExpressionVisitor = new ColumnExpressionReplacingExpressionVisitor( + source, target._tableReferences); var minProjectionCount = Math.Min(target._projection.Count, source._projection.Count); var initialProjectionCount = 0; for (var i = 0; i < minProjectionCount; i++) diff --git a/src/EFCore.Relational/Query/SqlExpressions/SelectExpression.cs b/src/EFCore.Relational/Query/SqlExpressions/SelectExpression.cs index 98ecfad6e56..fddea73a0f7 100644 --- a/src/EFCore.Relational/Query/SqlExpressions/SelectExpression.cs +++ b/src/EFCore.Relational/Query/SqlExpressions/SelectExpression.cs @@ -622,9 +622,8 @@ static Expression RemoveConvert(Expression expression) if (querySplittingBehavior == QuerySplittingBehavior.SplitQuery) { var outerSelectExpression = (SelectExpression)cloningExpressionVisitor!.Visit(baseSelectExpression!); - innerSelectExpression = - (SelectExpression)new ColumnExpressionReplacingExpressionVisitor(this, outerSelectExpression) - .Visit(innerSelectExpression); + innerSelectExpression = (SelectExpression)new ColumnExpressionReplacingExpressionVisitor( + this, outerSelectExpression._tableReferences).Visit(innerSelectExpression); if (outerSelectExpression.Limit != null || outerSelectExpression.Offset != null diff --git a/test/EFCore.Specification.Tests/Query/SimpleQueryTestBase.cs b/test/EFCore.Specification.Tests/Query/SimpleQueryTestBase.cs index 013cd5922a4..970ea53e607 100644 --- a/test/EFCore.Specification.Tests/Query/SimpleQueryTestBase.cs +++ b/test/EFCore.Specification.Tests/Query/SimpleQueryTestBase.cs @@ -598,5 +598,103 @@ protected class OrderItem public DateTime? ShippingDate { get; set; } public DateTime? CancellationDate { get; set; } } + + [ConditionalTheory] + [MemberData(nameof(IsAsyncData))] + public virtual async Task GroupBy_Aggregate_over_navigations_repeated(bool async) + { + var contextFactory = await InitializeAsync(seed: c => c.Seed()); + using var context = contextFactory.CreateContext(); + + var query = context + .Set() + .Where(x => x.OrderId != null) + .GroupBy(x => x.OrderId) + .Select(x => new + { + HourlyRate = x.Min(f => f.Order.HourlyRate), + CustomerId = x.Min(f => f.Project.Customer.Id), + CustomerName = x.Min(f => f.Project.Customer.Name), + }); + + var timeSheets = async + ? await query.ToListAsync() + : query.ToList(); + + Assert.Equal(2, timeSheets.Count); + } + + protected class Context27083 : DbContext + { + public Context27083(DbContextOptions options) + : base(options) + { + } + + public DbSet TimeSheets { get; set; } + public DbSet Customers { get; set; } + + public void Seed() + { + var customerA = new Customer { Name = "Customer A" }; + var customerB = new Customer { Name = "Customer B" }; + + var projectA = new Project { Customer = customerA }; + var projectB = new Project { Customer = customerB }; + + var orderA1 = new Order { Number = "A1", Customer = customerA, HourlyRate = 10 }; + var orderA2 = new Order { Number = "A2", Customer = customerA, HourlyRate = 11 }; + var orderB1 = new Order { Number = "B1", Customer = customerB, HourlyRate = 20 }; + + var timeSheetA = new TimeSheet { Order = orderA1, Project = projectA }; + var timeSheetB = new TimeSheet { Order = orderB1, Project = projectB }; + + AddRange(customerA, customerB); + AddRange(projectA, projectB); + AddRange(orderA1, orderA2, orderB1); + AddRange(timeSheetA, timeSheetB); + SaveChanges(); + } + } + + protected class Customer + { + public int Id { get; set; } + + public string Name { get; set; } + + public List Projects { get; set; } + public List Orders { get; set; } + } + + protected class Order + { + public int Id { get; set; } + public string Number { get; set; } + + public int CustomerId { get; set; } + public Customer Customer { get; set; } + + public int HourlyRate { get; set; } + } + + protected class Project + { + public int Id { get; set; } + public int CustomerId { get; set; } + + public Customer Customer { get; set; } + } + + protected class TimeSheet + { + public int Id { get; set; } + + public int ProjectId { get; set; } + public Project Project { get; set; } + + public int? OrderId { get; set; } + public Order Order { get; set; } + } } } diff --git a/test/EFCore.SqlServer.FunctionalTests/Query/SimpleQuerySqlServerTest.cs b/test/EFCore.SqlServer.FunctionalTests/Query/SimpleQuerySqlServerTest.cs index 6c327967dc4..e2d9572bf33 100644 --- a/test/EFCore.SqlServer.FunctionalTests/Query/SimpleQuerySqlServerTest.cs +++ b/test/EFCore.SqlServer.FunctionalTests/Query/SimpleQuerySqlServerTest.cs @@ -147,5 +147,21 @@ GROUP BY [o0].[OrderId] WHERE [o].[OrderId] = @__orderId_0 ORDER BY [o].[OrderId]"); } + + public override async Task GroupBy_Aggregate_over_navigations_repeated(bool async) + { + await base.GroupBy_Aggregate_over_navigations_repeated(async); + + AssertSql( + @"SELECT MIN([o].[HourlyRate]) AS [HourlyRate], MIN([c].[Id]) AS [CustomerId], MIN([c0].[Name]) AS [CustomerName] +FROM [TimeSheets] AS [t] +LEFT JOIN [Order] AS [o] ON [t].[OrderId] = [o].[Id] +INNER JOIN [Project] AS [p] ON [t].[ProjectId] = [p].[Id] +INNER JOIN [Customers] AS [c] ON [p].[CustomerId] = [c].[Id] +INNER JOIN [Project] AS [p0] ON [t].[ProjectId] = [p0].[Id] +INNER JOIN [Customers] AS [c0] ON [p0].[CustomerId] = [c0].[Id] +WHERE [t].[OrderId] IS NOT NULL +GROUP BY [t].[OrderId]"); + } } }