Skip to content

Commit

Permalink
Fix to #19731 - Discuss default null propagation strategy for functions
Browse files Browse the repository at this point in the history
- obsoleteing SqlFunction methods and ctors that don't specify nullability propagation info explicitly,
- correcting nullability for LongCount,
- changing 'NullResultAllowed' to 'IsNullable' to be consistent with ColumnExpression,
- other minor perf improvements & cleanup.

Fixes #19731
  • Loading branch information
maumar committed Feb 21, 2020
1 parent 1700908 commit 8f1063c
Show file tree
Hide file tree
Showing 52 changed files with 189 additions and 187 deletions.
12 changes: 6 additions & 6 deletions src/EFCore.Relational/Query/ISqlExpressionFactory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ SqlFunctionExpression Function(
SqlFunctionExpression Function(
[NotNull] string name,
[NotNull] IEnumerable<SqlExpression> arguments,
bool nullResultAllowed,
bool nullable,
[NotNull] IEnumerable<bool> argumentsPropagateNullability,
[NotNull] Type returnType,
[CanBeNull] RelationalTypeMapping typeMapping = null);
Expand All @@ -143,7 +143,7 @@ SqlFunctionExpression Function(
[CanBeNull] string schema,
[NotNull] string name,
[NotNull] IEnumerable<SqlExpression> arguments,
bool nullResultAllowed,
bool nullable,
[NotNull] IEnumerable<bool> argumentsPropagateNullability,
[NotNull] Type returnType,
[CanBeNull] RelationalTypeMapping typeMapping = null);
Expand All @@ -152,29 +152,29 @@ SqlFunctionExpression Function(
[CanBeNull] SqlExpression instance,
[NotNull] string name,
[NotNull] IEnumerable<SqlExpression> arguments,
bool nullResultAllowed,
bool nullable,
bool instancePropagatesNullability,
[NotNull] IEnumerable<bool> argumentsPropagateNullability,
[NotNull] Type returnType,
[CanBeNull] RelationalTypeMapping typeMapping = null);

SqlFunctionExpression Function(
[NotNull] string name,
bool nullResultAllowed,
bool nullable,
[NotNull] Type returnType,
[CanBeNull] RelationalTypeMapping typeMapping = null);

SqlFunctionExpression Function(
[NotNull] string schema,
[NotNull] string name,
bool nullResultAllowed,
bool nullable,
[NotNull] Type returnType,
[CanBeNull] RelationalTypeMapping typeMapping = null);

SqlFunctionExpression Function(
[CanBeNull] SqlExpression instance,
[NotNull] string name,
bool nullResultAllowed,
bool nullable,
bool instancePropagatesNullability,
[NotNull] Type returnType,
[CanBeNull] RelationalTypeMapping typeMapping = null);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1287,7 +1287,7 @@ protected virtual SqlExpression ProcessNullNotNull(
sqlUnaryExpression.TypeMapping));
}

if (!sqlFunctionExpression.NullResultAllowed)
if (!sqlFunctionExpression.IsNullable)
{
// when we know that function can't be nullable:
// non_nullable_function() is null-> false
Expand All @@ -1306,39 +1306,33 @@ protected virtual SqlExpression ProcessNullNotNull(
nullabilityPropagationElements.Add(sqlFunctionExpression.Instance);
}

for (var i = 0; i < sqlFunctionExpression.Arguments.Count; i++)
if (!sqlFunctionExpression.IsNiladic)
{
if (sqlFunctionExpression.ArgumentsPropagateNullability[i])
for (var i = 0; i < sqlFunctionExpression.Arguments.Count; i++)
{
nullabilityPropagationElements.Add(sqlFunctionExpression.Arguments[i]);
if (sqlFunctionExpression.ArgumentsPropagateNullability[i])
{
nullabilityPropagationElements.Add(sqlFunctionExpression.Arguments[i]);
}
}
}

// function(a, b) IS NULL -> a IS NULL || b IS NULL
// function(a, b) IS NOT NULL -> a IS NOT NULL && b IS NOT NULL
if (nullabilityPropagationElements.Count > 0)
{
var result = ProcessNullNotNull(
SqlExpressionFactory.MakeUnary(
sqlUnaryExpression.OperatorType,
nullabilityPropagationElements[0],
sqlUnaryExpression.Type,
sqlUnaryExpression.TypeMapping),
operandNullable: null);

foreach (var element in nullabilityPropagationElements.Skip(1))
{
result = SimplifyLogicalSqlBinaryExpression(
var result = nullabilityPropagationElements
.Select(e => ProcessNullNotNull(
SqlExpressionFactory.MakeUnary(
sqlUnaryExpression.OperatorType,
e,
sqlUnaryExpression.Type,
sqlUnaryExpression.TypeMapping),
operandNullable: null))
.Aggregate((r, e) => SimplifyLogicalSqlBinaryExpression(
sqlUnaryExpression.OperatorType == ExpressionType.Equal
? SqlExpressionFactory.OrElse(
result,
ProcessNullNotNull(
SqlExpressionFactory.IsNull(element),
operandNullable: null))
: SqlExpressionFactory.AndAlso(
result,
ProcessNullNotNull(
SqlExpressionFactory.IsNotNull(element),
operandNullable: null)));
}
? SqlExpressionFactory.OrElse(r, e)
: SqlExpressionFactory.AndAlso(r, e)));

return result;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ public virtual SqlExpression Translate(
dbFunction.Schema,
dbFunction.Name,
arguments,
nullResultAllowed: true,
nullable: true,
argumentsPropagateNullability: arguments.Select(a => false).ToList(),
method.ReturnType);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,15 +101,15 @@ public virtual SqlExpression TranslateAverage([NotNull] Expression expression)
SqlExpressionFactory.Function(
"AVG",
new[] { sqlExpression },
nullResultAllowed: true,
nullable: true,
argumentsPropagateNullability: new[] { false },
typeof(double)),
sqlExpression.Type,
sqlExpression.TypeMapping)
: (SqlExpression)SqlExpressionFactory.Function(
"AVG",
new[] { sqlExpression },
nullResultAllowed: true,
nullable: true,
argumentsPropagateNullability: new[] { false },
sqlExpression.Type,
sqlExpression.TypeMapping);
Expand All @@ -127,7 +127,7 @@ public virtual SqlExpression TranslateCount([CanBeNull] Expression expression =
SqlExpressionFactory.Function(
"COUNT",
new[] { SqlExpressionFactory.Fragment("*") },
nullResultAllowed: false,
nullable: false,
argumentsPropagateNullability: new[] { false },
typeof(int)));
}
Expand All @@ -144,7 +144,7 @@ public virtual SqlExpression TranslateLongCount([CanBeNull] Expression expressio
SqlExpressionFactory.Function(
"COUNT",
new[] { SqlExpressionFactory.Fragment("*") },
nullResultAllowed: false,
nullable: false,
argumentsPropagateNullability: new[] { false },
typeof(long)));
}
Expand All @@ -162,7 +162,7 @@ public virtual SqlExpression TranslateMax([NotNull] Expression expression)
? SqlExpressionFactory.Function(
"MAX",
new[] { sqlExpression },
nullResultAllowed: true,
nullable: true,
argumentsPropagateNullability: new[] { false },
sqlExpression.Type,
sqlExpression.TypeMapping)
Expand All @@ -182,7 +182,7 @@ public virtual SqlExpression TranslateMin([NotNull] Expression expression)
? SqlExpressionFactory.Function(
"MIN",
new[] { sqlExpression },
nullResultAllowed: true,
nullable: true,
argumentsPropagateNullability: new[] { false },
sqlExpression.Type,
sqlExpression.TypeMapping)
Expand Down Expand Up @@ -210,15 +210,15 @@ public virtual SqlExpression TranslateSum([NotNull] Expression expression)
SqlExpressionFactory.Function(
"SUM",
new[] { sqlExpression },
nullResultAllowed: true,
nullable: true,
argumentsPropagateNullability: new[] { false },
typeof(double)),
inputType,
sqlExpression.TypeMapping)
: (SqlExpression)SqlExpressionFactory.Function(
"SUM",
new[] { sqlExpression },
nullResultAllowed: true,
nullable: true,
argumentsPropagateNullability: new[] { false },
inputType,
sqlExpression.TypeMapping);
Expand Down
42 changes: 23 additions & 19 deletions src/EFCore.Relational/Query/SqlExpressionFactory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -376,7 +376,7 @@ public virtual SqlFunctionExpression Coalesce(SqlExpression left, SqlExpression
return SqlFunctionExpression.Create(
"COALESCE",
typeMappedArguments,
nullResultAllowed: true,
nullable: true,
// COALESCE is handled separately since it's only nullable if *both* arguments are null
argumentsPropagateNullability: new[] { false, false },
resultType,
Expand Down Expand Up @@ -487,21 +487,24 @@ public virtual CaseExpression Case(IReadOnlyList<CaseWhenClause> whenClauses, Sq
return new CaseExpression(typeMappedWhenClauses, elseResult);
}

[Obsolete("Use overload that explicitly specifies value for 'argumentsPropagateNullability' argument.")]
public virtual SqlFunctionExpression Function(
string name,
IEnumerable<SqlExpression> arguments,
Type returnType,
RelationalTypeMapping typeMapping = null)
=> Function(name, arguments, nullResultAllowed: true, argumentsPropagateNullability: arguments.Select(a => false), returnType, typeMapping);
=> Function(name, arguments, nullable: true, argumentsPropagateNullability: arguments.Select(a => false), returnType, typeMapping);

[Obsolete("Use overload that explicitly specifies value for 'argumentsPropagateNullability' argument.")]
public virtual SqlFunctionExpression Function(
string schema,
string name,
IEnumerable<SqlExpression> arguments,
Type returnType,
RelationalTypeMapping typeMapping = null)
=> Function(schema, name, arguments, nullResultAllowed: true, argumentsPropagateNullability: arguments.Select(a => false), returnType, typeMapping);
=> Function(schema, name, arguments, nullable: true, argumentsPropagateNullability: arguments.Select(a => false), returnType, typeMapping);

[Obsolete("Use overload that explicitly specifies values for 'instancePropagatesNullability' and 'argumentsPropagateNullability' arguments.")]
public virtual SqlFunctionExpression Function(
SqlExpression instance,
string name,
Expand All @@ -512,25 +515,26 @@ public virtual SqlFunctionExpression Function(
instance,
name,
arguments,
nullResultAllowed: true,
nullable: true,
instancePropagatesNullability: false,
argumentsPropagateNullability: arguments.Select(a => false),
returnType,
typeMapping);

public virtual SqlFunctionExpression Function(string name, Type returnType, RelationalTypeMapping typeMapping = null)
=> Function(name, nullResultAllowed: true, returnType, typeMapping);
=> Function(name, nullable: true, returnType, typeMapping);

public virtual SqlFunctionExpression Function(string schema, string name, Type returnType, RelationalTypeMapping typeMapping = null)
=> Function(schema, name, nullResultAllowed: true, returnType, typeMapping);
=> Function(schema, name, nullable: true, returnType, typeMapping);

[Obsolete("Use overload that explicitly specifies value for 'instancePropagatesNullability' argument.")]
public virtual SqlFunctionExpression Function(SqlExpression instance, string name, Type returnType, RelationalTypeMapping typeMapping = null)
=> Function(instance, name, nullResultAllowed: true, instancePropagatesNullability: false, returnType, typeMapping);
=> Function(instance, name, nullable: true, instancePropagatesNullability: false, returnType, typeMapping);

public virtual SqlFunctionExpression Function(
string name,
IEnumerable<SqlExpression> arguments,
bool nullResultAllowed,
bool nullable,
IEnumerable<bool> argumentsPropagateNullability,
Type returnType,
RelationalTypeMapping typeMapping = null)
Expand All @@ -549,7 +553,7 @@ public virtual SqlFunctionExpression Function(
return SqlFunctionExpression.Create(
name,
typeMappedArguments,
nullResultAllowed,
nullable,
argumentsPropagateNullability,
returnType,
typeMapping);
Expand All @@ -559,7 +563,7 @@ public virtual SqlFunctionExpression Function(
string schema,
string name,
IEnumerable<SqlExpression> arguments,
bool nullResultAllowed,
bool nullable,
IEnumerable<bool> argumentsPropagateNullability,
Type returnType,
RelationalTypeMapping typeMapping = null)
Expand All @@ -578,7 +582,7 @@ public virtual SqlFunctionExpression Function(
schema,
name,
typeMappedArguments,
nullResultAllowed,
nullable,
argumentsPropagateNullability,
returnType,
typeMapping);
Expand All @@ -588,7 +592,7 @@ public virtual SqlFunctionExpression Function(
SqlExpression instance,
string name,
IEnumerable<SqlExpression> arguments,
bool nullResultAllowed,
bool nullable,
bool instancePropagatesNullability,
IEnumerable<bool> argumentsPropagateNullability,
Type returnType,
Expand All @@ -609,34 +613,34 @@ public virtual SqlFunctionExpression Function(
instance,
name,
typeMappedArguments,
nullResultAllowed,
nullable,
instancePropagatesNullability,
argumentsPropagateNullability,
returnType,
typeMapping);
}

public virtual SqlFunctionExpression Function(string name, bool nullResultAllowed, Type returnType, RelationalTypeMapping typeMapping = null)
public virtual SqlFunctionExpression Function(string name, bool nullable, Type returnType, RelationalTypeMapping typeMapping = null)
{
Check.NotEmpty(name, nameof(name));
Check.NotNull(returnType, nameof(returnType));

return SqlFunctionExpression.CreateNiladic(name, nullResultAllowed, returnType, typeMapping);
return SqlFunctionExpression.CreateNiladic(name, nullable, returnType, typeMapping);
}

public virtual SqlFunctionExpression Function(string schema, string name, bool nullResultAllowed, Type returnType, RelationalTypeMapping typeMapping = null)
public virtual SqlFunctionExpression Function(string schema, string name, bool nullable, Type returnType, RelationalTypeMapping typeMapping = null)
{
Check.NotEmpty(schema, nameof(schema));
Check.NotEmpty(name, nameof(name));
Check.NotNull(returnType, nameof(returnType));

return SqlFunctionExpression.CreateNiladic(schema, name, nullResultAllowed, returnType, typeMapping);
return SqlFunctionExpression.CreateNiladic(schema, name, nullable, returnType, typeMapping);
}

public virtual SqlFunctionExpression Function(
SqlExpression instance,
string name,
bool nullResultAllowed,
bool nullable,
bool instancePropagatesNullability,
Type returnType,
RelationalTypeMapping typeMapping = null)
Expand All @@ -647,7 +651,7 @@ public virtual SqlFunctionExpression Function(
return SqlFunctionExpression.CreateNiladic(
ApplyDefaultTypeMapping(instance),
name,
nullResultAllowed,
nullable,
instancePropagatesNullability,
returnType,
typeMapping);
Expand Down
Loading

0 comments on commit 8f1063c

Please sign in to comment.