Skip to content

Commit

Permalink
Failed attempt JetBrains#1 at modifying EA and DFG
Browse files Browse the repository at this point in the history
  • Loading branch information
KitsuneAlex committed Dec 22, 2024
1 parent 93f58b9 commit 62d5d40
Show file tree
Hide file tree
Showing 7 changed files with 68 additions and 47 deletions.
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@

/*
* Copyright 2010-2023 JetBrains s.r.o. and Kotlin Programming Language contributors.
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
*/
@file:OptIn(ExperimentalForeignApi::class)
package kotlinx.cinterop

import kotlin.native.*
import kotlin.internal.InlineOnly
import kotlin.native.internal.GCUnsafeCall
import kotlin.native.internal.Intrinsic
import kotlin.native.internal.TypedIntrinsic
import kotlin.native.internal.IntrinsicType

Expand All @@ -19,6 +19,14 @@ internal inline val pointerSize: Int
@TypedIntrinsic(IntrinsicType.INTEROP_GET_POINTER_SIZE)
internal external fun getPointerSize(): Int

@PublishedApi
@TypedIntrinsic(IntrinsicType.INTEROP_ALLOCA)
internal external fun alloca(size: Long, align: Int): NativePointed

@PublishedApi
@InlineOnly
internal inline fun <reified T : CVariable> alloca(): T = alloca(sizeOf<T>(), alignOf<T>()).reinterpret()

// TODO: do not use singleton because it leads to init-check on any access.
@PublishedApi
internal object nativeMemUtils {
Expand Down Expand Up @@ -46,8 +54,6 @@ internal object nativeMemUtils {
@TypedIntrinsic(IntrinsicType.INTEROP_READ_PRIMITIVE) external fun getVector(mem: NativePointed): Vector128
@TypedIntrinsic(IntrinsicType.INTEROP_WRITE_PRIMITIVE) external fun putVector(mem: NativePointed, value: Vector128)

@TypedIntrinsic(IntrinsicType.INTEROP_ALLOCA) external fun alloca(size: Long, align: Int): NativePointed

// TODO: optimize
fun getByteArray(source: NativePointed, dest: ByteArray, length: Int) {
val sourceArray = source.reinterpret<ByteVar>().ptr
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -128,13 +128,19 @@ internal data class EscapeAnalysisInput(
get() = irModule
}

internal val EscapeAnalysisPhase = createSimpleNamedCompilerPhase<NativeGenerationState, EscapeAnalysisInput, Map<IrElement, Lifetime>>(
internal data class EscapeAnalysisOutput(
val lifetimes: Map<IrElement, Lifetime> = emptyMap(),
val interopLifetimes: Map<IrElement, Lifetime> = emptyMap()
)

internal val EscapeAnalysisPhase = createSimpleNamedCompilerPhase<NativeGenerationState, EscapeAnalysisInput, EscapeAnalysisOutput>(
name = "EscapeAnalysis",
outputIfNotEnabled = { _, _, _, _ -> emptyMap() },
outputIfNotEnabled = { _, _, _, _ -> EscapeAnalysisOutput() },
preactions = getDefaultIrActions(),
postactions = getDefaultIrActions(),
op = { generationState, input ->
val lifetimes = mutableMapOf<IrElement, Lifetime>()
val lifetimes = HashMap<IrElement, Lifetime>()
val interopLifetimes = HashMap<IrElement, Lifetime>()
val context = generationState.context
val entryPoint = context.ir.symbols.entryPoint?.owner
val nonDevirtualizedCallSitesUnfoldFactor =
Expand All @@ -157,8 +163,8 @@ internal val EscapeAnalysisPhase = createSimpleNamedCompilerPhase<NativeGenerati
DevirtualizationUnfoldFactors.DFG_DEVIRTUALIZED_CALL,
nonDevirtualizedCallSitesUnfoldFactor
).build()
EscapeAnalysis.computeLifetimes(context, generationState, input.moduleDFG, callGraph, lifetimes)
lifetimes
EscapeAnalysis.computeLifetimes(context, generationState, input.moduleDFG, callGraph, lifetimes, interopLifetimes)
EscapeAnalysisOutput(lifetimes, interopLifetimes)
}
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -500,9 +500,9 @@ private fun PhaseEngine<NativeGenerationState>.runCodegen(module: IrModuleFragme
runPhase(CreateLLVMDeclarationsPhase, module)
runPhase(GHAPhase, module, disable = !optimize)
runPhase(RTTIPhase, RTTIInput(module, dceResult))
val lifetimes = runPhase(EscapeAnalysisPhase, EscapeAnalysisInput(module, moduleDFG), disable = !optimize)
runPhase(InteropAllocationsTransformPhase, InteropAllocationsTransformInput(module, lifetimes))
runPhase(CodegenPhase, CodegenInput(module, irBuiltIns, lifetimes))
val eaOutput = runPhase(EscapeAnalysisPhase, EscapeAnalysisInput(module, moduleDFG), disable = !optimize)
runPhase(InteropAllocationsTransformPhase, InteropAllocationsTransformInput(module, eaOutput.interopLifetimes))
runPhase(CodegenPhase, CodegenInput(module, irBuiltIns, eaOutput.lifetimes))
}

private fun PhaseEngine<NativeGenerationState>.findDependenciesToCompile(): List<IrModuleFragment> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ internal enum class IntrinsicType {
INTEROP_NARROW,
INTEROP_STATIC_C_FUNCTION,
INTEROP_FUNPTR_INVOKE,
INTEROP_ALLOCA,
// Worker
WORKER_EXECUTE,
// Atomics
Expand Down Expand Up @@ -250,6 +251,7 @@ internal class IntrinsicGenerator(private val environment: IntrinsicGeneratorEnv
IntrinsicType.INTEROP_NATIVE_PTR_TO_LONG -> emitNativePtrToLong(callSite, args)
IntrinsicType.INTEROP_NATIVE_PTR_PLUS_LONG -> emitNativePtrPlusLong(args)
IntrinsicType.INTEROP_GET_NATIVE_NULL_PTR -> emitGetNativeNullPtr()
IntrinsicType.INTEROP_ALLOCA -> reportNonLoweredIntrinsic(intrinsicType) // TODO: ...
IntrinsicType.IDENTITY -> emitIdentity(args)
IntrinsicType.THE_UNIT_INSTANCE -> theUnitInstanceRef.llvm
IntrinsicType.ATOMIC_GET_FIELD -> reportNonLoweredIntrinsic(intrinsicType)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -408,6 +408,7 @@ internal class FunctionDFGBuilder(private val generationState: NativeGenerationS
private val executeImplProducerInvoke = executeImplProducerClass.simpleFunctions()
.single { it.name == OperatorNameConventions.INVOKE }
private val reinterpret = symbols.reinterpret
private val interopGetPtr = symbols.interopGetPtr
private val saveCoroutineState = symbols.saveCoroutineState
private val restoreCoroutineState = symbols.restoreCoroutineState
private val objCObjectRawValueGetter = symbols.interopObjCObjectRawValueGetter
Expand Down Expand Up @@ -651,7 +652,7 @@ internal class FunctionDFGBuilder(private val generationState: NativeGenerationS
// like a fixed-size object.
DataFlowIR.Node.AllocInstance(symbolTable.mapType(createEmptyStringSymbol.owner.returnType), value)

reinterpret -> getNode(value.extensionReceiver!!).value
reinterpret, interopGetPtr -> getNode(value.extensionReceiver!!).value

initInstanceSymbol -> error("Should've been lowered: ${value.render()}")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,7 @@ internal object EscapeAnalysis {
val callGraph: CallGraph,
val moduleDFG: ModuleDFG,
val lifetimes: MutableMap<IrElement, Lifetime>,
val interopLifetimes: MutableMap<IrElement, Lifetime>,
val propagateExiledToHeapObjects: Boolean
) {

Expand Down Expand Up @@ -534,10 +535,15 @@ internal object EscapeAnalysis {
++stats.totalStackAllocsCount
if (!isFilteredOut)
++stats.filteredStackAllocsCount

}

lifetimes[it] = lifetime
}

if (node is DataFlowIR.Node.Call && isInteropAllocCallee(node.callee.irFunction)) {
interopLifetimes[it] = lifetime
}
}
}
}
Expand Down Expand Up @@ -1768,13 +1774,14 @@ internal object EscapeAnalysis {
generationState: NativeGenerationState,
moduleDFG: ModuleDFG,
callGraph: CallGraph,
lifetimes: MutableMap<IrElement, Lifetime>
lifetimes: MutableMap<IrElement, Lifetime>,
interopLifetimes: MutableMap<IrElement, Lifetime>,
) {
assert(lifetimes.isEmpty())

try {
InterproceduralAnalysis(context, generationState, callGraph,
moduleDFG, lifetimes,
moduleDFG, lifetimes, interopLifetimes,
// The GC must be careful not to scan exiled objects, that have already became dead,
// as they may reference other already destroyed stack-allocated objects.
// TODO somehow tag these object, so that GC could handle them properly.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,12 @@ import org.jetbrains.kotlin.ir.declarations.IrModuleFragment
import org.jetbrains.kotlin.ir.declarations.IrSimpleFunction
import org.jetbrains.kotlin.ir.expressions.IrCall
import org.jetbrains.kotlin.ir.expressions.IrExpression
import org.jetbrains.kotlin.ir.expressions.impl.IrCallImpl
import org.jetbrains.kotlin.ir.util.copyValueArgumentsFrom
import org.jetbrains.kotlin.ir.util.kotlinFqName
import org.jetbrains.kotlin.ir.util.parentClassOrNull
import org.jetbrains.kotlin.ir.symbols.IrSimpleFunctionSymbol
import org.jetbrains.kotlin.ir.util.IdSignature
import org.jetbrains.kotlin.ir.visitors.IrElementTransformerVoid
import org.jetbrains.kotlin.ir.visitors.transformChildrenVoid
import org.jetbrains.kotlin.name.ClassId
import org.jetbrains.kotlin.name.FqName
import org.jetbrains.kotlin.name.Name
import org.jetbrains.kotlin.utils.findIsInstanceAnd

/**
* A simple transformation pass after escape analysis which transforms
Expand All @@ -35,22 +31,38 @@ import org.jetbrains.kotlin.utils.findIsInstanceAnd
*/

private val interopPackageName: FqName = FqName("kotlinx.cinterop")
private val nativePlacementName: FqName = ClassId(interopPackageName, Name.identifier("NativePlacement")).asSingleFqName()
private val arenaBaseName: FqName = ClassId(interopPackageName, Name.identifier("ArenaBase")).asSingleFqName()
private val allocFunctionName: Name = Name.identifier("alloc")
private val nativeMemUtilsName: FqName = ClassId(interopPackageName, Name.identifier("nativeMemUtils")).asSingleFqName()
private val allocaFunctionName: Name = Name.identifier("alloca")

private val allocaFunctionSignature: IdSignature = IdSignature.CommonSignature(
interopPackageName.asString(),
"${interopPackageName.asString()}.alloca",
null, 0L, null
)

internal fun isInteropAllocCallee(callee: IrSimpleFunction?): Boolean {
if (callee == null) return false
val signature = callee.symbol.signature ?: return false
return signature.packageFqName() == interopPackageName && allocFunctionName.asString() in callee.name.asString()
}

private class InteropAllocationsTransformer(
generationState: NativeGenerationState,
private val lifetimes: Map<IrElement, Lifetime>
) : IrElementTransformerVoid() {
private val allocaFunction = generationState.context.irModules.values
.flatMap { it.files }
.flatMap { it.declarations }
.findIsInstanceAnd<IrSimpleFunction> {
it.parentClassOrNull?.kotlinFqName == nativeMemUtilsName && it.name == allocaFunctionName
} ?: error("Could not find alloca intrinsic")
private val allocaFunction: IrSimpleFunctionSymbol by lazy {
generationState.context.symbolTable.referenceSimpleFunction(allocaFunctionSignature)
}

private fun computeLifetime(expression: IrExpression): Lifetime {
return lifetimes.getOrDefault(expression, Lifetime.GLOBAL) // TODO: move default to IRRELEVANT once globally changed
}

private fun isApplicableCallSite(callSite: IrCall): Boolean {
val lifetime = computeLifetime(callSite)
return lifetime == Lifetime.LOCAL
|| lifetime == Lifetime.STACK
|| lifetime == Lifetime.ARGUMENT
}

override fun visitElement(element: IrElement): IrElement {
element.transformChildrenVoid(this)
Expand All @@ -60,21 +72,8 @@ private class InteropAllocationsTransformer(
override fun visitCall(expression: IrCall): IrExpression {
super.visitCall(expression)
val callee = expression.symbol.owner
val parentName = callee.parentClassOrNull?.kotlinFqName ?: return expression
// We need to assume that calls may be devirtualized into direct calls to ArenaBase's alloc overload
if ((parentName != nativePlacementName && parentName != arenaBaseName) || callee.name != allocFunctionName) return expression
val lifetime = lifetimes[expression]
// We only care about things that never leave the scope or are exclusively used to pass down
if (lifetime != Lifetime.LOCAL && lifetime != Lifetime.STACK && lifetime != Lifetime.ARGUMENT) return expression
// Transform alloc -> alloca call and copy over value arguments
return IrCallImpl(
expression.startOffset,
expression.endOffset,
expression.type,
allocaFunction.symbol
).apply {
copyValueArgumentsFrom(expression, callee)
}
if (!isApplicableCallSite(expression)) return expression
return expression
}
}

Expand All @@ -91,6 +90,6 @@ internal val InteropAllocationsTransformPhase = createSimpleNamedCompilerPhase<N
preactions = getDefaultIrActions(),
postactions = getDefaultIrActions(),
op = { generationState, input ->
input.irModule.transformChildrenVoid(InteropAllocationsTransformer(generationState, input.lifetimes))
input.irModule.transform(InteropAllocationsTransformer(generationState, input.lifetimes), null)
}
)

0 comments on commit 62d5d40

Please sign in to comment.