Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Push task-ifying extract-method return type to the code gen stage. #76670

Merged
merged 8 commits into from
Jan 8, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -91,31 +91,16 @@ protected override ImmutableArray<StatementSyntax> GetInitialStatementsForMethod
Contract.ThrowIfFalse(this.SelectionResult.IsExtractMethodOnExpression);

// special case for array initializer
var returnType = AnalyzerResult.ReturnType;
var containingScope = this.SelectionResult.GetContainingScope();

ExpressionSyntax expression;
if (returnType.TypeKind == TypeKind.Array && containingScope is InitializerExpressionSyntax)
{
var typeSyntax = returnType.GenerateTypeSyntax();
var returnType = AnalyzerResult.CoreReturnType;
var containingScope = (ExpressionSyntax)this.SelectionResult.GetContainingScope();

expression = SyntaxFactory.ArrayCreationExpression(typeSyntax as ArrayTypeSyntax, containingScope as InitializerExpressionSyntax);
}
else
{
expression = containingScope as ExpressionSyntax;
}
var expression = returnType.TypeKind == TypeKind.Array && containingScope is InitializerExpressionSyntax initializerExpression
? SyntaxFactory.ArrayCreationExpression((ArrayTypeSyntax)returnType.GenerateTypeSyntax(), initializerExpression)
: containingScope;

if (AnalyzerResult.HasReturnType)
{
return [SyntaxFactory.ReturnStatement(
WrapInCheckedExpressionIfNeeded(expression))];
}
else
{
return [SyntaxFactory.ExpressionStatement(
WrapInCheckedExpressionIfNeeded(expression))];
}
return AnalyzerResult.CoreReturnType.SpecialType != SpecialType.System_Void
? [SyntaxFactory.ReturnStatement(WrapInCheckedExpressionIfNeeded(expression))]
: [SyntaxFactory.ExpressionStatement(WrapInCheckedExpressionIfNeeded(expression))];
}

private ExpressionSyntax WrapInCheckedExpressionIfNeeded(ExpressionSyntax expression)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ protected override IMethodSymbol GenerateMethodDefinition(
attributes: [],
accessibility: Accessibility.Private,
modifiers: CreateMethodModifiers(),
returnType: AnalyzerResult.ReturnType,
returnType: this.GetFinalReturnType(),
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

only when we're actually generating the final method declaration do we get teh type, potentially wrapped with 'Task'. Everything else sees the 'core' type and can ask questions of that.

refKind: AnalyzerResult.ReturnsByRef ? RefKind.Ref : RefKind.None,
explicitInterfaceImplementations: default,
name: _methodName.ToString(),
Expand Down Expand Up @@ -209,7 +209,7 @@ private bool ShouldPutUnsafeModifier()
private DeclarationModifiers CreateMethodModifiers()
{
var isUnsafe = ShouldPutUnsafeModifier();
var isAsync = this.SelectionResult.CreateAsyncMethod();
var isAsync = this.SelectionResult.ContainsAwaitExpression();
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

renamed for clarity

var isStatic = !AnalyzerResult.UseInstanceMember;
var isReadOnly = AnalyzerResult.ShouldBeReadOnly;

Expand Down Expand Up @@ -604,23 +604,20 @@ protected override ExpressionSyntax CreateCallSignature()
}

var invocation = (ExpressionSyntax)InvocationExpression(methodExpression, ArgumentList([.. arguments]));
if (this.SelectionResult.CreateAsyncMethod())

// If we're extracting any code that contained an 'await' then we'll have to await the new method we're
// calling as well. If we also see any use of .ConfigureAwait(false) in the extracted code, keep that
// pattern on the await expression we produce.
if (this.SelectionResult.ContainsAwaitExpression())
{
if (this.SelectionResult.ShouldCallConfigureAwaitFalse())
if (this.SelectionResult.ContainsConfigureAwaitFalse())
{
if (AnalyzerResult.ReturnType.GetMembers().Any(static x => x is IMethodSymbol
{
Name: nameof(Task.ConfigureAwait),
Parameters: [{ Type.SpecialType: SpecialType.System_Boolean }],
}))
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note: i removed the check that the final type we're extracting actually has a .ConfigureAwait(bool) method on it. We're only adding .ConfigureAwait(false) because we saw it in the user's original code. So we can trust they know what pattern they're following and that we should follow it as well. This also makes things work if the ConfigureAwait is an extension method.

{
invocation = InvocationExpression(
MemberAccessExpression(
SyntaxKind.SimpleMemberAccessExpression,
invocation,
IdentifierName(nameof(Task.ConfigureAwait))),
ArgumentList([Argument(LiteralExpression(SyntaxKind.FalseLiteralExpression))]));
}
invocation = InvocationExpression(
MemberAccessExpression(
SyntaxKind.SimpleMemberAccessExpression,
invocation,
IdentifierName(nameof(Task.ConfigureAwait))),
ArgumentList([Argument(LiteralExpression(SyntaxKind.FalseLiteralExpression))]));
}

invocation = AwaitExpression(invocation);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,17 +27,13 @@ private sealed class StatementResult(
: CSharpSelectionResult(document, selectionType, finalSpan)
{
public override bool ContainingScopeHasAsyncKeyword()
{
var node = GetContainingScope();

return node switch
=> GetContainingScope() switch
{
MethodDeclarationSyntax method => method.Modifiers.Any(SyntaxKind.AsyncKeyword),
LocalFunctionStatementSyntax localFunction => localFunction.Modifiers.Any(SyntaxKind.AsyncKeyword),
AnonymousFunctionExpressionSyntax anonymousFunction => anonymousFunction.AsyncKeyword != default,
_ => false,
};
}

public override SyntaxNode GetContainingScope()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,8 @@
using System.Threading.Tasks;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp.Extensions;
using Microsoft.CodeAnalysis.CSharp.LanguageService;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using Microsoft.CodeAnalysis.ExtractMethod;
using Microsoft.CodeAnalysis.LanguageService;
using Microsoft.CodeAnalysis.Shared.Extensions;
using Microsoft.CodeAnalysis.Text;
using Roslyn.Utilities;
Expand Down Expand Up @@ -68,27 +66,6 @@ protected override SyntaxNode GetNodeForDataFlowAnalysis()
: node;
}

protected override bool UnderAnonymousOrLocalMethod(SyntaxToken token, SyntaxToken firstToken, SyntaxToken lastToken)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these are no longer needed as they were hacks on how we previously determined certain syntactic facts about the code being extracted.

=> IsUnderAnonymousOrLocalMethod(token, firstToken, lastToken);

public static bool IsUnderAnonymousOrLocalMethod(SyntaxToken token, SyntaxToken firstToken, SyntaxToken lastToken)
{
for (var current = token.Parent; current != null; current = current.Parent)
{
if (current is MemberDeclarationSyntax)
return false;

if (current is AnonymousFunctionExpressionSyntax or LocalFunctionStatementSyntax)
{
// make sure the selection contains the lambda
return firstToken.SpanStart <= current.GetFirstToken().SpanStart &&
current.GetLastToken().Span.End <= lastToken.Span.End;
}
}

return false;
}

public override StatementSyntax GetFirstStatementUnderContainer()
{
Contract.ThrowIfTrue(IsExtractMethodOnExpression);
Expand Down
110 changes: 39 additions & 71 deletions src/Features/Core/Portable/ExtractMethod/MethodExtractor.Analyzer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ protected virtual bool IsReadOutside(ISymbol symbol, HashSet<ISymbol> readOutsid
public AnalyzerResult Analyze()
{
// do data flow analysis
var model = this.SemanticDocument.SemanticModel;
var model = this.SemanticModel;
var dataFlowAnalysisData = this.SelectionResult.GetDataFlowAnalysis();

// build symbol map for the identifiers used inside of the selection
Expand Down Expand Up @@ -127,8 +127,6 @@ public AnalyzerResult Analyze()

var (variables, returnType, returnsByRef) = GetSignatureInformation(variableInfoMap);

(returnType, var awaitTaskReturn) = AdjustReturnType(returnType);

// collect method type variable used in selected code
var sortedMap = new SortedDictionary<int, ITypeParameterSymbol>();
var typeParametersInConstraintList = GetMethodTypeParametersInConstraintList(variableInfoMap, symbolMap, sortedMap);
Expand All @@ -144,69 +142,12 @@ public AnalyzerResult Analyze()
variables,
returnType,
returnsByRef,
awaitTaskReturn,
instanceMemberIsUsed,
shouldBeReadOnly,
endOfSelectionReachable,
operationStatus);
}

private (ITypeSymbol typeSymbol, bool awaitTaskReturn) AdjustReturnType(ITypeSymbol returnType)
{
// if selection contains await which is not under async lambda or anonymous delegate,
// change return type to be wrapped in Task
var shouldPutAsyncModifier = SelectionResult.CreateAsyncMethod();
if (shouldPutAsyncModifier)
return WrapReturnTypeInTask(returnType);
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we no longer do the wrapping step here. we hold off on that until we need the final type for the method decl to have.


// unwrap task if needed
return (UnwrapTaskIfNeeded(returnType), awaitTaskReturn: false);
}

private ITypeSymbol UnwrapTaskIfNeeded(ITypeSymbol returnType)
{
// nothing to unwrap
if (SelectionResult.ContainingScopeHasAsyncKeyword() &&
ContainsReturnStatementInSelectedCode())
{
var originalDefinition = returnType.OriginalDefinition;

// see whether it needs to be unwrapped
var model = this.SemanticDocument.SemanticModel;
var taskType = model.Compilation.TaskType();
if (originalDefinition.Equals(taskType))
return model.Compilation.GetSpecialType(SpecialType.System_Void);

var genericTaskType = model.Compilation.TaskOfTType();
if (originalDefinition.Equals(genericTaskType))
return ((INamedTypeSymbol)returnType).TypeArguments[0];
}

// nothing to unwrap
return returnType;
}

private (ITypeSymbol returnType, bool awaitTaskReturn) WrapReturnTypeInTask(ITypeSymbol returnType)
{
var compilation = this.SemanticModel.Compilation;
var taskType = compilation.TaskType();

// convert void to Task type
if (taskType is object && returnType.Equals(compilation.GetSpecialType(SpecialType.System_Void)))
return (taskType, awaitTaskReturn: true);

if (!SelectionResult.IsExtractMethodOnExpression && ContainsReturnStatementInSelectedCode())
return (returnType, awaitTaskReturn: false);

var genericTaskType = compilation.TaskOfTType();

// okay, wrap the return type in Task<T>
if (genericTaskType is object)
returnType = genericTaskType.Construct(returnType);

return (returnType, awaitTaskReturn: false);
}

private (ImmutableArray<VariableInfo> finalOrderedVariableInfos, ITypeSymbol returnType, bool returnsByRef)
GetSignatureInformation(Dictionary<ISymbol, VariableInfo> symbolMap)
{
Expand All @@ -217,7 +158,7 @@ private ITypeSymbol UnwrapTaskIfNeeded(ITypeSymbol returnType)
// check whether current selection contains return statement
var (returnType, returnsByRef) = SelectionResult.GetReturnTypeInfo(this.CancellationToken);

return (allVariableInfos, returnType, returnsByRef);
return (allVariableInfos, UnwrapTaskIfNeeded(returnType), returnsByRef);
}
else
{
Expand All @@ -230,9 +171,34 @@ private ITypeSymbol UnwrapTaskIfNeeded(ITypeSymbol returnType)
return (finalOrderedVariableInfos, returnType, returnsByRef: false);
}

ITypeSymbol UnwrapTaskIfNeeded(ITypeSymbol returnType)
{
if (this.SelectionResult.ContainingScopeHasAsyncKeyword())
{
// We compute the desired return type for the extract method from our own return type. But for
// the purposes of manipulating the return type, we need to get to the underlying type if this
// was wrapped in a Task in an explicitly 'async' method. In other words, if we're in an `async
// Task<int>` method, then we want the extract method to return `int`. Note: we will possibly
// then wrap that as `Task<int>` again if we see that we extracted out any await-expressions.

var compilation = this.SemanticModel.Compilation;
var knownTaskTypes = new KnownTaskTypes(compilation);

// Map from `Task/ValueTask` to `void`
if (returnType.Equals(knownTaskTypes.TaskType) || returnType.Equals(knownTaskTypes.ValueTaskType))
return compilation.GetSpecialType(SpecialType.System_Void);

// Map from `Task<T>/ValueTask<T>` to `T`
if (returnType.OriginalDefinition.Equals(knownTaskTypes.TaskOfTType) || returnType.OriginalDefinition.Equals(knownTaskTypes.ValueTaskOfTType))
return returnType.GetTypeArguments().Single();
}

return returnType;
}

ITypeSymbol GetReturnType(ImmutableArray<VariableInfo> variablesToUseAsReturnValue)
{
var compilation = this.SemanticDocument.SemanticModel.Compilation;
var compilation = this.SemanticModel.Compilation;

if (variablesToUseAsReturnValue.IsEmpty)
return compilation.GetSpecialType(SpecialType.System_Void);
Expand Down Expand Up @@ -289,19 +255,21 @@ private OperationStatus GetOperationStatus(
? OperationStatus.LocalFunctionCallWithoutDeclaration
: OperationStatus.SucceededStatus;

return readonlyFieldStatus.With(anonymousTypeStatus)
.With(unsafeAddressStatus)
.With(asyncRefOutParameterStatus)
.With(variableMapStatus)
.With(localFunctionStatus);
return readonlyFieldStatus
.With(anonymousTypeStatus)
.With(unsafeAddressStatus)
.With(asyncRefOutParameterStatus)
.With(variableMapStatus)
.With(localFunctionStatus);
}

private OperationStatus CheckAsyncMethodRefOutParameters(IList<VariableInfo> parameters)
{
if (SelectionResult.CreateAsyncMethod())
if (SelectionResult.ContainsAwaitExpression())
{
var names = parameters.Where(v => v is { UseAsReturnValue: false, ParameterModifier: ParameterBehavior.Out or ParameterBehavior.Ref })
.Select(p => p.Name ?? string.Empty);
var names = parameters
.Where(v => v is { UseAsReturnValue: false, ParameterModifier: ParameterBehavior.Out or ParameterBehavior.Ref })
.Select(p => p.Name ?? string.Empty);

if (names.Any())
return new OperationStatus(succeeded: true, string.Format(FeaturesResources.Asynchronous_method_cannot_have_ref_out_parameters_colon_bracket_0_bracket, string.Join(", ", names)));
Expand Down Expand Up @@ -345,7 +313,7 @@ private ImmutableArray<VariableInfo> MarkVariableInfosToUseAsReturnValueIfPossib
// return values of the method since we can't actually have out/ref with an async method.
var outRefCount = numberOfOutParameters + numberOfRefParameters;
if (outRefCount > 0 &&
this.SelectionResult.CreateAsyncMethod() &&
this.SelectionResult.ContainsAwaitExpression() &&
this.SyntaxFacts.SupportsTupleDeconstruction(this.SemanticDocument.Document.Project.ParseOptions!))
{
var result = new FixedSizeArrayBuilder<VariableInfo>(variableInfo.Length);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,12 @@

#nullable disable

using System;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Collections.ObjectModel;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.CodeAnalysis.PooledObjects;
using Roslyn.Utilities;

namespace Microsoft.CodeAnalysis.ExtractMethod;

Expand All @@ -27,7 +26,6 @@ protected sealed class AnalyzerResult(
ImmutableArray<VariableInfo> variables,
ITypeSymbol returnType,
bool returnsByRef,
bool awaitTaskReturn,
bool instanceMemberIsUsed,
bool shouldBeReadOnly,
bool endOfSelectionReachable,
Expand All @@ -53,11 +51,10 @@ protected sealed class AnalyzerResult(
public bool EndOfSelectionReachable { get; } = endOfSelectionReachable;

/// <summary>
/// flag to show whether task return type is due to await
/// Initial computed return type for the extract method. This does not include any wrapping in a type like
/// <see cref="Task{TResult}"/> for async methods.
/// </summary>
public bool AwaitTaskReturn { get; } = awaitTaskReturn;

public ITypeSymbol ReturnType { get; } = returnType;
public ITypeSymbol CoreReturnType { get; } = returnType;
public bool ReturnsByRef { get; } = returnsByRef;

/// <summary>
Expand All @@ -67,14 +64,6 @@ protected sealed class AnalyzerResult(

public ImmutableArray<VariableInfo> Variables { get; } = variables;

public bool HasReturnType
{
get
{
return ReturnType.SpecialType != SpecialType.System_Void && !AwaitTaskReturn;
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note: i genuinely don't know what this expression meant before. especially the AwaitTaskReturn part. This is much clearer now. There is a core return type (prior to any task wrapping) and it can be checked. To know if there is a return type or not, it can be checked against System.Void.

}
}

public ImmutableArray<VariableInfo> GetVariablesToSplitOrMoveIntoMethodDefinition()
{
return Variables.WhereAsArray(
Expand Down
Loading
Loading