Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Clean up some code after the interface inheritance work #86347

Merged
merged 1 commit into from
May 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -7,26 +7,23 @@

namespace Microsoft.Interop
{
public sealed partial class ComInterfaceGenerator
/// <summary>
/// Represents an interface and all of the methods that need to be generated for it (methods declared on the interface and methods inherited from base interfaces).
/// </summary>
internal sealed record ComInterfaceAndMethodsContext(ComInterfaceContext Interface, SequenceEqualImmutableArray<ComMethodContext> Methods)
{
// Change Calc all methods to return an ordered list of all the methods and the data in comInterfaceandMethodsContext
// Have a step that runs CalculateMethodStub on each of them.
// Call GroupMethodsByInterfaceForGeneration

/// <summary>
/// Represents an interface and all of the methods that need to be generated for it (methods declared on the interface and methods inherited from base interfaces).
/// COM methods that are declared on the attributed interface declaration.
/// </summary>
private sealed record ComInterfaceAndMethodsContext(ComInterfaceContext Interface, SequenceEqualImmutableArray<ComMethodContext> Methods)
{
// Change Calc all methods to return an ordered list of all the methods and the data in comInterfaceandMethodsContext
// Have a step that runs CalculateMethodStub on each of them.
// Call GroupMethodsByInterfaceForGeneration

/// <summary>
/// COM methods that are declared on the attributed interface declaration.
/// </summary>
public IEnumerable<ComMethodContext> DeclaredMethods => Methods.Where(m => !m.IsInheritedMethod);
public IEnumerable<ComMethodContext> DeclaredMethods => Methods.Where(m => !m.IsInheritedMethod);

/// <summary>
/// COM methods that are declared on an interface the interface inherits from.
/// </summary>
public IEnumerable<ComMethodContext> ShadowingMethods => Methods.Where(m => m.IsInheritedMethod);
}
/// <summary>
/// COM methods that are declared on an interface the interface inherits from.
/// </summary>
public IEnumerable<ComMethodContext> ShadowingMethods => Methods.Where(m => m.IsInheritedMethod);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,51 +7,48 @@

namespace Microsoft.Interop
{
public sealed partial class ComInterfaceGenerator
internal sealed record ComInterfaceContext(ComInterfaceInfo Info, ComInterfaceContext? Base)
{
private sealed record ComInterfaceContext(ComInterfaceInfo Info, ComInterfaceContext? Base)
/// <summary>
/// Takes a list of ComInterfaceInfo, and creates a list of ComInterfaceContext.
/// </summary>
public static ImmutableArray<ComInterfaceContext> GetContexts(ImmutableArray<ComInterfaceInfo> data, CancellationToken _)
{
/// <summary>
/// Takes a list of ComInterfaceInfo, and creates a list of ComInterfaceContext.
/// </summary>
public static ImmutableArray<ComInterfaceContext> GetContexts(ImmutableArray<ComInterfaceInfo> data, CancellationToken _)
Dictionary<string, ComInterfaceInfo> symbolToInterfaceInfoMap = new();
var accumulator = ImmutableArray.CreateBuilder<ComInterfaceContext>(data.Length);
foreach (var iface in data)
{
Dictionary<string, ComInterfaceInfo> symbolToInterfaceInfoMap = new();
var accumulator = ImmutableArray.CreateBuilder<ComInterfaceContext>(data.Length);
foreach (var iface in data)
symbolToInterfaceInfoMap.Add(iface.ThisInterfaceKey, iface);
}
Dictionary<string, ComInterfaceContext> symbolToContextMap = new();

foreach (var iface in data)
{
accumulator.Add(AddContext(iface));
}
return accumulator.MoveToImmutable();

ComInterfaceContext AddContext(ComInterfaceInfo iface)
{
if (symbolToContextMap.TryGetValue(iface.ThisInterfaceKey, out var cachedValue))
{
symbolToInterfaceInfoMap.Add(iface.ThisInterfaceKey, iface);
return cachedValue;
}
Dictionary<string, ComInterfaceContext> symbolToContextMap = new();

foreach (var iface in data)
if (iface.BaseInterfaceKey is null)
{
accumulator.Add(AddContext(iface));
var baselessCtx = new ComInterfaceContext(iface, null);
symbolToContextMap[iface.ThisInterfaceKey] = baselessCtx;
return baselessCtx;
}
return accumulator.MoveToImmutable();

ComInterfaceContext AddContext(ComInterfaceInfo iface)
if (!symbolToContextMap.TryGetValue(iface.BaseInterfaceKey, out var baseContext))
{
if (symbolToContextMap.TryGetValue(iface.ThisInterfaceKey, out var cachedValue))
{
return cachedValue;
}

if (iface.BaseInterfaceKey is null)
{
var baselessCtx = new ComInterfaceContext(iface, null);
symbolToContextMap[iface.ThisInterfaceKey] = baselessCtx;
return baselessCtx;
}

if (!symbolToContextMap.TryGetValue(iface.BaseInterfaceKey, out var baseContext))
{
baseContext = AddContext(symbolToInterfaceInfoMap[iface.BaseInterfaceKey]);
}
var ctx = new ComInterfaceContext(iface, baseContext);
symbolToContextMap[iface.ThisInterfaceKey] = ctx;
return ctx;
baseContext = AddContext(symbolToInterfaceInfoMap[iface.BaseInterfaceKey]);
}
var ctx = new ComInterfaceContext(iface, baseContext);
symbolToContextMap[iface.ThisInterfaceKey] = ctx;
return ctx;
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,7 @@
// The .NET Foundation licenses this file to you under the MIT license.

using System;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Collections.Specialized;
using System.IO;
using System.Linq;
using System.Reflection;
Expand All @@ -19,14 +17,6 @@ namespace Microsoft.Interop
[Generator]
public sealed partial class ComInterfaceGenerator : IIncrementalGenerator
{
private sealed record class GeneratedStubCodeContext(
ManagedTypeInfo OriginalDefiningType,
ContainingSyntaxContext ContainingSyntaxContext,
SyntaxEquivalentNode<MethodDeclarationSyntax> Stub,
SequenceEqualImmutableArray<Diagnostic> Diagnostics) : GeneratedMethodContextBase(OriginalDefiningType, Diagnostics);

private sealed record SkippedStubContext(ManagedTypeInfo OriginalDefiningType) : GeneratedMethodContextBase(OriginalDefiningType, new(ImmutableArray<Diagnostic>.Empty));

public static class StepNames
{
public const string CalculateStubInformation = nameof(CalculateStubInformation);
Expand Down Expand Up @@ -103,11 +93,9 @@ public void Initialize(IncrementalGeneratorInitializationContext context)
{
var ((data, symbolMap), env) = param;
return new ComMethodContext(
data.Method.OriginalDeclaringInterface,
data.TypeKeyOwner,
data.Method.MethodInfo,
data.Method.Index,
CalculateStubInformation(data.Method.MethodInfo.Syntax, symbolMap[data.Method.MethodInfo], data.Method.Index, env, data.TypeKeyOwner.Info.Type, ct));
data.Method,
data.OwningInterface,
CalculateStubInformation(data.Method.MethodInfo.Syntax, symbolMap[data.Method.MethodInfo], data.Method.Index, env, data.OwningInterface.Info.Type, ct));
}).WithTrackingName(StepNames.CalculateStubInformation);

var interfaceAndMethodsContexts = comMethodContexts
Expand All @@ -117,7 +105,7 @@ public void Initialize(IncrementalGeneratorInitializationContext context)

// Generate the code for the managed-to-unmanaged stubs and the diagnostics from code-generation.
context.RegisterDiagnostics(interfaceAndMethodsContexts
.SelectMany((data, ct) => data.DeclaredMethods.SelectMany(m => m.GetManagedToUnmanagedStub().Diagnostics)));
.SelectMany((data, ct) => data.DeclaredMethods.SelectMany(m => m.ManagedToUnmanagedStub.Diagnostics)));
var managedToNativeInterfaceImplementations = interfaceAndMethodsContexts
.Select(GenerateImplementationInterface)
.WithTrackingName(StepNames.GenerateManagedToNativeInterfaceImplementation)
Expand All @@ -126,7 +114,7 @@ public void Initialize(IncrementalGeneratorInitializationContext context)

// Generate the code for the unmanaged-to-managed stubs and the diagnostics from code-generation.
context.RegisterDiagnostics(interfaceAndMethodsContexts
.SelectMany((data, ct) => data.DeclaredMethods.SelectMany(m => m.GetNativeToManagedStub().Diagnostics)));
.SelectMany((data, ct) => data.DeclaredMethods.SelectMany(m => m.UnmanagedToManagedStub.Diagnostics)));
var nativeToManagedVtableMethods = interfaceAndMethodsContexts
.Select(GenerateImplementationVTableMethods)
.WithTrackingName(StepNames.GenerateNativeToManagedVTableMethods)
Expand All @@ -145,11 +133,11 @@ public void Initialize(IncrementalGeneratorInitializationContext context)
.Select((data, ct) =>
{
var context = data.Interface.Info;
var methods = data.ShadowingMethods.Select(m => (MemberDeclarationSyntax)m.GenerateShadow());
var methods = data.ShadowingMethods.Select(m => m.Shadow);
var typeDecl = TypeDeclaration(context.ContainingSyntax.TypeKind, context.ContainingSyntax.Identifier)
.WithModifiers(context.ContainingSyntax.Modifiers)
.WithTypeParameterList(context.ContainingSyntax.TypeParameters)
.WithMembers(List(methods));
.WithMembers(List<MemberDeclarationSyntax>(methods));
return data.Interface.Info.TypeDefinitionContext.WrapMemberInContainingSyntaxWithUnsafeModifier(typeDecl);
})
.WithTrackingName(StepNames.GenerateShadowingMethods)
Expand Down Expand Up @@ -210,33 +198,6 @@ public void Initialize(IncrementalGeneratorInitializationContext context)
});
}

private static string GenerateMarkerInterfaceSource(ComInterfaceInfo iface) => $$"""
file unsafe class InterfaceInformation : global::System.Runtime.InteropServices.Marshalling.IIUnknownInterfaceType
{
public static global::System.Guid Iid => new(new global::System.ReadOnlySpan<byte>(new byte[] { {{string.Join(",", iface.InterfaceId.ToByteArray())}} }));

private static void** m_vtable;

public static void** ManagedVirtualMethodTable
{
get
{
if (m_vtable == null)
{
nint* vtable = (nint*)global::System.Runtime.CompilerServices.RuntimeHelpers.AllocateTypeAssociatedMemory(typeof({{iface.Type.FullTypeName}}), sizeof(nint) * 3);
global::System.Runtime.InteropServices.ComWrappers.GetIUnknownImpl(out vtable[0], out vtable[1], out vtable[2]);
m_vtable = (void**)vtable;
}
return m_vtable;
}
}
}

[global::System.Runtime.InteropServices.DynamicInterfaceCastableImplementation]
file interface InterfaceImplementation : {{iface.Type.FullTypeName}}
{}
""";

private static readonly AttributeSyntax s_iUnknownDerivedAttributeTemplate =
Attribute(
GenericName(TypeNames.IUnknownDerivedAttribute)
Expand All @@ -251,8 +212,7 @@ private static MemberDeclarationSyntax GenerateIUnknownDerivedAttributeApplicati
.WithTypeParameterList(context.ContainingSyntax.TypeParameters)
.AddAttributeLists(AttributeList(SingletonSeparatedList(s_iUnknownDerivedAttributeTemplate))));

// Todo: extract info needed from the IMethodSymbol into MethodInfo and only pass a MethodInfo here
private static IncrementalMethodStubGenerationContext CalculateStubInformation(MethodDeclarationSyntax syntax, IMethodSymbol symbol, int index, StubEnvironment environment, ManagedTypeInfo typeKeyOwner, CancellationToken ct)
private static IncrementalMethodStubGenerationContext CalculateStubInformation(MethodDeclarationSyntax syntax, IMethodSymbol symbol, int index, StubEnvironment environment, ManagedTypeInfo owningInterface, CancellationToken ct)
{
ct.ThrowIfCancellationRequested();
INamedTypeSymbol? lcidConversionAttrType = environment.Compilation.GetTypeByMetadataName(TypeNames.LCIDConversionAttribute);
Expand Down Expand Up @@ -365,7 +325,7 @@ private static IncrementalMethodStubGenerationContext CalculateStubInformation(M
new ComExceptionMarshalling(),
ComInterfaceGeneratorHelpers.CreateGeneratorFactory(environment, MarshalDirection.ManagedToUnmanaged),
ComInterfaceGeneratorHelpers.CreateGeneratorFactory(environment, MarshalDirection.UnmanagedToManaged),
typeKeyOwner,
owningInterface,
declaringType,
generatorDiagnostics.Diagnostics.ToSequenceEqualImmutableArray(),
ComInterfaceDispatchMarshallingInfo.Instance);
Expand Down Expand Up @@ -412,31 +372,32 @@ private static ImmutableArray<ComInterfaceAndMethodsContext> GroupComContextsFor
private static InterfaceDeclarationSyntax GenerateImplementationInterface(ComInterfaceAndMethodsContext interfaceGroup, CancellationToken _)
{
var definingType = interfaceGroup.Interface.Info.Type;
var shadowImplementations = interfaceGroup.ShadowingMethods.Select(m => (Method: m, ManagedToUnmanagedStub: m.GetManagedToUnmanagedStub()))
var shadowImplementations = interfaceGroup.ShadowingMethods.Select(m => (Method: m, ManagedToUnmanagedStub: m.ManagedToUnmanagedStub))
.Where(p => p.ManagedToUnmanagedStub is GeneratedStubCodeContext)
.Select(ctx => ((GeneratedStubCodeContext)ctx.ManagedToUnmanagedStub).Stub.Node
.WithExplicitInterfaceSpecifier(
ExplicitInterfaceSpecifier(ParseName(definingType.FullTypeName))));
var inheritedStubs = interfaceGroup.ShadowingMethods.Select(m => m.GenerateUnreachableExceptionStub());
var inheritedStubs = interfaceGroup.ShadowingMethods.Select(m => m.UnreachableExceptionStub);
return ImplementationInterfaceTemplate
.AddBaseListTypes(SimpleBaseType(definingType.Syntax))
.WithMembers(
List<MemberDeclarationSyntax>(
interfaceGroup.DeclaredMethods
.Select(m => m.GetManagedToUnmanagedStub())
.Select(m => m.ManagedToUnmanagedStub)
.OfType<GeneratedStubCodeContext>()
.Select(ctx => ctx.Stub.Node)
.Concat(shadowImplementations)
.Concat(inheritedStubs)))
.AddAttributeLists(AttributeList(SingletonSeparatedList(Attribute(ParseName(TypeNames.System_Runtime_InteropServices_DynamicInterfaceCastableImplementationAttribute)))));
}

private static InterfaceDeclarationSyntax GenerateImplementationVTableMethods(ComInterfaceAndMethodsContext comInterfaceAndMethods, CancellationToken _)
{
return ImplementationInterfaceTemplate
.WithMembers(
List<MemberDeclarationSyntax>(
comInterfaceAndMethods.DeclaredMethods
.Select(m => m.GetNativeToManagedStub())
.Select(m => m.UnmanagedToManagedStub)
.OfType<GeneratedStubCodeContext>()
.Select(context => context.Stub.Node)));
}
Expand All @@ -447,6 +408,7 @@ private static InterfaceDeclarationSyntax GenerateImplementationVTableMethods(Co

private static readonly MethodDeclarationSyntax CreateManagedVirtualFunctionTableMethodTemplate = MethodDeclaration(VoidStarStarSyntax, CreateManagedVirtualFunctionTableMethodName)
.AddModifiers(Token(SyntaxKind.InternalKeyword), Token(SyntaxKind.StaticKeyword));

private static InterfaceDeclarationSyntax GenerateImplementationVTable(ComInterfaceAndMethodsContext interfaceMethods, CancellationToken _)
{
const string vtableLocalName = "vtable";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,7 @@
// The .NET Foundation licenses this file to you under the MIT license.

using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using Microsoft.CodeAnalysis;

namespace Microsoft.Interop
Expand Down
Loading