diff --git a/src/coreclr/vm/ceeload.cpp b/src/coreclr/vm/ceeload.cpp index 5e3b27daf6d456..547db8d05971fe 100644 --- a/src/coreclr/vm/ceeload.cpp +++ b/src/coreclr/vm/ceeload.cpp @@ -4659,6 +4659,122 @@ PTR_VOID ReflectionModule::GetRvaField(RVA field) // virtual // VASigCookies // =========================================================================== +static bool TypeSignatureContainsGenericVariables(SigParser& sp); +static bool MethodSignatureContainsGenericVariables(SigParser& sp); + +static bool TypeSignatureContainsGenericVariables(SigParser& sp) +{ + STANDARD_VM_CONTRACT; + + CorElementType et = ELEMENT_TYPE_END; + IfFailThrow(sp.GetElemType(&et)); + + if (CorIsPrimitiveType(et)) + return false; + + switch (et) + { + case ELEMENT_TYPE_OBJECT: + case ELEMENT_TYPE_STRING: + case ELEMENT_TYPE_TYPEDBYREF: + return false; + + case ELEMENT_TYPE_BYREF: + case ELEMENT_TYPE_PTR: + case ELEMENT_TYPE_SZARRAY: + return TypeSignatureContainsGenericVariables(sp); + + case ELEMENT_TYPE_VALUETYPE: + case ELEMENT_TYPE_CLASS: + IfFailThrow(sp.GetToken(NULL)); // Skip RID + return false; + + case ELEMENT_TYPE_FNPTR: + return MethodSignatureContainsGenericVariables(sp); + + case ELEMENT_TYPE_ARRAY: + { + if (TypeSignatureContainsGenericVariables(sp)) + return true; + + uint32_t rank; + IfFailThrow(sp.GetData(&rank)); // Get rank + if (rank) + { + uint32_t nsizes; + IfFailThrow(sp.GetData(&nsizes)); // Get # of sizes + while (nsizes--) + { + IfFailThrow(sp.GetData(NULL)); // Skip size + } + + uint32_t nlbounds; + IfFailThrow(sp.GetData(&nlbounds)); // Get # of lower bounds + while (nlbounds--) + { + IfFailThrow(sp.GetData(NULL)); // Skip lower bounds + } + } + } + return false; + + case ELEMENT_TYPE_GENERICINST: + { + if (TypeSignatureContainsGenericVariables(sp)) + return true; + + uint32_t argCnt; + IfFailThrow(sp.GetData(&argCnt)); // Get number of parameters + while (argCnt--) + { + if (TypeSignatureContainsGenericVariables(sp)) + return true; + } + } + return false; + + case ELEMENT_TYPE_INTERNAL: + IfFailThrow(sp.GetPointer(NULL)); + return false; + + case ELEMENT_TYPE_VAR: + case ELEMENT_TYPE_MVAR: + return true; + + default: + // Return conservative answer for unhandled elements + _ASSERTE(!"Unexpected element type."); + return true; + } +} + +static bool MethodSignatureContainsGenericVariables(SigParser& sp) +{ + STANDARD_VM_CONTRACT; + + uint32_t callConv = 0; + IfFailThrow(sp.GetCallingConvInfo(&callConv)); + + if (callConv & IMAGE_CEE_CS_CALLCONV_GENERIC) + { + // Generic signatures should never show up here, return conservative answer. + _ASSERTE(!"Unexpected generic signature."); + return true; + } + + uint32_t numArgs = 0; + IfFailThrow(sp.GetData(&numArgs)); + + // iterate over the return type and parameters + for (uint32_t i = 0; i <= numArgs; i++) + { + if (TypeSignatureContainsGenericVariables(sp)) + return true; + } + + return false; +} + //========================================================================== // Enregisters a VASig. //========================================================================== @@ -4667,15 +4783,39 @@ VASigCookie *Module::GetVASigCookie(Signature vaSignature, const SigTypeContext* CONTRACT(VASigCookie*) { INSTANCE_CHECK; - THROWS; - GC_TRIGGERS; - MODE_ANY; + STANDARD_VM_CHECK; POSTCONDITION(CheckPointer(RETVAL)); INJECT_FAULT(COMPlusThrowOM()); } CONTRACT_END; - Module* pLoaderModule = ClassLoader::ComputeLoaderModuleWorker(this, mdTokenNil, typeContext->m_classInst, typeContext->m_methodInst); + SigTypeContext emptyContext; + + Module* pLoaderModule = this; + if (!typeContext->IsEmpty()) + { + // Strip the generic context if it is not actually used by the signature. It is nececessary for both: + // - Performance: allow more sharing of vasig cookies + // - Functionality: built-in runtime marshalling is disallowed for generic signatures + SigParser sigParser = vaSignature.CreateSigParser(); + if (MethodSignatureContainsGenericVariables(sigParser)) + { + pLoaderModule = ClassLoader::ComputeLoaderModuleWorker(this, mdTokenNil, typeContext->m_classInst, typeContext->m_methodInst); + } + else + { + typeContext = &emptyContext; + } + } + else + { +#ifdef _DEBUG + // The method signature should not contain any generic variables if the generic context is not provided. + SigParser sigParser = vaSignature.CreateSigParser(); + _ASSERTE(!MethodSignatureContainsGenericVariables(sigParser)); +#endif + } + VASigCookie *pCookie = GetVASigCookieWorker(this, pLoaderModule, vaSignature, typeContext); RETURN pCookie; @@ -4685,9 +4825,7 @@ VASigCookie *Module::GetVASigCookieWorker(Module* pDefiningModule, Module* pLoad { CONTRACT(VASigCookie*) { - THROWS; - GC_TRIGGERS; - MODE_ANY; + STANDARD_VM_CHECK; POSTCONDITION(CheckPointer(RETVAL)); INJECT_FAULT(COMPlusThrowOM()); } diff --git a/src/tests/Interop/MarshalAPI/FunctionPointer/GenericFunctionPointer.cs b/src/tests/Interop/MarshalAPI/FunctionPointer/GenericFunctionPointer.cs index fe6d37f01a98fb..da2fc75d913883 100644 --- a/src/tests/Interop/MarshalAPI/FunctionPointer/GenericFunctionPointer.cs +++ b/src/tests/Interop/MarshalAPI/FunctionPointer/GenericFunctionPointer.cs @@ -29,6 +29,12 @@ static BlittableGeneric UnmanagedExportedFunctionBlittableGenericString( return new() { X = Convert.ToInt32(arg) }; } + [UnmanagedCallersOnly] + static unsafe void UnmanagedExportedFunctionRefInt(int* pval, float arg) + { + *pval = Convert.ToInt32(arg); + } + class GenericCaller { internal static unsafe T GenericCalli(void* fnptr, U arg) @@ -40,6 +46,11 @@ internal static unsafe BlittableGeneric WrappedGenericCalli(void* fnptr, U { return ((delegate* unmanaged>)fnptr)(arg); } + + internal static unsafe void NonGenericCalli(void* fnptr, ref int val, float arg) + { + ((delegate* unmanaged)fnptr)(ref val, arg); + } } struct BlittableGeneric @@ -81,6 +92,14 @@ public static void RunGenericFunctionPointerTest(float inVal) outVar = GenericCaller.WrappedGenericCalli((delegate* unmanaged>)&UnmanagedExportedFunctionBlittableGenericString, inVal).X; } Assert.Equal(expectedValue, outVar); + + outVar = 0; + Console.WriteLine("Testing non-GenericCalli with non-blittable argument in a generic caller"); + unsafe + { + GenericCaller.NonGenericCalli((delegate* unmanaged)&UnmanagedExportedFunctionRefInt, ref outVar, inVal); + } + Assert.Equal(expectedValue, outVar); } [ConditionalFact(nameof(CanRunInvalidGenericFunctionPointerTest))]