From db6369f06aa1fdca7130174aab6df13b2ec99e33 Mon Sep 17 00:00:00 2001 From: Jeremy Koritzinsky Date: Tue, 16 May 2023 11:53:19 -0700 Subject: [PATCH] Clean up some code after the interface inheritance work Do some code cleanup: - Remove unused usings - Delete dead code - Unnest types - Remove outdated TODOs - Fix incorrect diagnostics - Refactor ComMethodContext to cache some of its work and prevent multiple calculations. I recommend reviewing with whitespace changes disabled. --- .../ComInterfaceGenerator/AttributeInfo.cs | 22 -- .../ComInterfaceAndMethodsContext.cs | 31 +- .../ComInterfaceContext.cs | 65 ++--- .../ComInterfaceGenerator.cs | 68 +---- .../ComInterfaceGeneratorHelpers.cs | 2 - .../ComInterfaceGenerator/ComInterfaceInfo.cs | 211 +++++++------- .../ComInterfaceGenerator/ComMethodContext.cs | 270 ++++++++++-------- .../ComInterfaceGenerator/ComMethodInfo.cs | 155 +++++----- .../GeneratedStubCodeContext.cs | 14 + .../IncrementalMethodStubGenerationContext.cs | 1 - .../IncrementalValuesProviderExtensions.cs | 2 - .../gen/ComInterfaceGenerator/InlinedTypes.cs | 124 -------- .../ManagedToNativeVTableMethodGenerator.cs | 1 - .../ComInterfaceDispatchMarshallerFactory.cs | 1 - .../SkippedStubContext.cs | 10 + .../UnmanagedToManagedStubGenerator.cs | 2 - .../UnreachableException.cs | 2 - .../VirtualMethodPointerStubGenerator.cs | 1 - .../VtableIndexStubGenerator.cs | 2 - .../VtableIndexStubGeneratorHelpers.cs | 2 - 20 files changed, 418 insertions(+), 568 deletions(-) delete mode 100644 src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/AttributeInfo.cs create mode 100644 src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/GeneratedStubCodeContext.cs delete mode 100644 src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/InlinedTypes.cs create mode 100644 src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/SkippedStubContext.cs diff --git a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/AttributeInfo.cs b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/AttributeInfo.cs deleted file mode 100644 index 8443100fa46828..00000000000000 --- a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/AttributeInfo.cs +++ /dev/null @@ -1,22 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. - -using System.Linq; -using Microsoft.CodeAnalysis; -using Microsoft.CodeAnalysis.CSharp; - -namespace Microsoft.Interop -{ - /// - /// Provides the info necessary for copying an attribute from user code to generated code. - /// - internal sealed record AttributeInfo(ManagedTypeInfo Type, SequenceEqualImmutableArray Arguments) - { - internal static AttributeInfo From(AttributeData attribute) - { - var type = ManagedTypeInfo.CreateTypeInfoForTypeSymbol(attribute.AttributeClass); - var args = attribute.ConstructorArguments.Select(ca => ca.ToCSharpString()); - return new(type, args.ToSequenceEqualImmutableArray()); - } - } -} diff --git a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceAndMethodsContext.cs b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceAndMethodsContext.cs index 82f1e2a5987063..dd6797f1a49f1d 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceAndMethodsContext.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceAndMethodsContext.cs @@ -7,26 +7,23 @@ namespace Microsoft.Interop { - public sealed partial class ComInterfaceGenerator + /// + /// Represents an interface and all of the methods that need to be generated for it (methods declared on the interface and methods inherited from base interfaces). + /// + internal sealed record ComInterfaceAndMethodsContext(ComInterfaceContext Interface, SequenceEqualImmutableArray Methods) { + // Change Calc all methods to return an ordered list of all the methods and the data in comInterfaceandMethodsContext + // Have a step that runs CalculateMethodStub on each of them. + // Call GroupMethodsByInterfaceForGeneration + /// - /// Represents an interface and all of the methods that need to be generated for it (methods declared on the interface and methods inherited from base interfaces). + /// COM methods that are declared on the attributed interface declaration. /// - private sealed record ComInterfaceAndMethodsContext(ComInterfaceContext Interface, SequenceEqualImmutableArray Methods) - { - // Change Calc all methods to return an ordered list of all the methods and the data in comInterfaceandMethodsContext - // Have a step that runs CalculateMethodStub on each of them. - // Call GroupMethodsByInterfaceForGeneration - - /// - /// COM methods that are declared on the attributed interface declaration. - /// - public IEnumerable DeclaredMethods => Methods.Where(m => !m.IsInheritedMethod); + public IEnumerable DeclaredMethods => Methods.Where(m => !m.IsInheritedMethod); - /// - /// COM methods that are declared on an interface the interface inherits from. - /// - public IEnumerable ShadowingMethods => Methods.Where(m => m.IsInheritedMethod); - } + /// + /// COM methods that are declared on an interface the interface inherits from. + /// + public IEnumerable ShadowingMethods => Methods.Where(m => m.IsInheritedMethod); } } diff --git a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceContext.cs b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceContext.cs index 5b570c1f0455ac..0eee960740269a 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceContext.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceContext.cs @@ -7,51 +7,48 @@ namespace Microsoft.Interop { - public sealed partial class ComInterfaceGenerator + internal sealed record ComInterfaceContext(ComInterfaceInfo Info, ComInterfaceContext? Base) { - private sealed record ComInterfaceContext(ComInterfaceInfo Info, ComInterfaceContext? Base) + /// + /// Takes a list of ComInterfaceInfo, and creates a list of ComInterfaceContext. + /// + public static ImmutableArray GetContexts(ImmutableArray data, CancellationToken _) { - /// - /// Takes a list of ComInterfaceInfo, and creates a list of ComInterfaceContext. - /// - public static ImmutableArray GetContexts(ImmutableArray data, CancellationToken _) + Dictionary symbolToInterfaceInfoMap = new(); + var accumulator = ImmutableArray.CreateBuilder(data.Length); + foreach (var iface in data) { - Dictionary symbolToInterfaceInfoMap = new(); - var accumulator = ImmutableArray.CreateBuilder(data.Length); - foreach (var iface in data) + symbolToInterfaceInfoMap.Add(iface.ThisInterfaceKey, iface); + } + Dictionary symbolToContextMap = new(); + + foreach (var iface in data) + { + accumulator.Add(AddContext(iface)); + } + return accumulator.MoveToImmutable(); + + ComInterfaceContext AddContext(ComInterfaceInfo iface) + { + if (symbolToContextMap.TryGetValue(iface.ThisInterfaceKey, out var cachedValue)) { - symbolToInterfaceInfoMap.Add(iface.ThisInterfaceKey, iface); + return cachedValue; } - Dictionary symbolToContextMap = new(); - foreach (var iface in data) + if (iface.BaseInterfaceKey is null) { - accumulator.Add(AddContext(iface)); + var baselessCtx = new ComInterfaceContext(iface, null); + symbolToContextMap[iface.ThisInterfaceKey] = baselessCtx; + return baselessCtx; } - return accumulator.MoveToImmutable(); - ComInterfaceContext AddContext(ComInterfaceInfo iface) + if (!symbolToContextMap.TryGetValue(iface.BaseInterfaceKey, out var baseContext)) { - if (symbolToContextMap.TryGetValue(iface.ThisInterfaceKey, out var cachedValue)) - { - return cachedValue; - } - - if (iface.BaseInterfaceKey is null) - { - var baselessCtx = new ComInterfaceContext(iface, null); - symbolToContextMap[iface.ThisInterfaceKey] = baselessCtx; - return baselessCtx; - } - - if (!symbolToContextMap.TryGetValue(iface.BaseInterfaceKey, out var baseContext)) - { - baseContext = AddContext(symbolToInterfaceInfoMap[iface.BaseInterfaceKey]); - } - var ctx = new ComInterfaceContext(iface, baseContext); - symbolToContextMap[iface.ThisInterfaceKey] = ctx; - return ctx; + baseContext = AddContext(symbolToInterfaceInfoMap[iface.BaseInterfaceKey]); } + var ctx = new ComInterfaceContext(iface, baseContext); + symbolToContextMap[iface.ThisInterfaceKey] = ctx; + return ctx; } } } diff --git a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceGenerator.cs b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceGenerator.cs index 303a20aa638490..6a985fcc3b6a8c 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceGenerator.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceGenerator.cs @@ -2,9 +2,7 @@ // The .NET Foundation licenses this file to you under the MIT license. using System; -using System.Collections.Generic; using System.Collections.Immutable; -using System.Collections.Specialized; using System.IO; using System.Linq; using System.Reflection; @@ -19,14 +17,6 @@ namespace Microsoft.Interop [Generator] public sealed partial class ComInterfaceGenerator : IIncrementalGenerator { - private sealed record class GeneratedStubCodeContext( - ManagedTypeInfo OriginalDefiningType, - ContainingSyntaxContext ContainingSyntaxContext, - SyntaxEquivalentNode Stub, - SequenceEqualImmutableArray Diagnostics) : GeneratedMethodContextBase(OriginalDefiningType, Diagnostics); - - private sealed record SkippedStubContext(ManagedTypeInfo OriginalDefiningType) : GeneratedMethodContextBase(OriginalDefiningType, new(ImmutableArray.Empty)); - public static class StepNames { public const string CalculateStubInformation = nameof(CalculateStubInformation); @@ -103,11 +93,9 @@ public void Initialize(IncrementalGeneratorInitializationContext context) { var ((data, symbolMap), env) = param; return new ComMethodContext( - data.Method.OriginalDeclaringInterface, - data.TypeKeyOwner, - data.Method.MethodInfo, - data.Method.Index, - CalculateStubInformation(data.Method.MethodInfo.Syntax, symbolMap[data.Method.MethodInfo], data.Method.Index, env, data.TypeKeyOwner.Info.Type, ct)); + data.Method, + data.OwningInterface, + CalculateStubInformation(data.Method.MethodInfo.Syntax, symbolMap[data.Method.MethodInfo], data.Method.Index, env, data.OwningInterface.Info.Type, ct)); }).WithTrackingName(StepNames.CalculateStubInformation); var interfaceAndMethodsContexts = comMethodContexts @@ -117,7 +105,7 @@ public void Initialize(IncrementalGeneratorInitializationContext context) // Generate the code for the managed-to-unmanaged stubs and the diagnostics from code-generation. context.RegisterDiagnostics(interfaceAndMethodsContexts - .SelectMany((data, ct) => data.DeclaredMethods.SelectMany(m => m.GetManagedToUnmanagedStub().Diagnostics))); + .SelectMany((data, ct) => data.DeclaredMethods.SelectMany(m => m.ManagedToUnmanagedStub.Diagnostics))); var managedToNativeInterfaceImplementations = interfaceAndMethodsContexts .Select(GenerateImplementationInterface) .WithTrackingName(StepNames.GenerateManagedToNativeInterfaceImplementation) @@ -126,7 +114,7 @@ public void Initialize(IncrementalGeneratorInitializationContext context) // Generate the code for the unmanaged-to-managed stubs and the diagnostics from code-generation. context.RegisterDiagnostics(interfaceAndMethodsContexts - .SelectMany((data, ct) => data.DeclaredMethods.SelectMany(m => m.GetNativeToManagedStub().Diagnostics))); + .SelectMany((data, ct) => data.DeclaredMethods.SelectMany(m => m.UnmanagedToManagedStub.Diagnostics))); var nativeToManagedVtableMethods = interfaceAndMethodsContexts .Select(GenerateImplementationVTableMethods) .WithTrackingName(StepNames.GenerateNativeToManagedVTableMethods) @@ -145,11 +133,11 @@ public void Initialize(IncrementalGeneratorInitializationContext context) .Select((data, ct) => { var context = data.Interface.Info; - var methods = data.ShadowingMethods.Select(m => (MemberDeclarationSyntax)m.GenerateShadow()); + var methods = data.ShadowingMethods.Select(m => m.Shadow); var typeDecl = TypeDeclaration(context.ContainingSyntax.TypeKind, context.ContainingSyntax.Identifier) .WithModifiers(context.ContainingSyntax.Modifiers) .WithTypeParameterList(context.ContainingSyntax.TypeParameters) - .WithMembers(List(methods)); + .WithMembers(List(methods)); return data.Interface.Info.TypeDefinitionContext.WrapMemberInContainingSyntaxWithUnsafeModifier(typeDecl); }) .WithTrackingName(StepNames.GenerateShadowingMethods) @@ -210,33 +198,6 @@ public void Initialize(IncrementalGeneratorInitializationContext context) }); } - private static string GenerateMarkerInterfaceSource(ComInterfaceInfo iface) => $$""" - file unsafe class InterfaceInformation : global::System.Runtime.InteropServices.Marshalling.IIUnknownInterfaceType - { - public static global::System.Guid Iid => new(new global::System.ReadOnlySpan(new byte[] { {{string.Join(",", iface.InterfaceId.ToByteArray())}} })); - - private static void** m_vtable; - - public static void** ManagedVirtualMethodTable - { - get - { - if (m_vtable == null) - { - nint* vtable = (nint*)global::System.Runtime.CompilerServices.RuntimeHelpers.AllocateTypeAssociatedMemory(typeof({{iface.Type.FullTypeName}}), sizeof(nint) * 3); - global::System.Runtime.InteropServices.ComWrappers.GetIUnknownImpl(out vtable[0], out vtable[1], out vtable[2]); - m_vtable = (void**)vtable; - } - return m_vtable; - } - } - } - - [global::System.Runtime.InteropServices.DynamicInterfaceCastableImplementation] - file interface InterfaceImplementation : {{iface.Type.FullTypeName}} - {} - """; - private static readonly AttributeSyntax s_iUnknownDerivedAttributeTemplate = Attribute( GenericName(TypeNames.IUnknownDerivedAttribute) @@ -251,8 +212,7 @@ private static MemberDeclarationSyntax GenerateIUnknownDerivedAttributeApplicati .WithTypeParameterList(context.ContainingSyntax.TypeParameters) .AddAttributeLists(AttributeList(SingletonSeparatedList(s_iUnknownDerivedAttributeTemplate)))); - // Todo: extract info needed from the IMethodSymbol into MethodInfo and only pass a MethodInfo here - private static IncrementalMethodStubGenerationContext CalculateStubInformation(MethodDeclarationSyntax syntax, IMethodSymbol symbol, int index, StubEnvironment environment, ManagedTypeInfo typeKeyOwner, CancellationToken ct) + private static IncrementalMethodStubGenerationContext CalculateStubInformation(MethodDeclarationSyntax syntax, IMethodSymbol symbol, int index, StubEnvironment environment, ManagedTypeInfo owningInterface, CancellationToken ct) { ct.ThrowIfCancellationRequested(); INamedTypeSymbol? lcidConversionAttrType = environment.Compilation.GetTypeByMetadataName(TypeNames.LCIDConversionAttribute); @@ -365,7 +325,7 @@ private static IncrementalMethodStubGenerationContext CalculateStubInformation(M new ComExceptionMarshalling(), ComInterfaceGeneratorHelpers.CreateGeneratorFactory(environment, MarshalDirection.ManagedToUnmanaged), ComInterfaceGeneratorHelpers.CreateGeneratorFactory(environment, MarshalDirection.UnmanagedToManaged), - typeKeyOwner, + owningInterface, declaringType, generatorDiagnostics.Diagnostics.ToSequenceEqualImmutableArray(), ComInterfaceDispatchMarshallingInfo.Instance); @@ -412,31 +372,32 @@ private static ImmutableArray GroupComContextsFor private static InterfaceDeclarationSyntax GenerateImplementationInterface(ComInterfaceAndMethodsContext interfaceGroup, CancellationToken _) { var definingType = interfaceGroup.Interface.Info.Type; - var shadowImplementations = interfaceGroup.ShadowingMethods.Select(m => (Method: m, ManagedToUnmanagedStub: m.GetManagedToUnmanagedStub())) + var shadowImplementations = interfaceGroup.ShadowingMethods.Select(m => (Method: m, ManagedToUnmanagedStub: m.ManagedToUnmanagedStub)) .Where(p => p.ManagedToUnmanagedStub is GeneratedStubCodeContext) .Select(ctx => ((GeneratedStubCodeContext)ctx.ManagedToUnmanagedStub).Stub.Node .WithExplicitInterfaceSpecifier( ExplicitInterfaceSpecifier(ParseName(definingType.FullTypeName)))); - var inheritedStubs = interfaceGroup.ShadowingMethods.Select(m => m.GenerateUnreachableExceptionStub()); + var inheritedStubs = interfaceGroup.ShadowingMethods.Select(m => m.UnreachableExceptionStub); return ImplementationInterfaceTemplate .AddBaseListTypes(SimpleBaseType(definingType.Syntax)) .WithMembers( List( interfaceGroup.DeclaredMethods - .Select(m => m.GetManagedToUnmanagedStub()) + .Select(m => m.ManagedToUnmanagedStub) .OfType() .Select(ctx => ctx.Stub.Node) .Concat(shadowImplementations) .Concat(inheritedStubs))) .AddAttributeLists(AttributeList(SingletonSeparatedList(Attribute(ParseName(TypeNames.System_Runtime_InteropServices_DynamicInterfaceCastableImplementationAttribute))))); } + private static InterfaceDeclarationSyntax GenerateImplementationVTableMethods(ComInterfaceAndMethodsContext comInterfaceAndMethods, CancellationToken _) { return ImplementationInterfaceTemplate .WithMembers( List( comInterfaceAndMethods.DeclaredMethods - .Select(m => m.GetNativeToManagedStub()) + .Select(m => m.UnmanagedToManagedStub) .OfType() .Select(context => context.Stub.Node))); } @@ -447,6 +408,7 @@ private static InterfaceDeclarationSyntax GenerateImplementationVTableMethods(Co private static readonly MethodDeclarationSyntax CreateManagedVirtualFunctionTableMethodTemplate = MethodDeclaration(VoidStarStarSyntax, CreateManagedVirtualFunctionTableMethodName) .AddModifiers(Token(SyntaxKind.InternalKeyword), Token(SyntaxKind.StaticKeyword)); + private static InterfaceDeclarationSyntax GenerateImplementationVTable(ComInterfaceAndMethodsContext interfaceMethods, CancellationToken _) { const string vtableLocalName = "vtable"; diff --git a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceGeneratorHelpers.cs b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceGeneratorHelpers.cs index e192fcd02d8feb..8853d907103282 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceGeneratorHelpers.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceGeneratorHelpers.cs @@ -2,9 +2,7 @@ // The .NET Foundation licenses this file to you under the MIT license. using System; -using System.Collections.Generic; using System.Linq; -using System.Text; using Microsoft.CodeAnalysis; namespace Microsoft.Interop diff --git a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceInfo.cs b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceInfo.cs index cd747bc6dc1c28..dd81de2507d4cb 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceInfo.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceInfo.cs @@ -2,150 +2,145 @@ // The .NET Foundation licenses this file to you under the MIT license. using System; -using System.Diagnostics; using System.Diagnostics.CodeAnalysis; using System.Linq; using Microsoft.CodeAnalysis; using Microsoft.CodeAnalysis.CSharp; using Microsoft.CodeAnalysis.CSharp.Syntax; -using Roslyn.Utilities; namespace Microsoft.Interop { - public sealed partial class ComInterfaceGenerator + /// + /// Information about a Com interface, but not its methods. + /// + internal sealed record ComInterfaceInfo( + ManagedTypeInfo Type, + string ThisInterfaceKey, // For associating interfaces to its base + string? BaseInterfaceKey, // For associating interfaces to its base + InterfaceDeclarationSyntax Declaration, + ContainingSyntaxContext TypeDefinitionContext, + ContainingSyntax ContainingSyntax, + Guid InterfaceId) { - /// - /// Information about a Com interface, but not its methods. - /// - private sealed record ComInterfaceInfo( - ManagedTypeInfo Type, - string ThisInterfaceKey, // For associating interfaces to its base - string? BaseInterfaceKey, // For associating interfaces to its base - InterfaceDeclarationSyntax Declaration, - ContainingSyntaxContext TypeDefinitionContext, - ContainingSyntax ContainingSyntax, - Guid InterfaceId) + public static (ComInterfaceInfo? Info, Diagnostic? Diagnostic) From(INamedTypeSymbol symbol, InterfaceDeclarationSyntax syntax) { - public static (ComInterfaceInfo? Info, Diagnostic? Diagnostic) From(INamedTypeSymbol symbol, InterfaceDeclarationSyntax syntax) + // Verify the method has no generic types or defined implementation + // and is not marked static or sealed + if (syntax.TypeParameterList is not null) + { + return (null, Diagnostic.Create( + GeneratorDiagnostics.InvalidAttributedMethodSignature, + syntax.Identifier.GetLocation(), + symbol.Name)); + } + + // Verify that the types the method is declared in are marked partial. + for (SyntaxNode? parentNode = syntax.Parent; parentNode is TypeDeclarationSyntax typeDecl; parentNode = parentNode.Parent) { - // Verify the method has no generic types or defined implementation - // and is not marked static or sealed - if (syntax.TypeParameterList is not null) + if (!typeDecl.Modifiers.Any(SyntaxKind.PartialKeyword)) { return (null, Diagnostic.Create( - GeneratorDiagnostics.InvalidAttributedMethodSignature, + GeneratorDiagnostics.InvalidAttributedMethodContainingTypeMissingModifiers, syntax.Identifier.GetLocation(), - symbol.Name)); - } - - // Verify that the types the method is declared in are marked partial. - for (SyntaxNode? parentNode = syntax.Parent; parentNode is TypeDeclarationSyntax typeDecl; parentNode = parentNode.Parent) - { - if (!typeDecl.Modifiers.Any(SyntaxKind.PartialKeyword)) - { - return (null, Diagnostic.Create( - GeneratorDiagnostics.InvalidAttributedMethodContainingTypeMissingModifiers, - syntax.Identifier.GetLocation(), - symbol.Name, - typeDecl.Identifier)); - } + symbol.Name, + typeDecl.Identifier)); } + } - if (!TryGetGuid(symbol, syntax, out Guid? guid, out Diagnostic? guidDiagnostic)) - return (null, guidDiagnostic); + if (!TryGetGuid(symbol, syntax, out Guid? guid, out Diagnostic? guidDiagnostic)) + return (null, guidDiagnostic); - if (!TryGetBaseComInterface(symbol, syntax, out INamedTypeSymbol? baseSymbol, out Diagnostic? baseDiagnostic)) - return (null, baseDiagnostic); + if (!TryGetBaseComInterface(symbol, syntax, out INamedTypeSymbol? baseSymbol, out Diagnostic? baseDiagnostic)) + return (null, baseDiagnostic); - return (new ComInterfaceInfo( - ManagedTypeInfo.CreateTypeInfoForTypeSymbol(symbol), - symbol.ToDisplayString(), - baseSymbol?.ToDisplayString(), - syntax, - new ContainingSyntaxContext(syntax), - new ContainingSyntax(syntax.Modifiers, syntax.Kind(), syntax.Identifier, syntax.TypeParameterList), - guid ?? Guid.Empty), null); - } + return (new ComInterfaceInfo( + ManagedTypeInfo.CreateTypeInfoForTypeSymbol(symbol), + symbol.ToDisplayString(), + baseSymbol?.ToDisplayString(), + syntax, + new ContainingSyntaxContext(syntax), + new ContainingSyntax(syntax.Modifiers, syntax.Kind(), syntax.Identifier, syntax.TypeParameterList), + guid ?? Guid.Empty), null); + } - /// - /// Returns true if there is 0 or 1 base Com interfaces (i.e. the inheritance is valid), and returns false when there are 2 or more base Com interfaces and sets . - /// - private static bool TryGetBaseComInterface(INamedTypeSymbol comIface, InterfaceDeclarationSyntax syntax, out INamedTypeSymbol? baseComIface, [NotNullWhen(false)] out Diagnostic? diagnostic) + /// + /// Returns true if there is 0 or 1 base Com interfaces (i.e. the inheritance is valid), and returns false when there are 2 or more base Com interfaces and sets . + /// + private static bool TryGetBaseComInterface(INamedTypeSymbol comIface, InterfaceDeclarationSyntax syntax, out INamedTypeSymbol? baseComIface, [NotNullWhen(false)] out Diagnostic? diagnostic) + { + baseComIface = null; + foreach (var implemented in comIface.Interfaces) { - baseComIface = null; - foreach (var implemented in comIface.Interfaces) + foreach (var attr in implemented.GetAttributes()) { - foreach (var attr in implemented.GetAttributes()) + if (attr.AttributeClass?.ToDisplayString() == TypeNames.GeneratedComInterfaceAttribute) { - if (attr.AttributeClass?.ToDisplayString() == TypeNames.GeneratedComInterfaceAttribute) + if (baseComIface is not null) { - if (baseComIface is not null) - { - diagnostic = Diagnostic.Create( - GeneratorDiagnostics.MultipleComInterfaceBaseTypes, - syntax.Identifier.GetLocation(), - comIface.ToDisplayString()); - return false; - } - baseComIface = implemented; + diagnostic = Diagnostic.Create( + GeneratorDiagnostics.MultipleComInterfaceBaseTypes, + syntax.Identifier.GetLocation(), + comIface.ToDisplayString()); + return false; } + baseComIface = implemented; } } - diagnostic = null; - return true; } + diagnostic = null; + return true; + } - /// - /// Returns true and sets if the guid is present. Returns false and sets diagnostic if the guid is not present or is invalid. - /// - private static bool TryGetGuid(INamedTypeSymbol interfaceSymbol, InterfaceDeclarationSyntax syntax, [NotNullWhen(true)] out Guid? guid, [NotNullWhen(false)] out Diagnostic? diagnostic) + /// + /// Returns true and sets if the guid is present. Returns false and sets diagnostic if the guid is not present or is invalid. + /// + private static bool TryGetGuid(INamedTypeSymbol interfaceSymbol, InterfaceDeclarationSyntax syntax, [NotNullWhen(true)] out Guid? guid, [NotNullWhen(false)] out Diagnostic? diagnostic) + { + guid = null; + AttributeData? guidAttr = null; + AttributeData? _ = null; // Interface Attribute Type. We'll always assume IUnkown for now. + foreach (var attr in interfaceSymbol.GetAttributes()) { - guid = null; - AttributeData? guidAttr = null; - AttributeData? _ = null; // Interface Attribute Type. We'll always assume IUnkown for now. - foreach (var attr in interfaceSymbol.GetAttributes()) - { - var attrDisplayString = attr.AttributeClass?.ToDisplayString(); - if (attrDisplayString is TypeNames.System_Runtime_InteropServices_GuidAttribute) - guidAttr = attr; - else if (attrDisplayString is TypeNames.InterfaceTypeAttribute) - _ = attr; - } - - if (guidAttr is not null - && guidAttr.ConstructorArguments.Length == 1 - && guidAttr.ConstructorArguments[0].Value is string guidStr - && Guid.TryParse(guidStr, out var result)) - { - guid = result; - } - - // Assume interfaceType is IUnknown for now - if (guid is null) - { - diagnostic = Diagnostic.Create( - GeneratorDiagnostics.InvalidAttributedInterfaceMissingGuidAttribute, - syntax.Identifier.GetLocation(), - interfaceSymbol.ToDisplayString()); - return false; - } - diagnostic = null; - return true; + var attrDisplayString = attr.AttributeClass?.ToDisplayString(); + if (attrDisplayString is TypeNames.System_Runtime_InteropServices_GuidAttribute) + guidAttr = attr; + else if (attrDisplayString is TypeNames.InterfaceTypeAttribute) + _ = attr; } - public override int GetHashCode() + if (guidAttr is not null + && guidAttr.ConstructorArguments.Length == 1 + && guidAttr.ConstructorArguments[0].Value is string guidStr + && Guid.TryParse(guidStr, out var result)) { - // ContainingSyntax and ContainingSyntaxContext do not implement GetHashCode - return HashCode.Combine(Type, TypeDefinitionContext, InterfaceId); + guid = result; } - public bool Equals(ComInterfaceInfo other) + // Assume interfaceType is IUnknown for now + if (guid is null) { - // ContainingSyntax and ContainingSyntaxContext are not used in the hash code - return Type == other.Type - && TypeDefinitionContext == other.TypeDefinitionContext - && InterfaceId == other.InterfaceId; + diagnostic = Diagnostic.Create( + GeneratorDiagnostics.InvalidAttributedInterfaceMissingGuidAttribute, + syntax.Identifier.GetLocation(), + interfaceSymbol.ToDisplayString()); + return false; } + diagnostic = null; + return true; + } + + public override int GetHashCode() + { + // ContainingSyntax and ContainingSyntaxContext do not implement GetHashCode + return HashCode.Combine(Type, TypeDefinitionContext, InterfaceId); + } + + public bool Equals(ComInterfaceInfo other) + { + // ContainingSyntax and ContainingSyntaxContext are not used in the hash code + return Type == other.Type + && TypeDefinitionContext == other.TypeDefinitionContext + && InterfaceId == other.InterfaceId; } } } diff --git a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComMethodContext.cs b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComMethodContext.cs index 31708cb9b8a1ad..daf7c149adf922 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComMethodContext.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComMethodContext.cs @@ -1,6 +1,7 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. +using System; using System.Collections.Generic; using System.Collections.Immutable; using System.Linq; @@ -12,155 +13,196 @@ namespace Microsoft.Interop { - public sealed partial class ComInterfaceGenerator + + /// + /// Represents a method, its declaring interface, and its index in the interface's vtable. + /// This type contains all information necessary to generate the corresponding methods in the ComInterfaceGenerator + /// + internal sealed class ComMethodContext : IEquatable { /// - /// Represents a method, its declaring interface, and its index in the interface's vtable. - /// This type contains all information necessary to generate the corresponding methods in the ComInterfaceGenerator + /// A partially constructed that does not have a generated for it yet. + /// can be constructed without a reference to an ISymbol, whereas the requires an ISymbol /// /// /// The interface that originally declared the method in user code /// - /// - /// The interface that this methods is being generated for (may be different that OriginalDeclaringInterface if it is an inherited method) - /// /// The basic information about the method. - /// The index on the interface vtable that points to this method - /// - private sealed record ComMethodContext( + /// The vtable index for the method. + public sealed record Builder(ComInterfaceContext OriginalDeclaringInterface, ComMethodInfo MethodInfo, int Index); + + /// + /// The fully-constructed immutable state for a . + /// + private record struct State( ComInterfaceContext OriginalDeclaringInterface, ComInterfaceContext OwningInterface, ComMethodInfo MethodInfo, - int Index, - IncrementalMethodStubGenerationContext GenerationContext) + IncrementalMethodStubGenerationContext GenerationContext); + + private readonly State _state; + + /// + /// Construct a full method context from the , context, and additional information. + /// + /// The partially constructed context + /// The final owning interface of this method context + /// The generation context for this method + public ComMethodContext(Builder builder, ComInterfaceContext owningInterface, IncrementalMethodStubGenerationContext generationContext) { - /// - /// A partially constructed that does not have a generated for it yet. - /// can be constructed without a reference to an ISymbol, whereas the requires an ISymbol - /// - public sealed record Builder(ComInterfaceContext OriginalDeclaringInterface, ComMethodInfo MethodInfo, int Index); + _state = new State(builder.OriginalDeclaringInterface, owningInterface, builder.MethodInfo, generationContext); + } - public bool IsInheritedMethod => OriginalDeclaringInterface != OwningInterface; + public override bool Equals(object obj) => obj is ComMethodContext other && Equals(other); - public GeneratedMethodContextBase GetManagedToUnmanagedStub() - { - if (GenerationContext.VtableIndexData.Direction is not (MarshalDirection.ManagedToUnmanaged or MarshalDirection.Bidirectional)) - { - return new SkippedStubContext(OriginalDeclaringInterface.Info.Type); - } - var (methodStub, diagnostics) = VirtualMethodPointerStubGenerator.GenerateManagedToNativeStub(GenerationContext); - return new GeneratedStubCodeContext(GenerationContext.TypeKeyOwner, GenerationContext.ContainingSyntaxContext, new(methodStub), new(diagnostics)); - } + public override int GetHashCode() => _state.GetHashCode(); + + public bool Equals(ComMethodContext other) => _state.Equals(other); + + public ComInterfaceContext OriginalDeclaringInterface => _state.OriginalDeclaringInterface; + + public ComInterfaceContext OwningInterface => _state.OwningInterface; + + public ComMethodInfo MethodInfo => _state.MethodInfo; + + public IncrementalMethodStubGenerationContext GenerationContext => _state.GenerationContext; - public GeneratedMethodContextBase GetNativeToManagedStub() + public bool IsInheritedMethod => OriginalDeclaringInterface != OwningInterface; + + private GeneratedMethodContextBase? _managedToUnmanagedStub; + + public GeneratedMethodContextBase ManagedToUnmanagedStub => _managedToUnmanagedStub ??= CreateManagedToUnmanagedStub(); + + private GeneratedMethodContextBase CreateManagedToUnmanagedStub() + { + if (GenerationContext.VtableIndexData.Direction is not (MarshalDirection.ManagedToUnmanaged or MarshalDirection.Bidirectional)) { - if (GenerationContext.VtableIndexData.Direction is not (MarshalDirection.UnmanagedToManaged or MarshalDirection.Bidirectional)) - { - return new SkippedStubContext(GenerationContext.OriginalDefiningType); - } - var (methodStub, diagnostics) = VirtualMethodPointerStubGenerator.GenerateNativeToManagedStub(GenerationContext); - return new GeneratedStubCodeContext(GenerationContext.OriginalDefiningType, GenerationContext.ContainingSyntaxContext, new(methodStub), new(diagnostics)); + return new SkippedStubContext(OriginalDeclaringInterface.Info.Type); } + var (methodStub, diagnostics) = VirtualMethodPointerStubGenerator.GenerateManagedToNativeStub(GenerationContext); + return new GeneratedStubCodeContext(GenerationContext.TypeKeyOwner, GenerationContext.ContainingSyntaxContext, new(methodStub), new(diagnostics)); + } - public MethodDeclarationSyntax GenerateUnreachableExceptionStub() + private GeneratedMethodContextBase? _unmanagedToManagedStub; + + public GeneratedMethodContextBase UnmanagedToManagedStub => _unmanagedToManagedStub ??= CreateUnmanagedToManagedStub(); + + private GeneratedMethodContextBase CreateUnmanagedToManagedStub() + { + if (GenerationContext.VtableIndexData.Direction is not (MarshalDirection.UnmanagedToManaged or MarshalDirection.Bidirectional)) { - // DeclarationCopiedFromBaseDeclaration() => throw new UnreachableException("This method should not be reached"); - return MethodInfo.Syntax - .WithModifiers(TokenList()) - .WithAttributeLists(List()) - .WithExplicitInterfaceSpecifier(ExplicitInterfaceSpecifier( - ParseName(OriginalDeclaringInterface.Info.Type.FullTypeName))) - .WithExpressionBody(ArrowExpressionClause( - ThrowExpression( - ObjectCreationExpression( - ParseTypeName(TypeNames.UnreachableException)) - .WithArgumentList(ArgumentList())))); + return new SkippedStubContext(GenerationContext.OriginalDefiningType); } + var (methodStub, diagnostics) = VirtualMethodPointerStubGenerator.GenerateNativeToManagedStub(GenerationContext); + return new GeneratedStubCodeContext(GenerationContext.OriginalDefiningType, GenerationContext.ContainingSyntaxContext, new(methodStub), new(diagnostics)); + } + + private MethodDeclarationSyntax? _unreachableExceptionStub; + + public MethodDeclarationSyntax UnreachableExceptionStub => _unreachableExceptionStub ??= CreateUnreachableExceptionStub(); + + private MethodDeclarationSyntax CreateUnreachableExceptionStub() + { + // DeclarationCopiedFromBaseDeclaration() => throw new UnreachableException("This method should not be reached"); + return MethodInfo.Syntax + .WithModifiers(TokenList()) + .WithAttributeLists(List()) + .WithExplicitInterfaceSpecifier(ExplicitInterfaceSpecifier( + ParseName(OriginalDeclaringInterface.Info.Type.FullTypeName))) + .WithExpressionBody(ArrowExpressionClause( + ThrowExpression( + ObjectCreationExpression( + ParseTypeName(TypeNames.UnreachableException)) + .WithArgumentList(ArgumentList())))); + } + + private MethodDeclarationSyntax? _shadow; + + public MethodDeclarationSyntax Shadow => _shadow ??= GenerateShadow(); + + private MethodDeclarationSyntax GenerateShadow() + { + // DeclarationCopiedFromBaseDeclaration() + // { + // return (()this).(); + // } + var forwarder = new Forwarder(); + return MethodDeclaration(GenerationContext.SignatureContext.StubReturnType, MethodInfo.MethodName) + .WithModifiers(TokenList(Token(SyntaxKind.NewKeyword))) + .WithParameterList(ParameterList(SeparatedList(GenerationContext.SignatureContext.StubParameters))) + .WithExpressionBody( + ArrowExpressionClause( + InvocationExpression( + MemberAccessExpression( + SyntaxKind.SimpleMemberAccessExpression, + ParenthesizedExpression( + CastExpression(OriginalDeclaringInterface.Info.Type.Syntax, IdentifierName("this"))), + IdentifierName(MethodInfo.MethodName)), + ArgumentList( + SeparatedList(GenerationContext.SignatureContext.ManagedParameters.Select(p => forwarder.AsArgument(p, new ManagedStubCodeContext()))))))) + .WithSemicolonToken(Token(SyntaxKind.SemicolonToken)); + } - public MethodDeclarationSyntax GenerateShadow() + /// + /// Returns a flat list of and its owning interface that represents all declared methods and inherited methods. + /// Guarantees the output will be sorted by order of interface input order, then by vtable order. + /// + public static List<(ComInterfaceContext OwningInterface, Builder Method)> CalculateAllMethods(IEnumerable<(ComInterfaceContext, SequenceEqualImmutableArray)> ifaceAndDeclaredMethods, CancellationToken _) + { + // Optimization : This step technically only needs a single interface inheritance hierarchy. + // We can calculate all inheritance chains in a previous step and only pass a single inheritance chain to this method. + // This way, when a single method changes, we would only need to recalculate this for the inheritance chain in which that method exists. + + var ifaceToDeclaredMethodsMap = ifaceAndDeclaredMethods.ToDictionary(static pair => pair.Item1, static pair => pair.Item2); + var allMethodsCache = new Dictionary>(); + var accumulator = new List<(ComInterfaceContext OwningInterface, Builder Method)>(); + foreach (var kvp in ifaceAndDeclaredMethods) { - // DeclarationCopiedFromBaseDeclaration() - // { - // return (()this).(); - // } - var forwarder = new Forwarder(); - return MethodDeclaration(GenerationContext.SignatureContext.StubReturnType, MethodInfo.MethodName) - .WithModifiers(TokenList(Token(SyntaxKind.NewKeyword))) - .WithParameterList(ParameterList(SeparatedList(GenerationContext.SignatureContext.StubParameters))) - .WithExpressionBody( - ArrowExpressionClause( - InvocationExpression( - MemberAccessExpression( - SyntaxKind.SimpleMemberAccessExpression, - ParenthesizedExpression( - CastExpression(OriginalDeclaringInterface.Info.Type.Syntax, IdentifierName("this"))), - IdentifierName(MethodInfo.MethodName)), - ArgumentList( - SeparatedList(GenerationContext.SignatureContext.ManagedParameters.Select(p => forwarder.AsArgument(p, new ManagedStubCodeContext()))))))) - .WithSemicolonToken(Token(SyntaxKind.SemicolonToken)); + var methods = AddMethods(kvp.Item1, kvp.Item2); + foreach (var method in methods) + { + accumulator.Add((kvp.Item1, method)); + } } + return accumulator; /// - /// Returns a flat list of and it's type key owner that represents all declared methods, and inherited methods. - /// Guarantees the output will be sorted by order of interface input order, then by vtable order. + /// Adds methods to a cache and returns inherited and declared methods for the interface in vtable order /// - public static List<(ComInterfaceContext TypeKeyOwner, Builder Method)> CalculateAllMethods(IEnumerable<(ComInterfaceContext, SequenceEqualImmutableArray)> ifaceAndDeclaredMethods, CancellationToken _) + ImmutableArray AddMethods(ComInterfaceContext iface, IEnumerable declaredMethods) { - // Optimization : This step technically only needs a single interface inheritance hierarchy. - // We can calculate all inheritance chains in a previous step and only pass a single inheritance chain to this method. - // This way, when a single method changes, we would only need to recalculate this for the inheritance chain in which that method exists. - - var ifaceToDeclaredMethodsMap = ifaceAndDeclaredMethods.ToDictionary(static pair => pair.Item1, static pair => pair.Item2); - var allMethodsCache = new Dictionary>(); - var accumulator = new List<(ComInterfaceContext TypeKeyOwner, Builder Method)>(); - foreach (var kvp in ifaceAndDeclaredMethods) + if (allMethodsCache.TryGetValue(iface, out var cachedValue)) { - var methods = AddMethods(kvp.Item1, kvp.Item2); - foreach (var method in methods) - { - accumulator.Add((kvp.Item1, method)); - } + return cachedValue; } - return accumulator; - /// - /// Adds methods to a cache and returns inherited and declared methods for the interface in vtable order - /// - ImmutableArray AddMethods(ComInterfaceContext iface, IEnumerable declaredMethods) + int startingIndex = 3; + List methods = new(); + // If we have a base interface, we should add the inherited methods to our list in vtable order + if (iface.Base is not null) { - if (allMethodsCache.TryGetValue(iface, out var cachedValue)) - { - return cachedValue; - } - - int startingIndex = 3; - List methods = new(); - // If we have a base interface, we should add the inherited methods to our list in vtable order - if (iface.Base is not null) + var baseComIface = iface.Base; + ImmutableArray baseMethods; + if (!allMethodsCache.TryGetValue(baseComIface, out var pair)) { - var baseComIface = iface.Base; - ImmutableArray baseMethods; - if (!allMethodsCache.TryGetValue(baseComIface, out var pair)) - { - baseMethods = AddMethods(baseComIface, ifaceToDeclaredMethodsMap[baseComIface]); - } - else - { - baseMethods = pair; - } - methods.AddRange(baseMethods); - startingIndex += baseMethods.Length; + baseMethods = AddMethods(baseComIface, ifaceToDeclaredMethodsMap[baseComIface]); } - // Then we append the declared methods in vtable order - foreach (var method in declaredMethods) + else { - methods.Add(new Builder(iface, method, startingIndex++)); + baseMethods = pair; } - // Cache so we don't recalculate if many interfaces inherit from the same one - var imm = methods.ToImmutableArray(); - allMethodsCache[iface] = imm; - return imm; + methods.AddRange(baseMethods); + startingIndex += baseMethods.Length; + } + // Then we append the declared methods in vtable order + foreach (var method in declaredMethods) + { + methods.Add(new Builder(iface, method, startingIndex++)); } + // Cache so we don't recalculate if many interfaces inherit from the same one + var imm = methods.ToImmutableArray(); + allMethodsCache[iface] = imm; + return imm; } } } diff --git a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComMethodInfo.cs b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComMethodInfo.cs index 38bf19a462f6a5..a2595736973333 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComMethodInfo.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComMethodInfo.cs @@ -1,8 +1,6 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. -using System; -using System.Collections.Generic; using System.Collections.Immutable; using System.Diagnostics; using System.Linq; @@ -13,108 +11,105 @@ namespace Microsoft.Interop { - public sealed partial class ComInterfaceGenerator + /// + /// Represents a method that has been determined to be a COM interface method. Only contains info immediately available from an IMethodSymbol and MethodDeclarationSyntax. + /// + internal sealed record ComMethodInfo( + MethodDeclarationSyntax Syntax, + string MethodName) { /// - /// Represents a method that has been determined to be a COM interface method. Only contains info immediately available from an IMethodSymbol and MethodDeclarationSyntax. + /// Returns a list of tuples of ComMethodInfo, IMethodSymbol, and Diagnostic. If ComMethodInfo is null, Diagnostic will not be null, and vice versa. /// - private sealed record ComMethodInfo( - MethodDeclarationSyntax Syntax, - string MethodName) + public static SequenceEqualImmutableArray<(ComMethodInfo? ComMethod, IMethodSymbol Symbol, Diagnostic? Diagnostic)> GetMethodsFromInterface((ComInterfaceInfo ifaceContext, INamedTypeSymbol ifaceSymbol) data, CancellationToken ct) { - /// - /// Returns a list of tuples of ComMethodInfo, IMethodSymbol, and Diagnostic. If ComMethodInfo is null, Diagnostic will not be null, and vice versa. - /// - public static SequenceEqualImmutableArray<(ComMethodInfo? ComMethod, IMethodSymbol Symbol, Diagnostic? Diagnostic)> GetMethodsFromInterface((ComInterfaceInfo ifaceContext, INamedTypeSymbol ifaceSymbol) data, CancellationToken ct) + var methods = ImmutableArray.CreateBuilder<(ComMethodInfo, IMethodSymbol, Diagnostic?)>(); + foreach (var member in data.ifaceSymbol.GetMembers()) { - var methods = ImmutableArray.CreateBuilder<(ComMethodInfo, IMethodSymbol, Diagnostic?)>(); - foreach (var member in data.ifaceSymbol.GetMembers()) + if (IsComMethodCandidate(member)) { - if (IsComMethodCandidate(member)) - { - methods.Add(CalculateMethodInfo(data.ifaceContext, (IMethodSymbol)member, ct)); - } + methods.Add(CalculateMethodInfo(data.ifaceContext, (IMethodSymbol)member, ct)); } - return methods.ToImmutable().ToSequenceEqual(); } + return methods.ToImmutable().ToSequenceEqual(); + } - private static Diagnostic? GetDiagnosticIfInvalidMethodForGeneration(MethodDeclarationSyntax comMethodDeclaringSyntax, IMethodSymbol method) + private static Diagnostic? GetDiagnosticIfInvalidMethodForGeneration(MethodDeclarationSyntax comMethodDeclaringSyntax, IMethodSymbol method) + { + // Verify the method has no generic types or defined implementation + // and is not marked static or sealed + if (comMethodDeclaringSyntax.TypeParameterList is not null + || comMethodDeclaringSyntax.Body is not null + || comMethodDeclaringSyntax.Modifiers.Any(SyntaxKind.SealedKeyword)) { - // Verify the method has no generic types or defined implementation - // and is not marked static or sealed - if (comMethodDeclaringSyntax.TypeParameterList is not null - || comMethodDeclaringSyntax.Body is not null - || comMethodDeclaringSyntax.Modifiers.Any(SyntaxKind.SealedKeyword)) - { - return Diagnostic.Create(GeneratorDiagnostics.InvalidAttributedMethodSignature, comMethodDeclaringSyntax.Identifier.GetLocation(), method.Name); - } - - // Verify the method does not have a ref return - if (method.ReturnsByRef || method.ReturnsByRefReadonly) - { - return Diagnostic.Create(GeneratorDiagnostics.ReturnConfigurationNotSupported, comMethodDeclaringSyntax.Identifier.GetLocation(), "ref return", method.ToDisplayString()); - } - - return null; + return Diagnostic.Create(GeneratorDiagnostics.InvalidAttributedMethodSignature, comMethodDeclaringSyntax.Identifier.GetLocation(), method.Name); } - private static bool IsComMethodCandidate(ISymbol member) + // Verify the method does not have a ref return + if (method.ReturnsByRef || method.ReturnsByRefReadonly) { - return member.Kind == SymbolKind.Method && !member.IsStatic; + return Diagnostic.Create(GeneratorDiagnostics.ReturnConfigurationNotSupported, comMethodDeclaringSyntax.Identifier.GetLocation(), "ref return", method.ToDisplayString()); } - private static (ComMethodInfo?, IMethodSymbol, Diagnostic?) CalculateMethodInfo(ComInterfaceInfo ifaceContext, IMethodSymbol method, CancellationToken ct) - { - ct.ThrowIfCancellationRequested(); - Debug.Assert(IsComMethodCandidate(method)); + return null; + } - // We only support methods that are defined in the same partial interface definition as the - // [GeneratedComInterface] attribute. - // This restriction not only makes finding the syntax for a given method cheaper, - // but it also enables us to ensure that we can determine vtable method order easily. - Location interfaceLocation = ifaceContext.Declaration.GetLocation(); - Location? methodLocationInAttributedInterfaceDeclaration = null; - foreach (var methodLocation in method.Locations) - { - if (methodLocation.SourceTree == interfaceLocation.SourceTree - && interfaceLocation.SourceSpan.Contains(methodLocation.SourceSpan)) - { - methodLocationInAttributedInterfaceDeclaration = methodLocation; - break; - } - } - // TODO: this should cause a diagnostic - if (methodLocationInAttributedInterfaceDeclaration is null) - { - return (null, method, Diagnostic.Create(GeneratorDiagnostics.CannotAnalyzeMethodPattern, method.Locations.FirstOrDefault(), method.ToDisplayString())); - } + private static bool IsComMethodCandidate(ISymbol member) + { + return member.Kind == SymbolKind.Method && !member.IsStatic; + } + private static (ComMethodInfo?, IMethodSymbol, Diagnostic?) CalculateMethodInfo(ComInterfaceInfo ifaceContext, IMethodSymbol method, CancellationToken ct) + { + ct.ThrowIfCancellationRequested(); + Debug.Assert(IsComMethodCandidate(method)); - // Find the matching declaration syntax - MethodDeclarationSyntax? comMethodDeclaringSyntax = null; - foreach (var declaringSyntaxReference in method.DeclaringSyntaxReferences) - { - var declaringSyntax = declaringSyntaxReference.GetSyntax(ct); - Debug.Assert(declaringSyntax.IsKind(SyntaxKind.MethodDeclaration)); - if (declaringSyntax.GetLocation().SourceSpan.Contains(methodLocationInAttributedInterfaceDeclaration.SourceSpan)) - { - comMethodDeclaringSyntax = (MethodDeclarationSyntax)declaringSyntax; - break; - } - } - if (comMethodDeclaringSyntax is null) + // We only support methods that are defined in the same partial interface definition as the + // [GeneratedComInterface] attribute. + // This restriction not only makes finding the syntax for a given method cheaper, + // but it also enables us to ensure that we can determine vtable method order easily. + Location interfaceLocation = ifaceContext.Declaration.GetLocation(); + Location? methodLocationInAttributedInterfaceDeclaration = null; + foreach (var methodLocation in method.Locations) + { + if (methodLocation.SourceTree == interfaceLocation.SourceTree + && interfaceLocation.SourceSpan.Contains(methodLocation.SourceSpan)) { - return (null, method, Diagnostic.Create(GeneratorDiagnostics.CannotAnalyzeMethodPattern, method.Locations.FirstOrDefault(), method.ToDisplayString())); + methodLocationInAttributedInterfaceDeclaration = methodLocation; + break; } + } + + if (methodLocationInAttributedInterfaceDeclaration is null) + { + return (null, method, Diagnostic.Create(GeneratorDiagnostics.MethodNotDeclaredInAttributedInterface, method.Locations.FirstOrDefault(), method.ToDisplayString())); + } - var diag = GetDiagnosticIfInvalidMethodForGeneration(comMethodDeclaringSyntax, method); - if (diag is not null) + + // Find the matching declaration syntax + MethodDeclarationSyntax? comMethodDeclaringSyntax = null; + foreach (var declaringSyntaxReference in method.DeclaringSyntaxReferences) + { + var declaringSyntax = declaringSyntaxReference.GetSyntax(ct); + Debug.Assert(declaringSyntax.IsKind(SyntaxKind.MethodDeclaration)); + if (declaringSyntax.GetLocation().SourceSpan.Contains(methodLocationInAttributedInterfaceDeclaration.SourceSpan)) { - return (null, method, diag); + comMethodDeclaringSyntax = (MethodDeclarationSyntax)declaringSyntax; + break; } - var comMethodInfo = new ComMethodInfo(comMethodDeclaringSyntax, method.Name); - return (comMethodInfo, method, null); } + if (comMethodDeclaringSyntax is null) + { + return (null, method, Diagnostic.Create(GeneratorDiagnostics.CannotAnalyzeMethodPattern, method.Locations.FirstOrDefault(), method.ToDisplayString())); + } + + var diag = GetDiagnosticIfInvalidMethodForGeneration(comMethodDeclaringSyntax, method); + if (diag is not null) + { + return (null, method, diag); + } + var comMethodInfo = new ComMethodInfo(comMethodDeclaringSyntax, method.Name); + return (comMethodInfo, method, null); } } } diff --git a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/GeneratedStubCodeContext.cs b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/GeneratedStubCodeContext.cs new file mode 100644 index 00000000000000..6f0966ed776505 --- /dev/null +++ b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/GeneratedStubCodeContext.cs @@ -0,0 +1,14 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp.Syntax; + +namespace Microsoft.Interop +{ + internal sealed record GeneratedStubCodeContext( + ManagedTypeInfo OriginalDefiningType, + ContainingSyntaxContext ContainingSyntaxContext, + SyntaxEquivalentNode Stub, + SequenceEqualImmutableArray Diagnostics) : GeneratedMethodContextBase(OriginalDefiningType, Diagnostics); +} diff --git a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/IncrementalMethodStubGenerationContext.cs b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/IncrementalMethodStubGenerationContext.cs index ff3d6f32d5f23a..fae4de0e1167c5 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/IncrementalMethodStubGenerationContext.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/IncrementalMethodStubGenerationContext.cs @@ -4,7 +4,6 @@ using Microsoft.CodeAnalysis; using Microsoft.CodeAnalysis.CSharp.Syntax; using System; -using System.Collections.Immutable; namespace Microsoft.Interop { diff --git a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/IncrementalValuesProviderExtensions.cs b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/IncrementalValuesProviderExtensions.cs index fb0dd80ca7f1a1..826aa4e01112e5 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/IncrementalValuesProviderExtensions.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/IncrementalValuesProviderExtensions.cs @@ -2,9 +2,7 @@ // The .NET Foundation licenses this file to you under the MIT license. using System; -using System.Collections.Generic; using System.Collections.Immutable; -using System.Text; using Microsoft.CodeAnalysis; namespace Microsoft.Interop diff --git a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/InlinedTypes.cs b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/InlinedTypes.cs deleted file mode 100644 index cb55e36b32cb95..00000000000000 --- a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/InlinedTypes.cs +++ /dev/null @@ -1,124 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. - -using Microsoft.CodeAnalysis; -using Microsoft.CodeAnalysis.CSharp; -using Microsoft.CodeAnalysis.CSharp.Syntax; -using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory; - -namespace Microsoft.Interop -{ - internal static class InlinedTypes - { - /// - /// Returns the ClassDeclarationSyntax for: - /// - /// public sealed unsafe class ComWrappersUnwrapper : IUnmanagedObjectUnwrapper - /// { - /// public static object GetObjectForUnmanagedWrapper(void* ptr) - /// { - /// return ComWrappers.ComInterfaceDispatch.GetInstance((ComWrappers.ComInterfaceDispatch*)ptr); - /// } - /// } - /// - /// - public static ClassDeclarationSyntax ComWrappersUnwrapper { get; } = GetComWrappersUnwrapper(); - - public static ClassDeclarationSyntax GetComWrappersUnwrapper() - { - return ClassDeclaration("ComWrappersUnwrapper") - .AddModifiers(Token(SyntaxKind.SealedKeyword), - Token(SyntaxKind.UnsafeKeyword), - Token(SyntaxKind.StaticKeyword), - Token(SyntaxKind.FileKeyword)) - .AddMembers( - MethodDeclaration( - PredefinedType(Token(SyntaxKind.ObjectKeyword)), - Identifier("GetComObjectForUnmanagedWrapper")) - .AddModifiers(Token(SyntaxKind.PublicKeyword), - Token(SyntaxKind.StaticKeyword)) - .AddParameterListParameters( - Parameter(Identifier("ptr")) - .WithType(PointerType(PredefinedType(Token(SyntaxKind.VoidKeyword))))) - .WithBody(body: Body())); - - static BlockSyntax Body() - { - var invocation = InvocationExpression( - MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, - MemberAccessExpression( - SyntaxKind.SimpleMemberAccessExpression, - IdentifierName("ComWrappers"), - IdentifierName("ComInterfaceDispatch")), - GenericName( - Identifier("GetInstance"), - TypeArgumentList( - SeparatedList( - new[] { PredefinedType(Token(SyntaxKind.ObjectKeyword)) }))))) - .AddArgumentListArguments( - Argument( - null, - Token(SyntaxKind.None), - CastExpression( - PointerType( - QualifiedName( - IdentifierName("ComWrappers"), - IdentifierName("ComInterfaceDispatch"))), - IdentifierName("ptr")))); - - return Block(ReturnStatement(invocation)); - } - } - - /// - /// - /// file static class UnmanagedObjectUnwrapper - /// { - /// public static object GetObjectForUnmanagedWrapper(void* ptr) where T : IUnmanagedObjectUnwrapper - /// { - /// return T.GetObjectForUnmanagedWrapper(ptr); - /// } - /// } - /// - /// - public static ClassDeclarationSyntax UnmanagedObjectUnwrapper { get; } = GetUnmanagedObjectUnwrapper(); - - private static ClassDeclarationSyntax GetUnmanagedObjectUnwrapper() - { - const string tUnwrapper = "TUnwrapper"; - return ClassDeclaration("UnmanagedObjectUnwrapper") - .AddModifiers(Token(SyntaxKind.FileKeyword), - Token(SyntaxKind.StaticKeyword)) - .AddMembers( - MethodDeclaration( - PredefinedType(Token(SyntaxKind.ObjectKeyword)), - Identifier("GetObjectForUnmanagedWrapper")) - .AddModifiers(Token(SyntaxKind.PublicKeyword), - Token(SyntaxKind.StaticKeyword)) - .AddTypeParameterListParameters( - TypeParameter(Identifier(tUnwrapper))) - .AddParameterListParameters( - Parameter(Identifier("ptr")) - .WithType(PointerType(PredefinedType(Token(SyntaxKind.VoidKeyword))))) - .AddConstraintClauses(TypeParameterConstraintClause(IdentifierName(tUnwrapper)) - .AddConstraints(TypeConstraint(ParseTypeName(TypeNames.IUnmanagedObjectUnwrapper)))) - .WithBody(body: Body())); - - static BlockSyntax Body() - { - var invocation = InvocationExpression( - MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, - IdentifierName("T"), - IdentifierName("GetObjectForUnmanagedWrapper"))) - .AddArgumentListArguments( - Argument( - null, - Token(SyntaxKind.None), - IdentifierName("ptr"))); - - return Block(ReturnStatement(invocation)); - } - - } - } -} diff --git a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ManagedToNativeVTableMethodGenerator.cs b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ManagedToNativeVTableMethodGenerator.cs index d6c6564cb2afe3..c0fe6e63540025 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ManagedToNativeVTableMethodGenerator.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ManagedToNativeVTableMethodGenerator.cs @@ -4,7 +4,6 @@ using System; using System.Collections.Generic; using System.Collections.Immutable; -using System.Diagnostics; using System.Linq; using Microsoft.CodeAnalysis; using Microsoft.CodeAnalysis.CSharp; diff --git a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/Marshallers/ComInterfaceDispatchMarshallerFactory.cs b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/Marshallers/ComInterfaceDispatchMarshallerFactory.cs index dc8dbe2de5fda2..f44a442d1d5040 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/Marshallers/ComInterfaceDispatchMarshallerFactory.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/Marshallers/ComInterfaceDispatchMarshallerFactory.cs @@ -3,7 +3,6 @@ using System; using System.Collections.Generic; -using System.Text; using Microsoft.CodeAnalysis.CSharp; using Microsoft.CodeAnalysis.CSharp.Syntax; using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory; diff --git a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/SkippedStubContext.cs b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/SkippedStubContext.cs new file mode 100644 index 00000000000000..aa9ef047ba0490 --- /dev/null +++ b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/SkippedStubContext.cs @@ -0,0 +1,10 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Immutable; +using Microsoft.CodeAnalysis; + +namespace Microsoft.Interop +{ + internal sealed record SkippedStubContext(ManagedTypeInfo OriginalDefiningType) : GeneratedMethodContextBase(OriginalDefiningType, new(ImmutableArray.Empty)); +} diff --git a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/UnmanagedToManagedStubGenerator.cs b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/UnmanagedToManagedStubGenerator.cs index 4d62b5d098c96a..7acc249f96eff7 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/UnmanagedToManagedStubGenerator.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/UnmanagedToManagedStubGenerator.cs @@ -4,8 +4,6 @@ using System; using System.Collections.Generic; using System.Collections.Immutable; -using System.Diagnostics; -using System.Linq; using Microsoft.CodeAnalysis; using Microsoft.CodeAnalysis.CSharp; using Microsoft.CodeAnalysis.CSharp.Syntax; diff --git a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/UnreachableException.cs b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/UnreachableException.cs index 203657801ccc8a..9a3052235cad92 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/UnreachableException.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/UnreachableException.cs @@ -2,8 +2,6 @@ // The .NET Foundation licenses this file to you under the MIT license. using System; -using System.Collections.Generic; -using System.Text; namespace Microsoft.Interop { diff --git a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/VirtualMethodPointerStubGenerator.cs b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/VirtualMethodPointerStubGenerator.cs index 35f610f66e3ef5..e813c8fb906e6e 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/VirtualMethodPointerStubGenerator.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/VirtualMethodPointerStubGenerator.cs @@ -5,7 +5,6 @@ using System.Collections.Generic; using System.Collections.Immutable; using System.Linq; -using System.Text; using Microsoft.CodeAnalysis.CSharp.Syntax; using Microsoft.CodeAnalysis.CSharp; using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory; diff --git a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/VtableIndexStubGenerator.cs b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/VtableIndexStubGenerator.cs index 1d161ff1d5f2be..43600ca57e37e5 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/VtableIndexStubGenerator.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/VtableIndexStubGenerator.cs @@ -2,12 +2,10 @@ // The .NET Foundation licenses this file to you under the MIT license. using System; -using System.Collections.Generic; using System.Collections.Immutable; using System.Diagnostics; using System.Linq; using System.Threading; -using System.Xml.Linq; using Microsoft.CodeAnalysis; using Microsoft.CodeAnalysis.CSharp; using Microsoft.CodeAnalysis.CSharp.Syntax; diff --git a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/VtableIndexStubGeneratorHelpers.cs b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/VtableIndexStubGeneratorHelpers.cs index a421cd26e7e15f..204c63fc7ba9ca 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/VtableIndexStubGeneratorHelpers.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/VtableIndexStubGeneratorHelpers.cs @@ -2,9 +2,7 @@ // The .NET Foundation licenses this file to you under the MIT license. using System; -using System.Collections.Generic; using System.Linq; -using System.Text; using Microsoft.CodeAnalysis; namespace Microsoft.Interop