Skip to content

Commit

Permalink
[ORC] Implement basic reoptimization.
Browse files Browse the repository at this point in the history
  • Loading branch information
sunho committed Dec 19, 2023
1 parent 6059562 commit 044695d
Show file tree
Hide file tree
Showing 8 changed files with 619 additions and 10 deletions.
6 changes: 3 additions & 3 deletions compiler-rt/lib/orc/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@

/// This macro should be used to define tags that will be associated with
/// handlers in the JIT process, and call can be used to define tags f
#define ORC_RT_JIT_DISPATCH_TAG(X) \
extern "C" char X; \
char X = 0;
#define ORC_RT_JIT_DISPATCH_TAG(X) \
ORC_RT_INTERFACE char X; \
char X = 0;

/// Opaque struct for external symbols.
struct __orc_rt_Opaque {};
Expand Down
1 change: 1 addition & 0 deletions compiler-rt/lib/orc/elfnix_platform.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ using namespace __orc_rt;
using namespace __orc_rt::elfnix;

// Declare function tags for functions in the JIT process.
ORC_RT_JIT_DISPATCH_TAG(__orc_rt_reoptimize_tag)
ORC_RT_JIT_DISPATCH_TAG(__orc_rt_elfnix_get_initializers_tag)
ORC_RT_JIT_DISPATCH_TAG(__orc_rt_elfnix_get_deinitializers_tag)
ORC_RT_JIT_DISPATCH_TAG(__orc_rt_elfnix_symbol_lookup_tag)
Expand Down
179 changes: 179 additions & 0 deletions llvm/include/llvm/ExecutionEngine/Orc/ReOptimizeLayer.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
//===- ReOptimizeLayer.h - Re-optimization layer interface ------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Re-optimization layer interface.
//
//===----------------------------------------------------------------------===//
#ifndef LLVM_EXECUTIONENGINE_ORC_REOPTLAYER_H
#define LLVM_EXECUTIONENGINE_ORC_REOPTLAYER_H

#include "llvm/ExecutionEngine/Orc/Core.h"
#include "llvm/ExecutionEngine/Orc/Layer.h"
#include "llvm/ExecutionEngine/Orc/RedirectionManager.h"
#include "llvm/ExecutionEngine/Orc/ThreadSafeModule.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
#include "llvm/Transforms/Utils/Cloning.h"

namespace llvm {
namespace orc {

using ReOptMaterializationUnitID = uint64_t;

class ReOptimizeLayer : public IRLayer, public ResourceManager {
public:
/// AddProfilerFunc will be called when ReOptimizeLayer emits the first
/// version of a materialization unit in order to inject profiling code and
/// reoptimization request code.
using AddProfilerFunc = unique_function<Error(
ReOptimizeLayer &Parent, ReOptMaterializationUnitID MUID,
unsigned CurVersion, ThreadSafeModule &TSM)>;

/// ReOptimizeFunc will be called when ReOptimizeLayer reoptimization of a
/// materialization unit was requested in order to reoptimize the IR module
/// based on profile data. OldRT is the ResourceTracker that tracks the old
/// function definitions. The OldRT must be kept alive until it can be
/// guaranteed that every invocation of the old function definitions has been
/// terminated.
using ReOptimizeFunc = unique_function<Error(
ReOptimizeLayer &Parent, ReOptMaterializationUnitID MUID,
unsigned CurVersion, ResourceTrackerSP OldRT, ThreadSafeModule &TSM)>;

ReOptimizeLayer(ExecutionSession &ES, IRLayer &BaseLayer,
RedirectableSymbolManager &RM)
: IRLayer(ES, BaseLayer.getManglingOptions()), ES(ES),
BaseLayer(BaseLayer), RSManager(RM), ReOptFunc(identity),
ProfilerFunc(reoptimizeIfCallFrequent) {}

void setReoptimizeFunc(ReOptimizeFunc ReOptFunc) {
this->ReOptFunc = std::move(ReOptFunc);
}

void setAddProfilerFunc(AddProfilerFunc ProfilerFunc) {
this->ProfilerFunc = std::move(ProfilerFunc);
}

/// Registers reoptimize runtime dispatch handlers to given PlatformJD. The
/// reoptimization request will not be handled if dispatch handler is not
/// registered by using this function.
Error reigsterRuntimeFunctions(JITDylib &PlatformJD);

/// Emits the given module. This should not be called by clients: it will be
/// called by the JIT when a definition added via the add method is requested.
void emit(std::unique_ptr<MaterializationResponsibility> R,
ThreadSafeModule TSM) override;

static const uint64_t CallCountThreshold = 10;

/// Basic AddProfilerFunc that reoptimizes the function when the call count
/// exceeds CallCountThreshold.
static Error reoptimizeIfCallFrequent(ReOptimizeLayer &Parent,
ReOptMaterializationUnitID MUID,
unsigned CurVersion,
ThreadSafeModule &TSM);

static Error identity(ReOptimizeLayer &Parent,
ReOptMaterializationUnitID MUID, unsigned CurVersion,
ResourceTrackerSP OldRT, ThreadSafeModule &TSM) {
return Error::success();
}

// Create IR reoptimize request fucntion call.
static void createReoptimizeCall(Module &M, Instruction &IP,
GlobalVariable *ArgBuffer);

Error handleRemoveResources(JITDylib &JD, ResourceKey K) override;
void handleTransferResources(JITDylib &JD, ResourceKey DstK,
ResourceKey SrcK) override;

private:
class ReOptMaterializationUnitState {
public:
ReOptMaterializationUnitState() = default;
ReOptMaterializationUnitState(ReOptMaterializationUnitID ID,
ThreadSafeModule TSM)
: ID(ID), TSM(std::move(TSM)) {}
ReOptMaterializationUnitState(ReOptMaterializationUnitState &&Other)
: ID(Other.ID), TSM(std::move(Other.TSM)), RT(std::move(Other.RT)),
Reoptimizing(std::move(Other.Reoptimizing)),
CurVersion(Other.CurVersion) {}

ReOptMaterializationUnitID getID() { return ID; }

const ThreadSafeModule &getThreadSafeModule() { return TSM; }

ResourceTrackerSP getResourceTracker() {
std::unique_lock<std::mutex> Lock(Mutex);
return RT;
}

void setResourceTracker(ResourceTrackerSP RT) {
std::unique_lock<std::mutex> Lock(Mutex);
this->RT = RT;
}

uint32_t getCurVersion() {
std::unique_lock<std::mutex> Lock(Mutex);
return CurVersion;
}

bool tryStartReoptimize();
void reoptimizeSucceeded();
void reoptimizeFailed();

private:
std::mutex Mutex;
ReOptMaterializationUnitID ID;
ThreadSafeModule TSM;
ResourceTrackerSP RT;
bool Reoptimizing = false;
uint32_t CurVersion = 0;
};

using SPSReoptimizeArgList =
shared::SPSArgList<ReOptMaterializationUnitID, uint32_t>;
using SendErrorFn = unique_function<void(Error)>;

Expected<SymbolMap> emitMUImplSymbols(ReOptMaterializationUnitState &MUState,
uint32_t Version, JITDylib &JD,
ThreadSafeModule TSM);

void rt_reoptimize(SendErrorFn SendResult, ReOptMaterializationUnitID MUID,
uint32_t CurVersion);

static Expected<Constant *>
createReoptimizeArgBuffer(Module &M, ReOptMaterializationUnitID MUID,
uint32_t CurVersion);

ReOptMaterializationUnitState &
createMaterializationUnitState(const ThreadSafeModule &TSM);

void
registerMaterializationUnitResource(ResourceKey Key,
ReOptMaterializationUnitState &State);

ReOptMaterializationUnitState &
getMaterializationUnitState(ReOptMaterializationUnitID MUID);

ExecutionSession &ES;
IRLayer &BaseLayer;
RedirectableSymbolManager &RSManager;

ReOptimizeFunc ReOptFunc;
AddProfilerFunc ProfilerFunc;

std::mutex Mutex;
std::map<ReOptMaterializationUnitID, ReOptMaterializationUnitState> MUStates;
DenseMap<ResourceKey, DenseSet<ReOptMaterializationUnitID>> MUResources;
ReOptMaterializationUnitID NextID = 1;
};

} // namespace orc
} // namespace llvm

#endif
1 change: 1 addition & 0 deletions llvm/lib/ExecutionEngine/Orc/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ add_llvm_component_library(LLVMOrcJIT
ThreadSafeModule.cpp
RedirectionManager.cpp
JITLinkRedirectableSymbolManager.cpp
ReOptimizeLayer.cpp
ADDITIONAL_HEADER_DIRS
${LLVM_MAIN_INCLUDE_DIR}/llvm/ExecutionEngine/Orc

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,21 +37,21 @@ void JITLinkRedirectableSymbolManager::emitRedirectableSymbols(
R->failMaterialization();
return;
}
dbgs() << *K << "\n";
SymbolToStubs[&TargetJD][K] = StubID;
NewSymbolDefs[K] = JumpStubs[StubID];
NewSymbolDefs[K].setFlags(V.getFlags());
Symbols.push_back(K);
AvailableStubs.pop_back();
}

if (auto Err = R->replace(absoluteSymbols(NewSymbolDefs))) {
// FIXME: when this fails we can return stubs to the pool
if (auto Err = redirectInner(TargetJD, InitialDests)) {
ES.reportError(std::move(Err));
R->failMaterialization();
return;
}

if (auto Err = redirectInner(TargetJD, InitialDests)) {
if (auto Err = R->replace(absoluteSymbols(NewSymbolDefs))) {
ES.reportError(std::move(Err));
R->failMaterialization();
return;
Expand Down Expand Up @@ -85,10 +85,8 @@ Error JITLinkRedirectableSymbolManager::redirectInner(
StubHandle StubID = SymbolToStubs[&TargetJD].at(K);
PtrWrites.push_back({StubPointers[StubID].getAddress(), V.getAddress()});
}
if (auto Err = ES.getExecutorProcessControl().getMemoryAccess().writePointers(
PtrWrites))
return Err;
return Error::success();
return ES.getExecutorProcessControl().getMemoryAccess().writePointers(
PtrWrites);
}

Error JITLinkRedirectableSymbolManager::grow(unsigned Need) {
Expand All @@ -113,6 +111,7 @@ Error JITLinkRedirectableSymbolManager::grow(unsigned Need) {
auto &StubsSection =
G->createSection(JumpStubTableName, MemProt::Exec | MemProt::Read);

// FIXME: We can batch the stubs into one block and use address to access them
for (size_t I = OldSize; I < NewSize; I++) {
auto Pointer = AnonymousPtrCreator(*G, PointerSection, nullptr, 0);
if (auto Err = Pointer.takeError())
Expand Down
Loading

0 comments on commit 044695d

Please sign in to comment.