Skip to content

Commit

Permalink
Translate XOR, shift left, shift right
Browse files Browse the repository at this point in the history
Closes #18795
  • Loading branch information
roji committed Nov 17, 2019
1 parent 293dab3 commit a887fce
Show file tree
Hide file tree
Showing 12 changed files with 261 additions and 105 deletions.
58 changes: 31 additions & 27 deletions src/EFCore.Cosmos/Query/Internal/SqlExpressionFactory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -144,9 +144,6 @@ private SqlExpression ApplyTypeMappingOnSqlBinary(
var left = sqlBinaryExpression.Left;
var right = sqlBinaryExpression.Right;

Type resultType;
CoreTypeMapping resultTypeMapping;
CoreTypeMapping inferredTypeMapping;
switch (sqlBinaryExpression.OperatorType)
{
case ExpressionType.Equal:
Expand All @@ -156,49 +153,56 @@ private SqlExpression ApplyTypeMappingOnSqlBinary(
case ExpressionType.LessThanOrEqual:
case ExpressionType.NotEqual:
{
inferredTypeMapping = ExpressionExtensions.InferTypeMapping(left, right)
?? _typeMappingSource.FindMapping(left.Type);
resultType = typeof(bool);
resultTypeMapping = _boolTypeMapping;
var inferredTypeMapping = ExpressionExtensions.InferTypeMapping(left, right)
?? _typeMappingSource.FindMapping(left.Type);
return new SqlBinaryExpression(
sqlBinaryExpression.OperatorType,
ApplyTypeMapping(left, inferredTypeMapping),
ApplyTypeMapping(right, inferredTypeMapping),
typeof(bool),
_boolTypeMapping);
}
break;

case ExpressionType.AndAlso:
case ExpressionType.OrElse:
{
inferredTypeMapping = _boolTypeMapping;
resultType = typeof(bool);
resultTypeMapping = _boolTypeMapping;
}
break;
return new SqlBinaryExpression(
sqlBinaryExpression.OperatorType,
ApplyTypeMapping(left, _boolTypeMapping),
ApplyTypeMapping(right, _boolTypeMapping),
typeof(bool),
_boolTypeMapping);

case ExpressionType.Add:
case ExpressionType.Subtract:
case ExpressionType.Multiply:
case ExpressionType.Divide:
case ExpressionType.Modulo:
case ExpressionType.LeftShift:
case ExpressionType.RightShift:
case ExpressionType.Coalesce:
case ExpressionType.And:
case ExpressionType.Or:
case ExpressionType.ExclusiveOr:
{
inferredTypeMapping = typeMapping ?? ExpressionExtensions.InferTypeMapping(left, right);
resultType = left.Type;
resultTypeMapping = inferredTypeMapping;
var inferredTypeMapping = typeMapping ?? ExpressionExtensions.InferTypeMapping(left, right);
return new SqlBinaryExpression(
sqlBinaryExpression.OperatorType,
ApplyTypeMapping(left, inferredTypeMapping),
ApplyTypeMapping(right, inferredTypeMapping),
left.Type,
inferredTypeMapping);
}
break;

case ExpressionType.LeftShift:
case ExpressionType.RightShift:
return new SqlBinaryExpression(
sqlBinaryExpression.OperatorType,
ApplyTypeMapping(left, typeMapping),
ApplyDefaultTypeMapping(right),
left.Type,
typeMapping);

default:
throw new InvalidOperationException("Incorrect operatorType for SqlBinaryExpression");
}

return new SqlBinaryExpression(
sqlBinaryExpression.OperatorType,
ApplyTypeMapping(left, inferredTypeMapping),
ApplyTypeMapping(right, inferredTypeMapping),
resultType,
resultTypeMapping);
}

/// <summary>
Expand Down
5 changes: 4 additions & 1 deletion src/EFCore.Relational/Query/QuerySqlGenerator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,10 @@ private static readonly Regex _composableSql
{ ExpressionType.Divide, " / " },
{ ExpressionType.Modulo, " % " },
{ ExpressionType.And, " & " },
{ ExpressionType.Or, " | " }
{ ExpressionType.Or, " | " },
{ ExpressionType.ExclusiveOr, " ^ " },
{ ExpressionType.LeftShift, " << " },
{ ExpressionType.RightShift, " >> " }
};

public QuerySqlGenerator(QuerySqlGeneratorDependencies dependencies)
Expand Down
54 changes: 30 additions & 24 deletions src/EFCore.Relational/Query/SqlExpressionFactory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -143,9 +143,6 @@ private SqlExpression ApplyTypeMappingOnSqlBinary(
var left = sqlBinaryExpression.Left;
var right = sqlBinaryExpression.Right;

Type resultType;
RelationalTypeMapping resultTypeMapping;
RelationalTypeMapping inferredTypeMapping;
switch (sqlBinaryExpression.OperatorType)
{
case ExpressionType.Equal:
Expand All @@ -155,21 +152,24 @@ private SqlExpression ApplyTypeMappingOnSqlBinary(
case ExpressionType.LessThanOrEqual:
case ExpressionType.NotEqual:
{
inferredTypeMapping = ExpressionExtensions.InferTypeMapping(left, right)
var inferredTypeMapping = ExpressionExtensions.InferTypeMapping(left, right)
?? _typeMappingSource.FindMapping(left.Type);
resultType = typeof(bool);
resultTypeMapping = _boolTypeMapping;
break;
return new SqlBinaryExpression(
sqlBinaryExpression.OperatorType,
ApplyTypeMapping(left, inferredTypeMapping),
ApplyTypeMapping(right, inferredTypeMapping),
typeof(bool),
_boolTypeMapping);
}

case ExpressionType.AndAlso:
case ExpressionType.OrElse:
{
inferredTypeMapping = _boolTypeMapping;
resultType = typeof(bool);
resultTypeMapping = _boolTypeMapping;
break;
}
return new SqlBinaryExpression(
sqlBinaryExpression.OperatorType,
ApplyTypeMapping(left, _boolTypeMapping),
ApplyTypeMapping(right, _boolTypeMapping),
typeof(bool),
_boolTypeMapping);

case ExpressionType.Add:
case ExpressionType.Subtract:
Expand All @@ -179,23 +179,29 @@ private SqlExpression ApplyTypeMappingOnSqlBinary(
case ExpressionType.Coalesce:
case ExpressionType.And:
case ExpressionType.Or:
case ExpressionType.ExclusiveOr:
{
inferredTypeMapping = typeMapping ?? ExpressionExtensions.InferTypeMapping(left, right);
resultType = left.Type;
resultTypeMapping = inferredTypeMapping;
break;
var inferredTypeMapping = typeMapping ?? ExpressionExtensions.InferTypeMapping(left, right);
return new SqlBinaryExpression(
sqlBinaryExpression.OperatorType,
ApplyTypeMapping(left, inferredTypeMapping),
ApplyTypeMapping(right, inferredTypeMapping),
left.Type,
inferredTypeMapping);
}

case ExpressionType.LeftShift:
case ExpressionType.RightShift:
return new SqlBinaryExpression(
sqlBinaryExpression.OperatorType,
ApplyTypeMapping(left, typeMapping),
ApplyDefaultTypeMapping(right),
left.Type,
typeMapping);

default:
throw new InvalidOperationException("Incorrect operatorType for SqlBinaryExpression");
}

return new SqlBinaryExpression(
sqlBinaryExpression.OperatorType,
ApplyTypeMapping(left, inferredTypeMapping),
ApplyTypeMapping(right, inferredTypeMapping),
resultType,
resultTypeMapping);
}

public virtual RelationalTypeMapping GetTypeMappingForValue(object value)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,11 @@ public class SqlBinaryExpression : SqlExpression
ExpressionType.GreaterThanOrEqual,
ExpressionType.Equal,
ExpressionType.NotEqual,
//ExpressionType.ExclusiveOr,
ExpressionType.Coalesce
ExpressionType.ExclusiveOr,
ExpressionType.Coalesce,
//ExpressionType.ArrayIndex,
//ExpressionType.RightShift,
//ExpressionType.LeftShift,
ExpressionType.RightShift,
ExpressionType.LeftShift
};

private static ExpressionType VerifyOperator(ExpressionType operatorType)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,6 @@ private static readonly HashSet<string> _dateTimeDataTypes
"datetimeoffset"
};

private static readonly HashSet<ExpressionType> _arithmeticOperatorTypes
= new HashSet<ExpressionType>
{
ExpressionType.Add,
ExpressionType.Subtract,
ExpressionType.Multiply,
ExpressionType.Divide,
ExpressionType.Modulo
};

// TODO: Possibly make this protected in base
private readonly ISqlExpressionFactory _sqlExpressionFactory;

Expand All @@ -47,17 +37,29 @@ protected override Expression VisitBinary(BinaryExpression binaryExpression)
{
var visitedExpression = (SqlExpression)base.VisitBinary(binaryExpression);

if (visitedExpression == null)
if (visitedExpression is SqlBinaryExpression sqlBinary)
{
return null;
switch (sqlBinary.OperatorType)
{
case ExpressionType.LeftShift:
case ExpressionType.RightShift:
return null;

case ExpressionType.Add:
case ExpressionType.Subtract:
case ExpressionType.Multiply:
case ExpressionType.Divide:
case ExpressionType.Modulo:
if (_dateTimeDataTypes.Contains(GetProviderType(sqlBinary.Left))
|| _dateTimeDataTypes.Contains(GetProviderType(sqlBinary.Right)))
{
return null;
}
break;
}
}

return visitedExpression is SqlBinaryExpression sqlBinary
&& _arithmeticOperatorTypes.Contains(sqlBinary.OperatorType)
&& (_dateTimeDataTypes.Contains(GetProviderType(sqlBinary.Left))
|| _dateTimeDataTypes.Contains(GetProviderType(sqlBinary.Right)))
? null
: visitedExpression;
return visitedExpression;
}

public override SqlExpression TranslateLongCount(Expression expression = null)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,17 +107,22 @@ protected override Expression VisitBinary(BinaryExpression binaryExpression)
{
var visitedExpression = (SqlExpression)base.VisitBinary(binaryExpression);

if (visitedExpression == null)
if (visitedExpression is SqlBinaryExpression sqlBinary)
{
return null;
if (sqlBinary.OperatorType == ExpressionType.ExclusiveOr)
{
return null;
}

if (_restrictedBinaryExpressions.TryGetValue(sqlBinary.OperatorType, out var restrictedTypes)
&& (restrictedTypes.Contains(GetProviderType(sqlBinary.Left))
|| restrictedTypes.Contains(GetProviderType(sqlBinary.Right))))
{
return null;
}
}

return visitedExpression is SqlBinaryExpression sqlBinary
&& _restrictedBinaryExpressions.TryGetValue(sqlBinary.OperatorType, out var restrictedTypes)
&& (restrictedTypes.Contains(GetProviderType(sqlBinary.Left))
|| restrictedTypes.Contains(GetProviderType(sqlBinary.Right)))
? null
: visitedExpression;
return visitedExpression;
}

public override SqlExpression TranslateAverage(Expression expression)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2368,6 +2368,36 @@ FROM root c
WHERE ((c[""Discriminator""] = ""Order"") AND ((c[""OrderID""] | 10248) = 10248))");
}

public override async Task Where_bitwise_binary_xor(bool async)
{
await base.Where_bitwise_binary_xor(async);

AssertSql(
@"SELECT c
FROM root c
WHERE ((c[""Discriminator""] = ""Order"") AND ((c[""OrderID""] ^ 1) = 10249))");
}

public override async Task Where_shift_left(bool async)
{
await base.Where_shift_left(async);

AssertSql(
@"SELECT c
FROM root c
WHERE ((c[""Discriminator""] = ""Order"") AND ((c[""OrderID""] << 1) = 20496))");
}

public override async Task Where_shift_right(bool async)
{
await base.Where_shift_right(async);

AssertSql(
@"SELECT c
FROM root c
WHERE ((c[""Discriminator""] = ""Order"") AND ((c[""OrderID""] >> 1) = 5124))");
}

[ConditionalFact(Skip = "Issue #17246")]
public override void Select_bitwise_or_with_logical_or()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3651,6 +3651,36 @@ public virtual Task Where_bitwise_binary_or(bool async)
entryCount: 1);
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Where_bitwise_binary_xor(bool async)
{
return AssertQuery(
async,
ss => ss.Set<Order>().Where(o => (o.OrderID ^ 1) == 10249),
entryCount: 1);
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Where_shift_left(bool async)
{
return AssertQuery(
async,
ss => ss.Set<Order>().Where(o => (o.OrderID << 1) == 20496),
entryCount: 1);
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Where_shift_right(bool async)
{
return AssertQuery(
async,
ss => ss.Set<Order>().Where(o => (o.OrderID >> 1) == 5124),
entryCount: 2);
}

[ConditionalFact]
public virtual void Select_bitwise_or_with_logical_or()
{
Expand Down
Loading

0 comments on commit a887fce

Please sign in to comment.