Skip to content

Commit

Permalink
Use IsManagedType for determining which state machine hoisted local…
Browse files Browse the repository at this point in the history
…s to clear (#75841)
  • Loading branch information
jcouv authored Nov 14, 2024
1 parent 789a655 commit 24e78cf
Show file tree
Hide file tree
Showing 6 changed files with 964 additions and 72 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -422,31 +422,16 @@ internal static bool TryUnwrapBoundStateMachineScope(ref BoundStatement statemen
return false;
}

private void AddVariableCleanup(ArrayBuilder<BoundExpression> cleanup, FieldSymbol field)
{
if (MightContainReferences(field.Type))
{
cleanup.Add(F.AssignmentExpression(F.Field(F.This(), field), F.NullOrDefault(field.Type)));
}
}

/// <summary>
/// Might the given type be, or contain, managed references? This is used to determine which
/// fields allocated to temporaries should be cleared when the underlying variable goes out of scope, so
/// Clear fields allocated to temporaries when the underlying variable goes out of scope, so
/// that they do not cause unnecessary object retention.
/// </summary>
private bool MightContainReferences(TypeSymbol type)
private void AddVariableCleanup(ArrayBuilder<BoundExpression> cleanup, FieldSymbol field)
{
if (type.IsReferenceType || type.TypeKind == TypeKind.TypeParameter) return true; // type parameter or reference type
if (type.TypeKind != TypeKind.Struct) return false; // enums, etc
if (type.SpecialType == SpecialType.System_TypedReference) return true;
if (type.SpecialType.CanOptimizeBehavior()) return false; // int, etc
if (!type.IsFromCompilation(this.CompilationState.ModuleBuilderOpt.Compilation)) return true; // perhaps from ref assembly
foreach (var f in _emptyStructTypeCache.GetStructInstanceFields(type))
if (field.Type.IsManagedTypeNoUseSiteDiagnostics)
{
if (MightContainReferences(f.Type)) return true;
cleanup.Add(F.AssignmentExpression(F.Field(F.This(), field), F.NullOrDefault(field.Type)));
}
return false;
}

private StateMachineFieldSymbol GetOrAllocateReusableHoistedField(TypeSymbol type, out bool reused, LocalSymbol local = null)
Expand Down
326 changes: 326 additions & 0 deletions src/Compilers/CSharp/Test/Emit/CodeGen/CodeGenAsyncIteratorTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9417,5 +9417,331 @@ public static async IAsyncEnumerable<S<T>> M<T>(Task t, T t1, T t2) where T : un
expectedOutput: ExecutionConditionUtil.IsMonoOrCoreClr ? "42False43" : null,
verify: ExecutionConditionUtil.IsMonoOrCoreClr ? Verification.Passes : Verification.Skipped).VerifyDiagnostics();
}

[Fact, WorkItem("https://github.com/dotnet/roslyn/issues/75666")]
public void AddVariableCleanup_IntLocal()
{
string src = """
using System.Reflection;
var values = C.Produce();
await foreach (int value in values) { }
System.Console.Write(((int)values.GetType().GetField("<values2>5__2", BindingFlags.NonPublic | BindingFlags.Instance).GetValue(values)));
class C
{
public static async System.Collections.Generic.IAsyncEnumerable<int> Produce()
{
int values2 = 42;
await System.Threading.Tasks.Task.CompletedTask;
yield return values2;
}
}
""";
CompileAndVerify(src, expectedOutput: ExpectedOutput("42"), verify: Verification.Skipped, targetFramework: TargetFramework.Net80).VerifyDiagnostics();
}

[Fact, WorkItem("https://github.com/dotnet/roslyn/issues/75666")]
public void AddVariableCleanup_StringLocal()
{
string src = """
using System.Reflection;
var values = C.Produce();
await foreach (int value in values) { }
System.Console.Write(((string)values.GetType().GetField("<values2>5__2", BindingFlags.NonPublic | BindingFlags.Instance).GetValue(values)));
class C
{
public static async System.Collections.Generic.IAsyncEnumerable<int> Produce()
{
string values2 = "value ";
await System.Threading.Tasks.Task.CompletedTask;
yield return 1;
System.Console.Write(values2);
}
}
""";
// Note: hoisted top-level local does not get cleared when exiting normally
CompileAndVerify(src, expectedOutput: ExpectedOutput("value value"), verify: Verification.Skipped, targetFramework: TargetFramework.Net80).VerifyDiagnostics();
}

[Fact, WorkItem("https://github.com/dotnet/roslyn/issues/75666")]
public void AddVariableCleanup_NestedStringLocal()
{
string src = """
using System.Reflection;
var tcs = new System.Threading.Tasks.TaskCompletionSource();
var values = C.Produce(true, tcs.Task);
var enumerator = values.GetAsyncEnumerator();
assert(await enumerator.MoveNextAsync());
assert(enumerator.Current == 1);
assert(await enumerator.MoveNextAsync());
assert(enumerator.Current == 2);
_ = enumerator.MoveNextAsync();
System.Console.Write(((string)values.GetType().GetField("<values2>5__2", BindingFlags.NonPublic | BindingFlags.Instance).GetValue(values)) is null);
void assert(bool b)
{
if (!b) throw new System.Exception();
}
class C
{
public static async System.Collections.Generic.IAsyncEnumerable<int> Produce(bool b, System.Threading.Tasks.Task task)
{
while (b)
{
string values2 = "value ";
yield return 1;
System.Console.Write(values2);
b = false;
}
yield return 2;
await task;
yield return 3;
}
}
""";
// Note: hoisted nested local gets cleared when exiting nested scope normally
CompileAndVerify(src, expectedOutput: ExpectedOutput("value True"), verify: Verification.Skipped, targetFramework: TargetFramework.Net80).VerifyDiagnostics();
}

[Fact, WorkItem("https://github.com/dotnet/roslyn/issues/75666")]
public void AddVariableCleanup_NestedLocalWithStructFromAnotherCompilation()
{
var libSrc = """
public struct S
{
public int field;
public override string ToString() => field.ToString();
}
""";
var libComp = CreateCompilation(libSrc, targetFramework: TargetFramework.Net80);
string src = """
using System.Reflection;
var tcs = new System.Threading.Tasks.TaskCompletionSource();
var values = C.Produce(true, tcs.Task);
var enumerator = values.GetAsyncEnumerator();
assert(await enumerator.MoveNextAsync());
assert(enumerator.Current == 1);
assert(await enumerator.MoveNextAsync());
assert(enumerator.Current == 2);
_ = enumerator.MoveNextAsync();
System.Console.Write(((S)values.GetType().GetField("<values2>5__2", BindingFlags.NonPublic | BindingFlags.Instance).GetValue(values)));
void assert(bool b)
{
if (!b) throw new System.Exception();
}
class C
{
public static async System.Collections.Generic.IAsyncEnumerable<int> Produce(bool b, System.Threading.Tasks.Task task)
{
while (b)
{
S values2 = new S { field = 42 };
yield return 1;
System.Console.Write(values2);
b = false;
}
yield return 2;
await task;
yield return 3;
}
}
""";
var verifier = CompileAndVerify(src, expectedOutput: ExpectedOutput("4242"), references: [libComp.EmitToImageReference()],
verify: Verification.Skipped, targetFramework: TargetFramework.Net80).VerifyDiagnostics();

verifier.VerifyIL("C.<Produce>d__0.System.Runtime.CompilerServices.IAsyncStateMachine.MoveNext()", """
{
// Code size 437 (0x1b5)
.maxstack 3
.locals init (int V_0,
S V_1,
System.Runtime.CompilerServices.TaskAwaiter V_2,
C.<Produce>d__0 V_3,
System.Exception V_4)
IL_0000: ldarg.0
IL_0001: ldfld "int C.<Produce>d__0.<>1__state"
IL_0006: stloc.0
.try
{
IL_0007: ldloc.0
IL_0008: ldc.i4.s -6
IL_000a: sub
IL_000b: switch (
IL_0144,
IL_00bd,
IL_0072,
IL_002c,
IL_002c,
IL_002c,
IL_010e)
IL_002c: ldarg.0
IL_002d: ldfld "bool C.<Produce>d__0.<>w__disposeMode"
IL_0032: brfalse.s IL_0039
IL_0034: leave IL_0181
IL_0039: ldarg.0
IL_003a: ldc.i4.m1
IL_003b: dup
IL_003c: stloc.0
IL_003d: stfld "int C.<Produce>d__0.<>1__state"
IL_0042: br.s IL_009f
IL_0044: ldarg.0
IL_0045: ldloca.s V_1
IL_0047: initobj "S"
IL_004d: ldloca.s V_1
IL_004f: ldc.i4.s 42
IL_0051: stfld "int S.field"
IL_0056: ldloc.1
IL_0057: stfld "S C.<Produce>d__0.<values2>5__2"
IL_005c: ldarg.0
IL_005d: ldc.i4.1
IL_005e: stfld "int C.<Produce>d__0.<>2__current"
IL_0063: ldarg.0
IL_0064: ldc.i4.s -4
IL_0066: dup
IL_0067: stloc.0
IL_0068: stfld "int C.<Produce>d__0.<>1__state"
IL_006d: leave IL_01a8
IL_0072: ldarg.0
IL_0073: ldc.i4.m1
IL_0074: dup
IL_0075: stloc.0
IL_0076: stfld "int C.<Produce>d__0.<>1__state"
IL_007b: ldarg.0
IL_007c: ldfld "bool C.<Produce>d__0.<>w__disposeMode"
IL_0081: brfalse.s IL_0088
IL_0083: leave IL_0181
IL_0088: ldarg.0
IL_0089: ldfld "S C.<Produce>d__0.<values2>5__2"
IL_008e: box "S"
IL_0093: call "void System.Console.Write(object)"
IL_0098: ldarg.0
IL_0099: ldc.i4.0
IL_009a: stfld "bool C.<Produce>d__0.b"
IL_009f: ldarg.0
IL_00a0: ldfld "bool C.<Produce>d__0.b"
IL_00a5: brtrue.s IL_0044
IL_00a7: ldarg.0
IL_00a8: ldc.i4.2
IL_00a9: stfld "int C.<Produce>d__0.<>2__current"
IL_00ae: ldarg.0
IL_00af: ldc.i4.s -5
IL_00b1: dup
IL_00b2: stloc.0
IL_00b3: stfld "int C.<Produce>d__0.<>1__state"
IL_00b8: leave IL_01a8
IL_00bd: ldarg.0
IL_00be: ldc.i4.m1
IL_00bf: dup
IL_00c0: stloc.0
IL_00c1: stfld "int C.<Produce>d__0.<>1__state"
IL_00c6: ldarg.0
IL_00c7: ldfld "bool C.<Produce>d__0.<>w__disposeMode"
IL_00cc: brfalse.s IL_00d3
IL_00ce: leave IL_0181
IL_00d3: ldarg.0
IL_00d4: ldfld "System.Threading.Tasks.Task C.<Produce>d__0.task"
IL_00d9: callvirt "System.Runtime.CompilerServices.TaskAwaiter System.Threading.Tasks.Task.GetAwaiter()"
IL_00de: stloc.2
IL_00df: ldloca.s V_2
IL_00e1: call "bool System.Runtime.CompilerServices.TaskAwaiter.IsCompleted.get"
IL_00e6: brtrue.s IL_012a
IL_00e8: ldarg.0
IL_00e9: ldc.i4.0
IL_00ea: dup
IL_00eb: stloc.0
IL_00ec: stfld "int C.<Produce>d__0.<>1__state"
IL_00f1: ldarg.0
IL_00f2: ldloc.2
IL_00f3: stfld "System.Runtime.CompilerServices.TaskAwaiter C.<Produce>d__0.<>u__1"
IL_00f8: ldarg.0
IL_00f9: stloc.3
IL_00fa: ldarg.0
IL_00fb: ldflda "System.Runtime.CompilerServices.AsyncIteratorMethodBuilder C.<Produce>d__0.<>t__builder"
IL_0100: ldloca.s V_2
IL_0102: ldloca.s V_3
IL_0104: call "void System.Runtime.CompilerServices.AsyncIteratorMethodBuilder.AwaitUnsafeOnCompleted<System.Runtime.CompilerServices.TaskAwaiter, C.<Produce>d__0>(ref System.Runtime.CompilerServices.TaskAwaiter, ref C.<Produce>d__0)"
IL_0109: leave IL_01b4
IL_010e: ldarg.0
IL_010f: ldfld "System.Runtime.CompilerServices.TaskAwaiter C.<Produce>d__0.<>u__1"
IL_0114: stloc.2
IL_0115: ldarg.0
IL_0116: ldflda "System.Runtime.CompilerServices.TaskAwaiter C.<Produce>d__0.<>u__1"
IL_011b: initobj "System.Runtime.CompilerServices.TaskAwaiter"
IL_0121: ldarg.0
IL_0122: ldc.i4.m1
IL_0123: dup
IL_0124: stloc.0
IL_0125: stfld "int C.<Produce>d__0.<>1__state"
IL_012a: ldloca.s V_2
IL_012c: call "void System.Runtime.CompilerServices.TaskAwaiter.GetResult()"
IL_0131: ldarg.0
IL_0132: ldc.i4.3
IL_0133: stfld "int C.<Produce>d__0.<>2__current"
IL_0138: ldarg.0
IL_0139: ldc.i4.s -6
IL_013b: dup
IL_013c: stloc.0
IL_013d: stfld "int C.<Produce>d__0.<>1__state"
IL_0142: leave.s IL_01a8
IL_0144: ldarg.0
IL_0145: ldc.i4.m1
IL_0146: dup
IL_0147: stloc.0
IL_0148: stfld "int C.<Produce>d__0.<>1__state"
IL_014d: ldarg.0
IL_014e: ldfld "bool C.<Produce>d__0.<>w__disposeMode"
IL_0153: pop
IL_0154: leave.s IL_0181
}
catch System.Exception
{
IL_0156: stloc.s V_4
IL_0158: ldarg.0
IL_0159: ldc.i4.s -2
IL_015b: stfld "int C.<Produce>d__0.<>1__state"
IL_0160: ldarg.0
IL_0161: ldc.i4.0
IL_0162: stfld "int C.<Produce>d__0.<>2__current"
IL_0167: ldarg.0
IL_0168: ldflda "System.Runtime.CompilerServices.AsyncIteratorMethodBuilder C.<Produce>d__0.<>t__builder"
IL_016d: call "void System.Runtime.CompilerServices.AsyncIteratorMethodBuilder.Complete()"
IL_0172: ldarg.0
IL_0173: ldflda "System.Threading.Tasks.Sources.ManualResetValueTaskSourceCore<bool> C.<Produce>d__0.<>v__promiseOfValueOrEnd"
IL_0178: ldloc.s V_4
IL_017a: call "void System.Threading.Tasks.Sources.ManualResetValueTaskSourceCore<bool>.SetException(System.Exception)"
IL_017f: leave.s IL_01b4
}
IL_0181: ldarg.0
IL_0182: ldc.i4.s -2
IL_0184: stfld "int C.<Produce>d__0.<>1__state"
IL_0189: ldarg.0
IL_018a: ldc.i4.0
IL_018b: stfld "int C.<Produce>d__0.<>2__current"
IL_0190: ldarg.0
IL_0191: ldflda "System.Runtime.CompilerServices.AsyncIteratorMethodBuilder C.<Produce>d__0.<>t__builder"
IL_0196: call "void System.Runtime.CompilerServices.AsyncIteratorMethodBuilder.Complete()"
IL_019b: ldarg.0
IL_019c: ldflda "System.Threading.Tasks.Sources.ManualResetValueTaskSourceCore<bool> C.<Produce>d__0.<>v__promiseOfValueOrEnd"
IL_01a1: ldc.i4.0
IL_01a2: call "void System.Threading.Tasks.Sources.ManualResetValueTaskSourceCore<bool>.SetResult(bool)"
IL_01a7: ret
IL_01a8: ldarg.0
IL_01a9: ldflda "System.Threading.Tasks.Sources.ManualResetValueTaskSourceCore<bool> C.<Produce>d__0.<>v__promiseOfValueOrEnd"
IL_01ae: ldc.i4.1
IL_01af: call "void System.Threading.Tasks.Sources.ManualResetValueTaskSourceCore<bool>.SetResult(bool)"
IL_01b4: ret
}
""");
}
}
}
Loading

0 comments on commit 24e78cf

Please sign in to comment.