Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Query: Make SqlExpression.Type to be reference type or non-nullable v… #21726

Merged
merged 1 commit into from
Jul 22, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/EFCore.Relational/Query/Internal/ComparisonTranslator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -48,15 +48,15 @@ public virtual SqlExpression Translate(SqlExpression instance, MethodInfo method
SqlExpression right = null;
if (method.Name == nameof(string.Compare)
&& arguments.Count == 2
&& arguments[0].Type.UnwrapNullableType() == arguments[1].Type.UnwrapNullableType())
&& arguments[0].Type == arguments[1].Type)
{
left = arguments[0];
right = arguments[1];
}
else if (method.Name == nameof(string.CompareTo)
&& arguments.Count == 1
&& instance != null
&& instance.Type.UnwrapNullableType() == arguments[0].Type.UnwrapNullableType())
&& instance.Type == arguments[0].Type)
{
left = instance;
right = arguments[0];
Expand Down
11 changes: 3 additions & 8 deletions src/EFCore.Relational/Query/Internal/EnumHasFlagTranslator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -48,14 +48,9 @@ public virtual SqlExpression Translate(SqlExpression instance, MethodInfo method
if (Equals(method, _methodInfo))
{
var argument = arguments[0];
if (instance.Type.UnwrapNullableType() != argument.Type.UnwrapNullableType())
{
return null;
}

return _sqlExpressionFactory.Equal(
_sqlExpressionFactory.And(instance, argument),
argument);
return instance.Type != argument.Type
? null
: (SqlExpression)_sqlExpressionFactory.Equal(_sqlExpressionFactory.And(instance, argument), argument);
}

return null;
Expand Down
2 changes: 1 addition & 1 deletion src/EFCore.Relational/Query/Internal/EqualsTranslator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ public virtual SqlExpression Translate(SqlExpression instance, MethodInfo method
if (left != null
&& right != null)
{
if (left.Type.UnwrapNullableType() == right.Type.UnwrapNullableType())
if (left.Type == right.Type)
{
return _sqlExpressionFactory.Equal(left, right);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -148,8 +148,7 @@ private SqlExpression SimplifyUnaryExpression(
switch (operatorType)
{
case ExpressionType.Not
when type == typeof(bool)
|| type == typeof(bool?):
when type == typeof(bool):
{
switch (operand)
{
Expand Down Expand Up @@ -250,7 +249,7 @@ private SqlExpression SimplifyNullNotNullExpression(
typeMapping);

case ColumnExpression columnOperand
when !columnOperand.IsNullable:
when !columnOperand.IsNullable:
return _sqlExpressionFactory.Constant(operatorType == ExpressionType.NotEqual, typeMapping);

case SqlUnaryExpression sqlUnaryOperand:
Expand Down Expand Up @@ -298,7 +297,7 @@ private SqlExpression SimplifyNullNotNullExpression(
break;

case SqlFunctionExpression sqlFunctionExpression
when sqlFunctionExpression.IsBuiltIn
when sqlFunctionExpression.IsBuiltIn
&& string.Equals("COALESCE", sqlFunctionExpression.Name, StringComparison.OrdinalIgnoreCase):
// for coalesce:
// (a ?? b) == null -> a == null && b == null
Expand Down
4 changes: 2 additions & 2 deletions src/EFCore.Relational/Query/QuerySqlGenerator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -572,7 +572,7 @@ protected override Expression VisitSqlParameter(SqlParameterExpression sqlParame
sqlParameterExpression.Name,
parameterNameInCommand,
sqlParameterExpression.TypeMapping,
sqlParameterExpression.Type.IsNullableType());
sqlParameterExpression.IsNullable);
}

_relationalCommandBuilder
Expand Down Expand Up @@ -706,7 +706,7 @@ protected override Expression VisitSqlUnary(SqlUnaryExpression sqlUnaryExpressio
}

case ExpressionType.Not
when sqlUnaryExpression.Type.UnwrapNullableType() == typeof(bool):
when sqlUnaryExpression.Type == typeof(bool):
{
_relationalCommandBuilder.Append("NOT (");
Visit(sqlUnaryExpression.Operand);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.

using System;
using System.Collections.Generic;
using System.Linq;
using System.Linq.Expressions;
Expand Down Expand Up @@ -83,7 +84,7 @@ public virtual SqlExpression Translate(
arguments,
nullable: true,
argumentsPropagateNullability: arguments.Select(a => false).ToList(),
method.ReturnType);
method.ReturnType.UnwrapNullableType());
}

return _sqlExpressionFactory.Function(
Expand All @@ -92,7 +93,7 @@ public virtual SqlExpression Translate(
arguments,
nullable: true,
argumentsPropagateNullability: arguments.Select(a => false).ToList(),
method.ReturnType);
method.ReturnType.UnwrapNullableType());
}

return _plugins.Concat(_translators)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ public virtual SqlExpression TranslateAverage([NotNull] Expression expression)
: CoreStrings.TranslationFailedWithDetails(expression.Print(), TranslationErrorDetails));
}

var inputType = sqlExpression.Type.UnwrapNullableType();
var inputType = sqlExpression.Type;
if (inputType == typeof(int)
|| inputType == typeof(long))
{
Expand Down Expand Up @@ -305,7 +305,7 @@ public virtual SqlExpression TranslateSum([NotNull] Expression expression)
: CoreStrings.TranslationFailedWithDetails(expression.Print(), TranslationErrorDetails));
}

var inputType = sqlExpression.Type.UnwrapNullableType();
var inputType = sqlExpression.Type;

return inputType == typeof(float)
? _sqlExpressionFactory.Convert(
Expand Down Expand Up @@ -1313,9 +1313,7 @@ protected override Expression VisitExtension(Expression extensionExpression)
Check.NotNull(extensionExpression, nameof(extensionExpression));

if (extensionExpression is SqlExpression sqlExpression
&& !(extensionExpression is SqlFragmentExpression)
&& !(extensionExpression is SqlFunctionExpression sqlFunctionExpression
&& sqlFunctionExpression.Type.IsQueryableType()))
&& !(extensionExpression is SqlFragmentExpression))
{
if (sqlExpression.TypeMapping == null)
{
Expand Down
4 changes: 2 additions & 2 deletions src/EFCore.Relational/Query/SqlExpressionFactory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ private SqlExpression ApplyTypeMappingOnSqlUnary(
case ExpressionType.Equal:
case ExpressionType.NotEqual:
case ExpressionType.Not
when sqlUnaryExpression.Type.UnwrapNullableType() == typeof(bool):
when sqlUnaryExpression.Type == typeof(bool):
{
resultTypeMapping = _boolTypeMapping;
resultType = typeof(bool);
Expand Down Expand Up @@ -444,7 +444,7 @@ public virtual SqlUnaryExpression Convert(SqlExpression operand, Type type, Rela
Check.NotNull(operand, nameof(operand));
Check.NotNull(type, nameof(type));

return MakeUnary(ExpressionType.Convert, operand, type, typeMapping);
return MakeUnary(ExpressionType.Convert, operand, type.UnwrapNullableType(), typeMapping);
}

/// <inheritdoc />
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ internal ColumnExpression(IProperty property, IColumnBase column, TableExpressio
: this(
column.Name,
table,
property.ClrType,
property.ClrType.UnwrapNullableType(),
column.PropertyMappings.First(m => m.Property == property).TypeMapping,
nullable || column.IsNullable)
{
Expand All @@ -51,7 +51,7 @@ private static bool IsNullableProjection(ProjectionExpression projectionExpressi
};

private ColumnExpression(string name, TableExpressionBase table, Type type, RelationalTypeMapping typeMapping, bool nullable)
: base(nullable ? type.MakeNullable() : type, typeMapping)
: base(type, typeMapping)
{
Check.NotEmpty(name, nameof(name));
Check.NotNull(table, nameof(table));
Expand All @@ -70,6 +70,7 @@ private ColumnExpression(string name, TableExpressionBase table, Type type, Rela
/// The table from which column is being referenced.
/// </summary>
public TableExpressionBase Table { get; }

/// <summary>
/// The bool value indicating if this column can have null values.
/// </summary>
Expand All @@ -88,7 +89,7 @@ protected override Expression VisitChildren(ExpressionVisitor visitor)
/// </summary>
/// <returns> A new expression which has <see cref="IsNullable"/> property set to true. </returns>
public ColumnExpression MakeNullable()
=> new ColumnExpression(Name, Table, Type.MakeNullable(), TypeMapping, true);
=> new ColumnExpression(Name, Table, Type, TypeMapping, true);

/// <inheritdoc />
protected override void Print(ExpressionPrinter expressionPrinter)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ public class SqlConstantExpression : SqlExpression
/// <param name="constantExpression"> A <see cref="ConstantExpression"/>. </param>
/// <param name="typeMapping"> The <see cref="RelationalTypeMapping"/> associated with the expression. </param>
public SqlConstantExpression([NotNull] ConstantExpression constantExpression, [CanBeNull] RelationalTypeMapping typeMapping)
: base(Check.NotNull(constantExpression, nameof(constantExpression)).Type, typeMapping)
: base(Check.NotNull(constantExpression, nameof(constantExpression)).Type.UnwrapNullableType(), typeMapping)
{
_constantExpression = constantExpression;
}
Expand Down
5 changes: 5 additions & 0 deletions src/EFCore.Relational/Query/SqlExpressions/SqlExpression.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
using JetBrains.Annotations;
using Microsoft.EntityFrameworkCore.Diagnostics;
using Microsoft.EntityFrameworkCore.Storage;
using Microsoft.EntityFrameworkCore.Utilities;

namespace Microsoft.EntityFrameworkCore.Query.SqlExpressions
{
Expand All @@ -27,6 +28,10 @@ public abstract class SqlExpression : Expression, IPrintableExpression
/// <param name="typeMapping"> The <see cref="RelationalTypeMapping"/> associated with the expression. </param>
protected SqlExpression([NotNull] Type type, [CanBeNull] RelationalTypeMapping typeMapping)
{
Check.NotNull(type, nameof(type));

Check.DebugAssert(!type.IsNullableValueType(), "SqlExpression.Type must be reference type or non-nullable value type");

Type = type;
TypeMapping = typeMapping;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,16 +24,22 @@ public sealed class SqlParameterExpression : SqlExpression
private readonly ParameterExpression _parameterExpression;

internal SqlParameterExpression(ParameterExpression parameterExpression, RelationalTypeMapping typeMapping)
: base(parameterExpression.Type, typeMapping)
: base(parameterExpression.Type.UnwrapNullableType(), typeMapping)
{
_parameterExpression = parameterExpression;
IsNullable = parameterExpression.Type.IsNullableType();
}

/// <summary>
/// The name of the parameter.
/// </summary>
public string Name => _parameterExpression.Name;

/// <summary>
/// The bool value indicating if this parameter can have null values.
/// </summary>
public bool IsNullable { get; }

/// <summary>
/// Applies supplied type mapping to this expression.
/// </summary>
Expand Down
2 changes: 1 addition & 1 deletion src/EFCore.Relational/Query/SqlNullabilityProcessor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1644,7 +1644,7 @@ private SqlExpression ProcessNullNotNull(SqlUnaryExpression sqlUnaryExpression,
private static bool IsLogicalNot(SqlUnaryExpression sqlUnaryExpression)
=> sqlUnaryExpression != null
&& sqlUnaryExpression.OperatorType == ExpressionType.Not
&& sqlUnaryExpression.Type.UnwrapNullableType() == typeof(bool);
&& sqlUnaryExpression.Type == typeof(bool);

// ?a == ?b -> [(a == b) && (a != null && b != null)] || (a == null && b == null))
//
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ private SqlExpression BuildCompareToExpression(SqlExpression sqlExpression)
private SqlExpression SimplifyNegatedBinary(SqlExpression sqlExpression)
=> sqlExpression is SqlUnaryExpression sqlUnaryExpression
&& sqlUnaryExpression.OperatorType == ExpressionType.Not
&& sqlUnaryExpression.Type.UnwrapNullableType() == typeof(bool)
&& sqlUnaryExpression.Type == typeof(bool)
&& sqlUnaryExpression.Operand is SqlBinaryExpression sqlBinaryOperand
&& (sqlBinaryOperand.OperatorType == ExpressionType.Equal || sqlBinaryOperand.OperatorType == ExpressionType.NotEqual)
? _sqlExpressionFactory.MakeBinary(
Expand Down Expand Up @@ -351,7 +351,7 @@ protected override Expression VisitSqlUnary(SqlUnaryExpression sqlUnaryExpressio
switch (sqlUnaryExpression.OperatorType)
{
case ExpressionType.Not
when sqlUnaryExpression.Type.UnwrapNullableType() == typeof(bool):
when sqlUnaryExpression.Type == typeof(bool):
{
_isSearchCondition = true;
resultCondition = true;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,15 +103,15 @@ public virtual SqlExpression Translate(SqlExpression instance, MethodInfo method
argumentsPropagateNullability: new[] { true },
typeof(long));

return _sqlExpressionFactory.Convert(result, method.ReturnType);
return _sqlExpressionFactory.Convert(result, method.ReturnType.UnwrapNullableType());
}

return _sqlExpressionFactory.Function(
"DATALENGTH",
arguments.Skip(1),
nullable: true,
argumentsPropagateNullability: new[] { true },
method.ReturnType);
method.ReturnType.UnwrapNullableType());
}

return null;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ public virtual SqlExpression Translate(SqlExpression instance, MethodInfo method
return method.Name == nameof(ToString)
&& arguments.Count == 0
&& instance != null
&& _typeMapping.TryGetValue(instance.Type.UnwrapNullableType(), out var storeType)
&& _typeMapping.TryGetValue(instance.Type, out var storeType)
? _sqlExpressionFactory.Function(
"CONVERT",
new[] { _sqlExpressionFactory.Fragment(storeType), instance },
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,8 @@ protected override ShapedQueryExpression TranslateThenBy(ShapedQueryExpression s
}

private static Type GetProviderType(SqlExpression expression)
=> (expression.TypeMapping?.Converter?.ProviderClrType
?? expression.TypeMapping?.ClrType
?? expression.Type).UnwrapNullableType();
=> expression.TypeMapping?.Converter?.ProviderClrType
?? expression.TypeMapping?.ClrType
?? expression.Type;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -289,9 +289,9 @@ public override SqlExpression TranslateSum(Expression expression)
private static Type GetProviderType(SqlExpression expression)
=> expression == null
? null
: (expression.TypeMapping?.Converter?.ProviderClrType
?? expression.TypeMapping?.ClrType
?? expression.Type).UnwrapNullableType();
: expression.TypeMapping?.Converter?.ProviderClrType
?? expression.TypeMapping?.ClrType
?? expression.Type;

private static bool AreOperandsDecimals(SqlBinaryExpression sqlExpression) => GetProviderType(sqlExpression.Left) == typeof(decimal)
&& GetProviderType(sqlExpression.Right) == typeof(decimal);
Expand Down
Loading