Skip to content

Commit

Permalink
Improve code location and replace shared ptr aliases (#82)
Browse files Browse the repository at this point in the history
  • Loading branch information
reble authored Mar 3, 2023
1 parent 62d6b15 commit 66d1b6b
Show file tree
Hide file tree
Showing 6 changed files with 92 additions and 87 deletions.
10 changes: 6 additions & 4 deletions sycl/include/sycl/detail/common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -104,18 +104,20 @@ struct code_location {

#ifndef DISABLE_SYCL_INSTRUMENTATION_METADATA
#define _CODELOCONLYPARAM(a) \
const detail::code_location a = detail::code_location::current()
const ::sycl::detail::code_location a = \
::sycl::detail::code_location::current()
#define _CODELOCPARAM(a) \
, const detail::code_location a = detail::code_location::current()
#define _CODELOCPARAMDEF(a) , const detail::code_location a
, const ::sycl::detail::code_location a = \
::sycl::detail::code_location::current()
#define _CODELOCPARAMDEF(a) , const ::sycl::detail::code_location a

#define _CODELOCARG(a)
#define _CODELOCFW(a) , a
#else
#define _CODELOCONLYPARAM(a)
#define _CODELOCPARAM(a)

#define _CODELOCARG(a) const detail::code_location a = {}
#define _CODELOCARG(a) const ::sycl::detail::code_location a = {}
#define _CODELOCFW(a)
#endif

Expand Down
17 changes: 8 additions & 9 deletions sycl/include/sycl/ext/oneapi/experimental/graph.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,6 @@ namespace detail {
struct node_impl;
struct graph_impl;

using node_ptr = std::shared_ptr<node_impl>;
using graph_ptr = std::shared_ptr<graph_impl>;
} // namespace detail

enum class graph_state {
Expand All @@ -39,16 +37,16 @@ enum class graph_state {

class __SYCL_EXPORT node {
private:
node(detail::node_ptr Impl) : impl(Impl) {}
node(const std::shared_ptr<detail::node_impl> &Impl) : impl(Impl) {}

template <class Obj>
friend decltype(Obj::impl)
sycl::detail::getSyclObjImpl(const Obj &SyclObject);
template <class T>
friend T sycl::detail::createSyclObjFromImpl(decltype(T::impl) ImplObj);

detail::node_ptr impl;
detail::graph_ptr MGraph;
std::shared_ptr<detail::node_impl> impl;
std::shared_ptr<detail::graph_impl> MGraph;
};

template <graph_state State = graph_state::modifiable>
Expand Down Expand Up @@ -106,7 +104,7 @@ class __SYCL_EXPORT command_graph {
bool end_recording(const std::vector<queue> &recordingQueues);

private:
command_graph(detail::graph_ptr Impl) : impl(Impl) {}
command_graph(const std::shared_ptr<detail::graph_impl> &Impl) : impl(Impl) {}

// Template-less implementation of add()
node add_impl(std::function<void(handler &)> cgf,
Expand All @@ -118,14 +116,15 @@ class __SYCL_EXPORT command_graph {
template <class T>
friend T sycl::detail::createSyclObjFromImpl(decltype(T::impl) ImplObj);

detail::graph_ptr impl;
std::shared_ptr<detail::graph_impl> impl;
};

template <> class __SYCL_EXPORT command_graph<graph_state::executable> {
public:
command_graph() = delete;

command_graph(detail::graph_ptr g, const sycl::context &ctx)
command_graph(const std::shared_ptr<detail::graph_impl> &g,
const sycl::context &ctx)
: MTag(rand()), MCtx(ctx), impl(g) {}

private:
Expand All @@ -135,7 +134,7 @@ template <> class __SYCL_EXPORT command_graph<graph_state::executable> {

int MTag;
const sycl::context &MCtx;
detail::graph_ptr impl;
std::shared_ptr<detail::graph_impl> impl;
};
} // namespace experimental
} // namespace oneapi
Expand Down
30 changes: 13 additions & 17 deletions sycl/include/sycl/queue.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
#include <sycl/property_list.hpp>
#include <sycl/stl.hpp>


// Explicitly request format macros
#ifndef __STDC_FORMAT_MACROS
#define __STDC_FORMAT_MACROS 1
Expand Down Expand Up @@ -1075,9 +1074,9 @@ class __SYCL_EXPORT queue {
/// \return an event representing graph execution operation.
event ext_oneapi_graph(ext::oneapi::experimental::command_graph<
ext::oneapi::experimental::graph_state::executable>
Graph) {
const detail::code_location CodeLoc = {};
return submit([&](handler &CGH) { CGH.ext_oneapi_graph(Graph); }, CodeLoc);
Graph _CODELOCPARAM(&CodeLoc)) {
return submit(
[&](handler &CGH) { CGH.ext_oneapi_graph(Graph); } _CODELOCFW(CodeLoc));
}

/// Shortcut for executing a graph of commands.
Expand All @@ -1090,11 +1089,10 @@ class __SYCL_EXPORT queue {
ext::oneapi::experimental::graph_state::executable>
Graph,
event DepEvent _CODELOCPARAM(&CodeLoc)) {
return submit(
[&](handler &CGH) {
CGH.depends_on(DepEvent);
CGH.ext_oneapi_graph(Graph);
} _CODELOCFW(CodeLoc));
return submit([&](handler &CGH) {
CGH.depends_on(DepEvent);
CGH.ext_oneapi_graph(Graph);
} _CODELOCFW(CodeLoc));
}

/// Shortcut for executing a graph of commands.
Expand All @@ -1106,14 +1104,12 @@ class __SYCL_EXPORT queue {
event ext_oneapi_graph(ext::oneapi::experimental::command_graph<
ext::oneapi::experimental::graph_state::executable>
Graph,
const std::vector<event> &DepEvents) {
const detail::code_location CodeLoc = {};
return submit(
[&](handler &CGH) {
CGH.depends_on(DepEvents);
CGH.ext_oneapi_graph(Graph);
},
CodeLoc);
const std::vector<event> &DepEvents
_CODELOCPARAM(&CodeLoc)) {
return submit([&](handler &CGH) {
CGH.depends_on(DepEvents);
CGH.ext_oneapi_graph(Graph);
} _CODELOCFW(CodeLoc));
}

/// Returns whether the queue is in order or OoO
Expand Down
53 changes: 28 additions & 25 deletions sycl/source/detail/graph_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,12 @@
namespace sycl {
__SYCL_INLINE_VER_NAMESPACE(_V1) {

namespace detail {
struct queue_impl;
using queue_ptr = std::shared_ptr<queue_impl>;
} // namespace detail

namespace ext {
namespace oneapi {
namespace experimental {
namespace detail {

void graph_impl::exec(const sycl::detail::queue_ptr &q) {
void graph_impl::exec(const std::shared_ptr<sycl::detail::queue_impl> &q) {
if (MSchedule.empty()) {
for (auto n : MRoots) {
n->topology_sort(MSchedule);
Expand All @@ -33,7 +28,8 @@ void graph_impl::exec(const sycl::detail::queue_ptr &q) {
n->exec(q);
}

void graph_impl::exec_and_wait(const sycl::detail::queue_ptr &q) {
void graph_impl::exec_and_wait(
const std::shared_ptr<sycl::detail::queue_impl> &q) {
bool isSubGraph = q->getIsGraphSubmitting();
if (!isSubGraph) {
q->setIsGraphSubmitting(true);
Expand All @@ -48,14 +44,14 @@ void graph_impl::exec_and_wait(const sycl::detail::queue_ptr &q) {
}
}

void graph_impl::add_root(node_ptr n) {
void graph_impl::add_root(const std::shared_ptr<node_impl> &n) {
MRoots.insert(n);
for (auto n : MSchedule)
n->MScheduled = false;
MSchedule.clear();
}

void graph_impl::remove_root(node_ptr n) {
void graph_impl::remove_root(const std::shared_ptr<node_impl> &n) {
MRoots.erase(n);
for (auto n : MSchedule)
n->MScheduled = false;
Expand All @@ -74,8 +70,10 @@ void graph_impl::remove_root(node_ptr n) {
//
// @returns True if a dependency was added in this node of any of its
// successors.
bool check_for_arg(const sycl::detail::ArgDesc &arg, node_ptr currentNode,
std::set<node_ptr> &deps, bool dereferencePtr = false) {
bool check_for_arg(const sycl::detail::ArgDesc &arg,
const std::shared_ptr<node_impl> &currentNode,
std::set<std::shared_ptr<node_impl>> &deps,
bool dereferencePtr = false) {
bool successorAddedDep = false;
for (auto &successor : currentNode->MSuccessors) {
successorAddedDep |= check_for_arg(arg, successor, deps, dereferencePtr);
Expand All @@ -90,14 +88,16 @@ bool check_for_arg(const sycl::detail::ArgDesc &arg, node_ptr currentNode,
}

template <typename T>
node_ptr graph_impl::add(graph_ptr impl, T cgf,
const std::vector<sycl::detail::ArgDesc> &args,
const std::vector<node_ptr> &dep) {
node_ptr nodeImpl = std::make_shared<node_impl>(impl, cgf, args);
std::shared_ptr<node_impl>
graph_impl::add(const std::shared_ptr<graph_impl> &impl, T cgf,
const std::vector<sycl::detail::ArgDesc> &args,
const std::vector<std::shared_ptr<node_impl>> &dep) {
std::shared_ptr<node_impl> nodeImpl =
std::make_shared<node_impl>(impl, cgf, args);
// Copy deps so we can modify them
auto deps = dep;
// A unique set of dependencies obtained by checking kernel arguments
std::set<node_ptr> uniqueDeps;
std::set<std::shared_ptr<node_impl>> uniqueDeps;
for (auto &arg : args) {
if (arg.MType != sycl::detail::kernel_param_kind_t::kind_pointer) {
continue;
Expand Down Expand Up @@ -133,13 +133,13 @@ bool graph_impl::clear_queues() {
return anyQueuesCleared;
}

void node_impl::exec(sycl::detail::queue_ptr q) {
void node_impl::exec(const std::shared_ptr<sycl::detail::queue_impl> &q
_CODELOCPARAMDEF(&CodeLoc)) {
std::vector<sycl::event> deps;
for (auto i : MPredecessors)
deps.push_back(i->get_event());

const sycl::detail::code_location CodeLoc;
MEvent = q->submit(wrapper{MBody, deps}, q, CodeLoc);
MEvent = q->submit(wrapper{MBody, deps}, q _CODELOCFW(CodeLoc));
}
} // namespace detail

Expand All @@ -151,23 +151,26 @@ command_graph<graph_state::modifiable>::command_graph(
template <>
node command_graph<graph_state::modifiable>::add_impl(
std::function<void(handler &)> cgf, const std::vector<node> &dep) {
std::vector<detail::node_ptr> depImpls;
std::vector<std::shared_ptr<detail::node_impl>> depImpls;
for (auto &d : dep) {
depImpls.push_back(sycl::detail::getSyclObjImpl(d));
}

auto nodeImpl = impl->add(impl, cgf, {}, depImpls);
std::shared_ptr<detail::node_impl> nodeImpl =
impl->add(impl, cgf, {}, depImpls);
return sycl::detail::createSyclObjFromImpl<node>(nodeImpl);
}

template <>
void command_graph<graph_state::modifiable>::make_edge(node sender,
node receiver) {
auto sender_impl = sycl::detail::getSyclObjImpl(sender);
auto receiver_impl = sycl::detail::getSyclObjImpl(receiver);
std::shared_ptr<detail::node_impl> senderImpl =
sycl::detail::getSyclObjImpl(sender);
std::shared_ptr<detail::node_impl> receiverImpl =
sycl::detail::getSyclObjImpl(receiver);

sender_impl->register_successor(receiver_impl); // register successor
impl->remove_root(receiver_impl); // remove receiver from root node list
senderImpl->register_successor(receiverImpl); // register successor
impl->remove_root(receiverImpl); // remove receiver from root node list
}

template <>
Expand Down
Loading

0 comments on commit 66d1b6b

Please sign in to comment.