Skip to content

Commit

Permalink
Implement sum and average aggregation for decimal in SQLite (#33721)
Browse files Browse the repository at this point in the history
  • Loading branch information
ranma42 authored May 29, 2024
1 parent 7acafb2 commit d8fc38e
Show file tree
Hide file tree
Showing 8 changed files with 272 additions and 65 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,30 @@ protected virtual SqlExpression VisitRegexp(
return regexpExpression.Update(match, pattern);
}

/// <inheritdoc/>
protected override SqlExpression VisitSqlFunction(
SqlFunctionExpression sqlFunctionExpression,
bool allowOptimizedExpansion,
out bool nullable)
{
var result = base.VisitSqlFunction(sqlFunctionExpression, allowOptimizedExpansion, out nullable);

if (result is SqlFunctionExpression resultFunctionExpression
&& resultFunctionExpression.IsBuiltIn
&& string.Equals(resultFunctionExpression.Name, "ef_sum", StringComparison.OrdinalIgnoreCase))
{
nullable = false;

var sqlExpressionFactory = Dependencies.SqlExpressionFactory;
return sqlExpressionFactory.Coalesce(
result,
sqlExpressionFactory.Constant(0, resultFunctionExpression.TypeMapping),
resultFunctionExpression.TypeMapping);
}

return result;
}

#pragma warning disable EF1001
/// <summary>
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,14 @@ public SqliteQueryableAggregateMethodTranslator(ISqlExpressionFactory sqlExpress
var averageArgumentType = GetProviderType(averageSqlExpression);
if (averageArgumentType == typeof(decimal))
{
throw new NotSupportedException(
SqliteStrings.AggregateOperationNotSupported(
nameof(Queryable.Average), averageArgumentType.ShortDisplayName()));
averageSqlExpression = CombineTerms(source, averageSqlExpression);
return _sqlExpressionFactory.Function(
"ef_avg",
[averageSqlExpression],
nullable: true,
argumentsPropagateNullability: [false],
averageSqlExpression.Type,
averageSqlExpression.TypeMapping);
}

break;
Expand Down Expand Up @@ -100,8 +105,14 @@ public SqliteQueryableAggregateMethodTranslator(ISqlExpressionFactory sqlExpress
var sumArgumentType = GetProviderType(sumSqlExpression);
if (sumArgumentType == typeof(decimal))
{
throw new NotSupportedException(
SqliteStrings.AggregateOperationNotSupported(nameof(Queryable.Sum), sumArgumentType.ShortDisplayName()));
sumSqlExpression = CombineTerms(source, sumSqlExpression);
return _sqlExpressionFactory.Function(
"ef_sum",
[sumSqlExpression],
nullable: true,
argumentsPropagateNullability: [false],
sumSqlExpression.Type,
sumSqlExpression.TypeMapping);
}

break;
Expand All @@ -115,4 +126,21 @@ public SqliteQueryableAggregateMethodTranslator(ISqlExpressionFactory sqlExpress
=> expression.TypeMapping?.Converter?.ProviderClrType
?? expression.TypeMapping?.ClrType
?? expression.Type;

private SqlExpression CombineTerms(EnumerableExpression enumerableExpression, SqlExpression sqlExpression)
{
if (enumerableExpression.Predicate != null)
{
sqlExpression = _sqlExpressionFactory.Case(
new List<CaseWhenClause> { new(enumerableExpression.Predicate, sqlExpression) },
elseResult: null);
}

if (enumerableExpression.IsDistinct)
{
sqlExpression = new DistinctExpression(sqlExpression);
}

return sqlExpression;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,25 @@ private void InitializeDbConnection(DbConnection connection)
name: "ef_negate",
(decimal? m) => -m,
isDeterministic: true);

sqliteConnection.CreateAggregate(
"ef_avg",
seed: (0m, 0ul),
((decimal sum, ulong count) acc, decimal? value) => value is null
? acc
: (acc.sum + value.Value, acc.count + 1),
((decimal sum, ulong count) acc) => acc.count == 0
? default(decimal?)
: acc.sum / acc.count,
isDeterministic: true);

sqliteConnection.CreateAggregate(
"ef_sum",
seed: null,
(decimal? sum, decimal? value) => value is null
? sum
: sum is null ? value : sum.Value + value.Value,
isDeterministic: true);
}
else
{
Expand Down
22 changes: 10 additions & 12 deletions test/EFCore.Sqlite.FunctionalTests/BuiltInDataTypesSqliteTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -936,7 +936,7 @@ public virtual void Cant_query_Max_of_converted_types()
}

[ConditionalFact]
public virtual void Cant_query_Average_of_converted_types()
public virtual void Can_query_Average_of_converted_types()
{
using var context = CreateContext();
context.Add(
Expand All @@ -958,15 +958,14 @@ public virtual void Cant_query_Average_of_converted_types()
context.SaveChanges();

Assert.Equal(
SqliteStrings.AggregateOperationNotSupported(nameof(Queryable.Average), typeof(decimal).ShortDisplayName()),
Assert.Throws<NotSupportedException>(
() => context.Set<BuiltInNullableDataTypes>()
.Where(e => e.PartitionId == 202)
.Average(e => e.TestNullableDecimal)).Message);
1.000000000000002m,
context.Set<BuiltInNullableDataTypes>()
.Where(e => e.PartitionId == 202)
.Average(e => e.TestNullableDecimal));
}

[ConditionalFact]
public virtual void Cant_query_Sum_of_converted_types()
public virtual void Can_query_Sum_of_converted_types()
{
using var context = CreateContext();
context.Add(
Expand All @@ -988,11 +987,10 @@ public virtual void Cant_query_Sum_of_converted_types()
context.SaveChanges();

Assert.Equal(
SqliteStrings.AggregateOperationNotSupported(nameof(Queryable.Sum), typeof(decimal).ShortDisplayName()),
Assert.Throws<NotSupportedException>(
() => context.Set<BuiltInDataTypes>()
.Where(e => e.PartitionId == 203)
.Sum(e => e.TestDecimal)).Message);
2.000000000000002m,
context.Set<BuiltInDataTypes>()
.Where(e => e.PartitionId == 203)
.Sum(e => e.TestDecimal));
}

[ConditionalFact]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,68 @@ INSERT INTO ZeroKey VALUES (NULL)
""");

public override async Task Average_with_cast()
=> Assert.Equal(
SqliteStrings.AggregateOperationNotSupported("Average", "decimal"),
(await Assert.ThrowsAsync<NotSupportedException>(base.Average_with_cast)).Message);
{
await base.Average_with_cast();

AssertSql(
"""
SELECT "p"."Id", "p"."DecimalColumn", "p"."DoubleColumn", "p"."FloatColumn", "p"."IntColumn", "p"."LongColumn", "p"."NullableDecimalColumn", "p"."NullableDoubleColumn", "p"."NullableFloatColumn", "p"."NullableIntColumn", "p"."NullableLongColumn", "p"."Price"
FROM "Prices" AS "p"
""",
//
"""
SELECT ef_avg("p"."Price")
FROM "Prices" AS "p"
""",
//
"""
SELECT AVG(CAST("p"."IntColumn" AS REAL))
FROM "Prices" AS "p"
""",
//
"""
SELECT AVG(CAST("p"."NullableIntColumn" AS REAL))
FROM "Prices" AS "p"
""",
//
"""
SELECT AVG(CAST("p"."LongColumn" AS REAL))
FROM "Prices" AS "p"
""",
//
"""
SELECT AVG(CAST("p"."NullableLongColumn" AS REAL))
FROM "Prices" AS "p"
""",
//
"""
SELECT CAST(AVG("p"."FloatColumn") AS REAL)
FROM "Prices" AS "p"
""",
//
"""
SELECT CAST(AVG("p"."NullableFloatColumn") AS REAL)
FROM "Prices" AS "p"
""",
//
"""
SELECT AVG("p"."DoubleColumn")
FROM "Prices" AS "p"
""",
//
"""
SELECT AVG("p"."NullableDoubleColumn")
FROM "Prices" AS "p"
""",
//
"""
SELECT ef_avg("p"."DecimalColumn")
FROM "Prices" AS "p"
""",
//
"""
SELECT ef_avg("p"."NullableDecimalColumn")
FROM "Prices" AS "p"
""");
}
}
17 changes: 13 additions & 4 deletions test/EFCore.Sqlite.FunctionalTests/Query/Ef6GroupBySqliteTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,16 @@ public Ef6GroupBySqliteTest(Ef6GroupBySqliteFixture fixture, ITestOutputHelper t
}

public override async Task Average_Grouped_from_LINQ_101(bool async)
=> Assert.Equal(
SqliteStrings.AggregateOperationNotSupported("Average", "decimal"),
(await Assert.ThrowsAsync<NotSupportedException>(
() => base.Average_Grouped_from_LINQ_101(async))).Message);
{
await base.Average_Grouped_from_LINQ_101(async);

AssertSql(
"""
SELECT "p"."Category", ef_avg("p"."UnitPrice") AS "AveragePrice"
FROM "ProductForLinq" AS "p"
GROUP BY "p"."Category"
""");
}

public override async Task Max_Grouped_from_LINQ_101(bool async)
=> Assert.Equal(
Expand Down Expand Up @@ -49,6 +55,9 @@ public override async Task Group_Join_from_LINQ_101(bool async)
(await Assert.ThrowsAsync<InvalidOperationException>(
() => base.Group_Join_from_LINQ_101(async))).Message);

private void AssertSql(params string[] expected)
=> Fixture.TestSqlLoggerFactory.AssertBaseline(expected);

public class Ef6GroupBySqliteFixture : Ef6GroupByFixtureBase, ITestSqlLoggerFactory
{
public TestSqlLoggerFactory TestSqlLoggerFactory
Expand Down
Loading

0 comments on commit d8fc38e

Please sign in to comment.