Skip to content

Commit

Permalink
Translate SequenceEqual for Sqlite and SqlServer byte arrays (#19594)
Browse files Browse the repository at this point in the history
  • Loading branch information
svengeance authored and roji committed Jan 28, 2020
1 parent 14428d2 commit b4ae282
Show file tree
Hide file tree
Showing 6 changed files with 70 additions and 1 deletion.
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.

using System.Collections.Generic;
using System.Reflection;
using JetBrains.Annotations;
using Microsoft.EntityFrameworkCore.Query.SqlExpressions;

namespace Microsoft.EntityFrameworkCore.Query.Internal
{
public class ByteArraySequenceEqualTranslator: IMethodCallTranslator
{
private readonly ISqlExpressionFactory _sqlExpressionFactory;

public ByteArraySequenceEqualTranslator([NotNull] ISqlExpressionFactory sqlExpressionFactory)
{
_sqlExpressionFactory = sqlExpressionFactory;
}

public virtual SqlExpression Translate(SqlExpression instance, MethodInfo method, IReadOnlyList<SqlExpression> arguments)
{
if (method.IsGenericMethod
&& method.GetGenericMethodDefinition().Equals(EnumerableMethods.SequenceEqual)
&& arguments[0].Type == typeof(byte[]))
{
return _sqlExpressionFactory.Equal(arguments[0], arguments[1]);
}

return null;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ public RelationalMethodCallTranslatorProvider([NotNull] RelationalMethodCallTran
new LikeTranslator(sqlExpressionFactory),
new EnumHasFlagTranslator(sqlExpressionFactory),
new GetValueOrDefaultTranslator(sqlExpressionFactory),
new ComparisonTranslator(sqlExpressionFactory)
new ComparisonTranslator(sqlExpressionFactory),
new ByteArraySequenceEqualTranslator(sqlExpressionFactory)
});
_sqlExpressionFactory = sqlExpressionFactory;
}
Expand Down
3 changes: 3 additions & 0 deletions src/Shared/EnumerableMethods.cs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ internal static class EnumerableMethods
public static MethodInfo AnyWithoutPredicate { get; }
public static MethodInfo AnyWithPredicate { get; }
public static MethodInfo Contains { get; }
public static MethodInfo SequenceEqual { get; }

public static MethodInfo ToList { get; }
public static MethodInfo ToArray { get; }
Expand Down Expand Up @@ -151,6 +152,8 @@ static EnumerableMethods()
&& IsFunc(mi.GetParameters()[1].ParameterType));
Contains = enumerableMethods.Single(
mi => mi.Name == nameof(Enumerable.Contains) && mi.GetParameters().Length == 2);
SequenceEqual = enumerableMethods.Single(
mi => mi.Name == nameof(Enumerable.SequenceEqual) && mi.GetParameters().Length == 2);

ToList = enumerableMethods.Single(
mi => mi.Name == nameof(Enumerable.ToList) && mi.GetParameters().Length == 1);
Expand Down
11 changes: 11 additions & 0 deletions test/EFCore.Specification.Tests/Query/GearsOfWarQueryTestBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7548,6 +7548,17 @@ public virtual Task Projecting_required_string_column_compared_to_null_parameter
ss => ss.Set<Gear>().Select(g => g.Nickname == nullParameter));
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Byte_array_filter_by_SequenceEqual(bool async)
{
var byteArrayParam = new byte[] { 0x04, 0x05, 0x06, 0x07, 0x08 };

return AssertQuery(
async,
ss => ss.Set<Squad>().Where(s => s.Banner5.SequenceEqual(byteArrayParam)));
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Group_by_nullable_property_HasValue_and_project_the_grouping_key(bool async)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7527,6 +7527,17 @@ FROM [Gears] AS [g]
WHERE [g].[Discriminator] IN (N'Gear', N'Officer')");
}

public override async Task Byte_array_filter_by_SequenceEqual(bool isAsync)
{
await base.Byte_array_filter_by_SequenceEqual(isAsync);

AssertSql(@"@__byteArrayParam_0='0x0405060708' (Size = 5)
SELECT [s].[Id], [s].[Banner], [s].[Banner5], [s].[InternalNumber], [s].[Name]
FROM [Squads] AS [s]
WHERE [s].[Banner5] = @__byteArrayParam_0");
}

public override async Task Group_by_nullable_property_HasValue_and_project_the_grouping_key(bool async)
{
await base.Group_by_nullable_property_HasValue_and_project_the_grouping_key(async);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,17 @@ SELECT COUNT(*)
WHERE length(""s"".""Banner"") = length(@__byteArrayParam)");
}

public override async Task Byte_array_filter_by_SequenceEqual(bool async)
{
await base.Byte_array_filter_by_SequenceEqual(async);

AssertSql(@"@__byteArrayParam_0='0x0405060708' (Size = 5) (DbType = String)
SELECT ""s"".""Id"", ""s"".""Banner"", ""s"".""Banner5"", ""s"".""InternalNumber"", ""s"".""Name""
FROM ""Squads"" AS ""s""
WHERE ""s"".""Banner5"" = @__byteArrayParam_0");
}

private void AssertSql(params string[] expected)
=> Fixture.TestSqlLoggerFactory.AssertBaseline(expected);
}
Expand Down

0 comments on commit b4ae282

Please sign in to comment.