Skip to content

Commit

Permalink
Query: Relational: Throw exception when translating Count with predic…
Browse files Browse the repository at this point in the history
…ate on Grouping

Translate for InMemory
  • Loading branch information
smitpatel committed Oct 23, 2019
1 parent be0e5ec commit f434755
Show file tree
Hide file tree
Showing 6 changed files with 146 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,25 @@ private static Expression GetSelector(MethodCallExpression methodCallExpression,
throw new InvalidOperationException(CoreStrings.TranslationFailed(methodCallExpression.Print()));
}

private Expression GetPredicate(MethodCallExpression methodCallExpression, GroupByShaperExpression groupByShaperExpression)
{
if (methodCallExpression.Arguments.Count == 1)
{
return null;
}

if (methodCallExpression.Arguments.Count == 2)
{
var selectorLambda = methodCallExpression.Arguments[1].UnwrapLambdaFromQuote();
return ReplacingExpressionVisitor.Replace(
selectorLambda.Parameters[0],
groupByShaperExpression.ElementSelector,
selectorLambda.Body);
}

throw new InvalidOperationException(CoreStrings.TranslationFailed(methodCallExpression.Print()));
}

protected override Expression VisitMethodCall(MethodCallExpression methodCallExpression)
{
if (methodCallExpression.Method.IsGenericMethod
Expand All @@ -276,29 +295,22 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp
&& methodCallExpression.Arguments.Count > 0
&& methodCallExpression.Arguments[0] is InMemoryGroupByShaperExpression groupByShaperExpression)
{
switch (methodCallExpression.Method.Name)
var methodName = methodCallExpression.Method.Name;
switch (methodName)
{
case nameof(Enumerable.Average):
case nameof(Enumerable.Max):
case nameof(Enumerable.Min):
case nameof(Enumerable.Sum):
{
var translation = Translate(GetSelector(methodCallExpression, groupByShaperExpression));
if (translation == null)
{
return null;
}

var selector = Expression.Lambda(translation, groupByShaperExpression.ValueBufferParameter);
MethodInfo getMethod()
=> methodCallExpression.Method.Name switch
{
nameof(Enumerable.Average) => InMemoryLinqOperatorProvider.GetAverageWithSelector(selector.ReturnType),
nameof(Enumerable.Max) => InMemoryLinqOperatorProvider.GetMaxWithSelector(selector.ReturnType),
nameof(Enumerable.Min) => InMemoryLinqOperatorProvider.GetMinWithSelector(selector.ReturnType),
nameof(Enumerable.Sum) => InMemoryLinqOperatorProvider.GetSumWithSelector(selector.ReturnType),
_ => throw new InvalidOperationException("Invalid Aggregate Operator encountered."),
};
var method = getMethod();
var method = GetMethod();
method = method.GetGenericArguments().Length == 2
? method.MakeGenericMethod(typeof(ValueBuffer), selector.ReturnType)
: method.MakeGenericMethod(typeof(ValueBuffer));
Expand All @@ -307,14 +319,47 @@ MethodInfo getMethod()
groupByShaperExpression.GroupingParameter,
selector);

MethodInfo GetMethod()
=> methodName switch
{
nameof(Enumerable.Average) => InMemoryLinqOperatorProvider.GetAverageWithSelector(selector.ReturnType),
nameof(Enumerable.Max) => InMemoryLinqOperatorProvider.GetMaxWithSelector(selector.ReturnType),
nameof(Enumerable.Min) => InMemoryLinqOperatorProvider.GetMinWithSelector(selector.ReturnType),
nameof(Enumerable.Sum) => InMemoryLinqOperatorProvider.GetSumWithSelector(selector.ReturnType),
_ => throw new InvalidOperationException("Invalid Aggregate Operator encountered."),
};
}

case nameof(Enumerable.Count):
return Expression.Call(
InMemoryLinqOperatorProvider.CountWithoutPredicate.MakeGenericMethod(typeof(ValueBuffer)),
groupByShaperExpression.GroupingParameter);
case nameof(Enumerable.LongCount):
{
var predicate = GetPredicate(methodCallExpression, groupByShaperExpression);
if (predicate == null)
{
return Expression.Call(
(string.Equals(methodName, nameof(Enumerable.Count))
? InMemoryLinqOperatorProvider.CountWithoutPredicate
: InMemoryLinqOperatorProvider.LongCountWithoutPredicate)
.MakeGenericMethod(typeof(ValueBuffer)),
groupByShaperExpression.GroupingParameter);
}

var translation = Translate(predicate);
if (translation == null)
{
return null;
}

predicate = Expression.Lambda(translation, groupByShaperExpression.ValueBufferParameter);

return Expression.Call(
InMemoryLinqOperatorProvider.LongCountWithoutPredicate.MakeGenericMethod(typeof(ValueBuffer)),
groupByShaperExpression.GroupingParameter);
(string.Equals(methodName, nameof(Enumerable.Count))
? InMemoryLinqOperatorProvider.CountWithoutPredicate
: InMemoryLinqOperatorProvider.LongCountWithoutPredicate)
.MakeGenericMethod(typeof(ValueBuffer)),
groupByShaperExpression.GroupingParameter,
predicate);
}

default:
throw new InvalidOperationException(CoreStrings.TranslationFailed(methodCallExpression.Print()));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,14 +100,24 @@ public virtual SqlExpression TranslateAverage(Expression expression)

public virtual SqlExpression TranslateCount(Expression expression = null)
{
// TODO: Translate Count with predicate for GroupBy
if (expression != null)
{
// TODO: Translate Count with predicate for GroupBy
return null;
}

return _sqlExpressionFactory.ApplyDefaultTypeMapping(
_sqlExpressionFactory.Function("COUNT", new[] { _sqlExpressionFactory.Fragment("*") }, typeof(int)));
}

public virtual SqlExpression TranslateLongCount(Expression expression = null)
{
// TODO: Translate Count with predicate for GroupBy
if (expression != null)
{
// TODO: Translate Count with predicate for GroupBy
return null;
}

return _sqlExpressionFactory.ApplyDefaultTypeMapping(
_sqlExpressionFactory.Function("COUNT", new[] { _sqlExpressionFactory.Fragment("*") }, typeof(long)));
}
Expand Down Expand Up @@ -288,6 +298,25 @@ private Expression GetSelector(MethodCallExpression methodCallExpression, GroupB
throw new InvalidOperationException(CoreStrings.TranslationFailed(methodCallExpression.Print()));
}

private Expression GetPredicate(MethodCallExpression methodCallExpression, GroupByShaperExpression groupByShaperExpression)
{
if (methodCallExpression.Arguments.Count == 1)
{
return null;
}

if (methodCallExpression.Arguments.Count == 2)
{
var selectorLambda = methodCallExpression.Arguments[1].UnwrapLambdaFromQuote();
return ReplacingExpressionVisitor.Replace(
selectorLambda.Parameters[0],
groupByShaperExpression.ElementSelector,
selectorLambda.Body);
}

throw new InvalidOperationException(CoreStrings.TranslationFailed(methodCallExpression.Print()));
}

protected override Expression VisitMethodCall(MethodCallExpression methodCallExpression)
{
// EF.Property case
Expand All @@ -310,8 +339,8 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp
var translatedAggregate = methodCallExpression.Method.Name switch
{
nameof(Enumerable.Average) => TranslateAverage(GetSelector(methodCallExpression, groupByShaperExpression)),
nameof(Enumerable.Count) => TranslateCount(),
nameof(Enumerable.LongCount) => TranslateLongCount(),
nameof(Enumerable.Count) => TranslateCount(GetPredicate(methodCallExpression, groupByShaperExpression)),
nameof(Enumerable.LongCount) => TranslateLongCount(GetPredicate(methodCallExpression, groupByShaperExpression)),
nameof(Enumerable.Max) => TranslateMax(GetSelector(methodCallExpression, groupByShaperExpression)),
nameof(Enumerable.Min) => TranslateMin(GetSelector(methodCallExpression, groupByShaperExpression)),
nameof(Enumerable.Sum) => TranslateSum(GetSelector(methodCallExpression, groupByShaperExpression)),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,12 @@ protected override Expression VisitBinary(BinaryExpression binaryExpression)

public override SqlExpression TranslateLongCount(Expression expression = null)
{
// TODO: Translate Count with predicate for GroupBy
if (expression != null)
{
// TODO: Translate Count with predicate for GroupBy
return null;
}

return _sqlExpressionFactory.ApplyDefaultTypeMapping(
_sqlExpressionFactory.Function("COUNT_BIG", new[] { _sqlExpressionFactory.Fragment("*") }, typeof(long)));
}
Expand Down
18 changes: 18 additions & 0 deletions test/EFCore.Specification.Tests/Query/GroupByQueryTestBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1904,6 +1904,24 @@ public virtual Task GroupBy_with_order_by_skip_and_another_order_by(bool isAsync
);
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task GroupBy_Property_Select_Count_with_predicate(bool isAsync)
{
return AssertQueryScalar(
isAsync,
ss => ss.Set<Order>().GroupBy(o => o.CustomerID).Select(g => g.Count(o => o.OrderID < 10300)));
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task GroupBy_Property_Select_LongCount_with_predicate(bool isAsync)
{
return AssertQueryScalar(
isAsync,
ss => ss.Set<Order>().GroupBy(o => o.CustomerID).Select(g => g.LongCount(o => o.OrderID < 10300)));
}

#endregion

#region GroupByWithoutAggregate
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.

using System;
using System.Linq;
using System.Threading.Tasks;
using Microsoft.EntityFrameworkCore.TestUtilities;
Expand Down Expand Up @@ -2115,6 +2116,18 @@ OFFSET @__p_0 ROWS
GROUP BY [t].[CustomerID]");
}

public override Task GroupBy_Property_Select_Count_with_predicate(bool isAsync)
{
return Assert.ThrowsAsync<InvalidOperationException>(
() => base.GroupBy_Property_Select_Count_with_predicate(isAsync));
}

public override Task GroupBy_Property_Select_LongCount_with_predicate(bool isAsync)
{
return Assert.ThrowsAsync<InvalidOperationException>(
() => base.GroupBy_Property_Select_LongCount_with_predicate(isAsync));
}

public override async Task GroupBy_with_grouping_key_using_Like(bool isAsync)
{
await base.GroupBy_with_grouping_key_using_Like(isAsync);
Expand Down
15 changes: 15 additions & 0 deletions test/EFCore.Sqlite.FunctionalTests/Query/GroupByQuerySqliteTest.cs
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.

using System;
using System.Threading.Tasks;
using Microsoft.EntityFrameworkCore.TestUtilities;
using Xunit;
using Xunit.Abstractions;

namespace Microsoft.EntityFrameworkCore.Query
Expand All @@ -15,5 +18,17 @@ public GroupByQuerySqliteTest(NorthwindQuerySqliteFixture<NoopModelCustomizer> f
Fixture.TestSqlLoggerFactory.Clear();
//Fixture.TestSqlLoggerFactory.SetTestOutputHelper(testOutputHelper);
}

public override Task GroupBy_Property_Select_Count_with_predicate(bool isAsync)
{
return Assert.ThrowsAsync<InvalidOperationException>(
() => base.GroupBy_Property_Select_Count_with_predicate(isAsync));
}

public override Task GroupBy_Property_Select_LongCount_with_predicate(bool isAsync)
{
return Assert.ThrowsAsync<InvalidOperationException>(
() => base.GroupBy_Property_Select_LongCount_with_predicate(isAsync));
}
}
}

0 comments on commit f434755

Please sign in to comment.