Skip to content

Commit

Permalink
[SYCL] Record & Replay Implementation
Browse files Browse the repository at this point in the history
Implementation of Record & Replay API with tests

Co-authored-by: Ben Tracy <[email protected]>
  • Loading branch information
EwanC and Bensuo authored Feb 27, 2023
1 parent 06c588f commit d4c1ed3
Show file tree
Hide file tree
Showing 9 changed files with 603 additions and 24 deletions.
35 changes: 35 additions & 0 deletions sycl/include/sycl/ext/oneapi/experimental/graph.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ namespace sycl {
__SYCL_INLINE_VER_NAMESPACE(_V1) {

class handler;
class queue;
namespace ext {
namespace oneapi {
namespace experimental {
Expand Down Expand Up @@ -70,6 +71,40 @@ class __SYCL_EXPORT command_graph {
finalize(const sycl::context &syclContext,
const property_list &propList = {}) const;

/// Change the state of a queue to be recording and associate this graph with
/// it.
/// @param recordingQueue The queue to change state on and associate this
/// graph with.
/// @return True if the queue had its state changed from executing to
/// recording.
bool begin_recording(queue recordingQueue);

/// Change the state of multiple queues to be recording and associate this
/// graph with each of them.
/// @param recordingQueues The queues to change state on and associate this
/// graph with.
/// @return True if any queue had its state changed from executing to
/// recording.
bool begin_recording(const std::vector<queue> &recordingQueues);

/// Set all queues currently recording to this graph to the executing state.
/// @return True if any queue had its state changed from recording to
/// executing.
bool end_recording();

/// Set a queues currently recording to this graph to the executing state.
/// @param recordingQueue The queue to change state on.
/// @return True if the queue had its state changed from recording to
/// executing.
bool end_recording(queue recordingQueue);

/// Set multiple queues currently recording to this graph to the executing
/// state.
/// @param recordingQueue The queues to change state on.
/// @return True if any queue had its state changed from recording to
/// executing.
bool end_recording(const std::vector<queue> &recordingQueues);

private:
command_graph(detail::graph_ptr Impl) : impl(Impl) {}

Expand Down
10 changes: 10 additions & 0 deletions sycl/include/sycl/queue.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,16 @@ static event submitAssertCapture(queue &, event &, queue *,
#endif
} // namespace detail

namespace ext {
namespace oneapi {
namespace experimental {
// State of a queue with regards to graph recording,
// returned by info::queue::state
enum class queue_state { executing, recording };
} // namespace experimental
} // namespace oneapi
} // namespace ext

/// Encapsulates a single SYCL queue which schedules kernels on a SYCL device.
///
/// A SYCL queue can be used to submit command groups to be executed by the SYCL
Expand Down
123 changes: 119 additions & 4 deletions sycl/source/detail/graph_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

#include <detail/graph_impl.hpp>
#include <detail/queue_impl.hpp>
#include <sycl/queue.hpp>

namespace sycl {
__SYCL_INLINE_VER_NAMESPACE(_V1) {
Expand Down Expand Up @@ -61,12 +62,56 @@ void graph_impl::remove_root(node_ptr n) {
MSchedule.clear();
}

// Recursive check if a graph node or its successors contains a given kernel
// argument.
//
// @param[in] arg The kernel argument to check for.
// @param[in] currentNode The current graph node being checked.
// @param[in,out] deps The unique list of dependencies which have been
// identified for this arg.
// @param[in] dereferencePtr if true arg comes direct from the handler in which
// case it will need to be deferenced to check actual value.
//
// @returns True if a dependency was added in this node of any of its
// successors.
bool check_for_arg(const sycl::detail::ArgDesc &arg, node_ptr currentNode,
std::set<node_ptr> &deps, bool dereferencePtr = false) {
bool successorAddedDep = false;
for (auto &successor : currentNode->MSuccessors) {
successorAddedDep |= check_for_arg(arg, successor, deps, dereferencePtr);
}

if (deps.find(currentNode) == deps.end() &&
currentNode->has_arg(arg, dereferencePtr) && !successorAddedDep) {
deps.insert(currentNode);
return true;
}
return successorAddedDep;
}

template <typename T>
node_ptr graph_impl::add(graph_ptr impl, T cgf,
const std::vector<sycl::detail::ArgDesc> &args,
const std::vector<node_ptr> &dep) {
node_ptr nodeImpl = std::make_shared<node_impl>(impl, cgf);
if (!dep.empty()) {
for (auto n : dep) {
node_ptr nodeImpl = std::make_shared<node_impl>(impl, cgf, args);
// Copy deps so we can modify them
auto deps = dep;
// A unique set of dependencies obtained by checking kernel arguments
std::set<node_ptr> uniqueDeps;
for (auto &arg : args) {
if (arg.MType != sycl::detail::kernel_param_kind_t::kind_pointer) {
continue;
}
// Look through the graph for nodes which share this argument
for (auto nodePtr : MRoots) {
check_for_arg(arg, nodePtr, uniqueDeps, true);
}
}

// Add any deps determined from arguments into the dependency list
deps.insert(deps.end(), uniqueDeps.begin(), uniqueDeps.end());
if (!deps.empty()) {
for (auto n : deps) {
n->register_successor(nodeImpl); // register successor
this->remove_root(nodeImpl); // remove receiver from root node
// list
Expand All @@ -77,6 +122,17 @@ node_ptr graph_impl::add(graph_ptr impl, T cgf,
return nodeImpl;
}

bool graph_impl::clear_queues() {
bool anyQueuesCleared = false;
for (auto &q : MRecordingQueues) {
q->setCommandGraph(nullptr);
anyQueuesCleared = true;
}
MRecordingQueues.clear();

return anyQueuesCleared;
}

void node_impl::exec(sycl::detail::queue_ptr q) {
std::vector<sycl::event> deps;
for (auto i : MPredecessors)
Expand All @@ -100,7 +156,7 @@ node command_graph<graph_state::modifiable>::add_impl(
depImpls.push_back(sycl::detail::getSyclObjImpl(d));
}

auto nodeImpl = impl->add(impl, cgf, depImpls);
auto nodeImpl = impl->add(impl, cgf, {}, depImpls);
return sycl::detail::createSyclObjFromImpl<node>(nodeImpl);
}

Expand All @@ -121,6 +177,65 @@ command_graph<graph_state::modifiable>::finalize(
return command_graph<graph_state::executable>{this->impl, ctx};
}

template <>
bool command_graph<graph_state::modifiable>::begin_recording(
queue recordingQueue) {
auto queueImpl = sycl::detail::getSyclObjImpl(recordingQueue);
if (queueImpl->getCommandGraph() == nullptr) {
queueImpl->setCommandGraph(impl);
impl->add_queue(queueImpl);
return true;
} else if (queueImpl->getCommandGraph() != impl) {
throw sycl::exception(make_error_code(errc::invalid),
"begin_recording called for a queue which is already "
"recording to a different graph.");
}

// Queue was already recording to this graph.
return false;
}

template <>
bool command_graph<graph_state::modifiable>::begin_recording(
const std::vector<queue> &recordingQueues) {
bool queueStateChanged = false;
for (auto &q : recordingQueues) {
queueStateChanged |= this->begin_recording(q);
}
return queueStateChanged;
}

template <> bool command_graph<graph_state::modifiable>::end_recording() {
return impl->clear_queues();
}

template <>
bool command_graph<graph_state::modifiable>::end_recording(
queue recordingQueue) {
auto queueImpl = sycl::detail::getSyclObjImpl(recordingQueue);
if (queueImpl->getCommandGraph() == impl) {
queueImpl->setCommandGraph(nullptr);
impl->remove_queue(queueImpl);
return true;
} else if (queueImpl->getCommandGraph() != nullptr) {
throw sycl::exception(make_error_code(errc::invalid),
"end_recording called for a queue which is recording "
"to a different graph.");
}

// Queue was not recording to a graph.
return false;
}
template <>
bool command_graph<graph_state::modifiable>::end_recording(
const std::vector<queue> &recordingQueues) {
bool queueStateChanged = false;
for (auto &q : recordingQueues) {
queueStateChanged |= this->end_recording(q);
}
return queueStateChanged;
}

} // namespace experimental
} // namespace oneapi
} // namespace ext
Expand Down
53 changes: 51 additions & 2 deletions sycl/source/detail/graph_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

#pragma once

#include <sycl/detail/cg_types.hpp>
#include <sycl/ext/oneapi/experimental/graph.hpp>
#include <sycl/handler.hpp>

Expand Down Expand Up @@ -53,6 +54,8 @@ struct node_impl {

std::function<void(sycl::handler &)> MBody;

std::vector<sycl::detail::ArgDesc> MArgs;

void exec(sycl::detail::queue_ptr q);

void register_successor(node_ptr n) {
Expand All @@ -65,7 +68,17 @@ struct node_impl {
sycl::event get_event(void) const { return MEvent; }

template <typename T>
node_impl(graph_ptr g, T cgf) : MScheduled(false), MGraph(g), MBody(cgf) {}
node_impl(graph_ptr g, T cgf, const std::vector<sycl::detail::ArgDesc> &args)
: MScheduled(false), MGraph(g), MBody(cgf), MArgs(args) {
for (size_t i = 0; i < MArgs.size(); i++) {
if (MArgs[i].MType == sycl::detail::kernel_param_kind_t::kind_pointer) {
// Make sure we are storing the actual USM pointer for comparison
// purposes, note we couldn't actually submit using these copies of the
// args if subsequent code expects a void**.
MArgs[i].MPtr = *(void **)(MArgs[i].MPtr);
}
}
}

// Recursively adding nodes to execution stack:
void topology_sort(std::list<node_ptr> &schedule) {
Expand All @@ -76,6 +89,20 @@ struct node_impl {
}
schedule.push_front(node_ptr(this));
}

bool has_arg(const sycl::detail::ArgDesc &arg, bool dereferencePtr = false) {
for (auto &nodeArg : MArgs) {
if (arg.MType == nodeArg.MType && arg.MSize == nodeArg.MSize) {
// Args coming directly from the handler will need to be dereferenced
// since they are actually void**
void *incomingPtr = dereferencePtr ? *(void **)arg.MPtr : arg.MPtr;
if (incomingPtr == nodeArg.MPtr) {
return true;
}
}
}
return false;
}
};

struct graph_impl {
Expand All @@ -93,9 +120,31 @@ struct graph_impl {
void remove_root(node_ptr n);

template <typename T>
node_ptr add(graph_ptr impl, T cgf, const std::vector<node_ptr> &dep = {});
node_ptr add(graph_ptr impl, T cgf,
const std::vector<sycl::detail::ArgDesc> &args,
const std::vector<node_ptr> &dep = {});

graph_impl() : MFirst(true) {}

/// Add a queue to the set of queues which are currently recording to this
/// graph.
void add_queue(sycl::detail::queue_ptr recordingQueue) {
MRecordingQueues.insert(recordingQueue);
}

/// Remove a queue from the set of queues which are currently recording to
/// this graph.
void remove_queue(sycl::detail::queue_ptr recordingQueue) {
MRecordingQueues.erase(recordingQueue);
}

/// Remove all queues which are recording to this graph, also sets all queues
/// cleared back to the executing state. \return True if any queues were
/// removed.
bool clear_queues();

private:
std::set<sycl::detail::queue_ptr> MRecordingQueues;
};

} // namespace detail
Expand Down
Loading

0 comments on commit d4c1ed3

Please sign in to comment.