Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 13 additions & 10 deletions sycl/source/detail/graph/graph_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,7 @@ graph_impl::graph_impl(const sycl::context &SyclContext,

graph_impl::~graph_impl() {
try {
clearQueues();
clearQueues(false /*Needs lock*/);
for (auto &MemObj : MMemObjs) {
MemObj->markNoLongerBeingUsedInGraph();
}
Expand Down Expand Up @@ -564,17 +564,21 @@ void graph_impl::removeQueue(sycl::detail::queue_impl &RecordingQueue) {
MRecordingQueues.erase(RecordingQueue.weak_from_this());
}

bool graph_impl::clearQueues() {
bool AnyQueuesCleared = false;
for (auto &Queue : MRecordingQueues) {
void graph_impl::clearQueues(bool NeedsLock) {
graph_impl::RecQueuesStorage SwappedQueues;
{
graph_impl::WriteLock Guard(MMutex, std::defer_lock);
if (NeedsLock) {
Guard.lock();
}
std::swap(MRecordingQueues, SwappedQueues);
}

for (auto &Queue : SwappedQueues) {
if (auto ValidQueue = Queue.lock(); ValidQueue) {
ValidQueue->setCommandGraph(nullptr);
AnyQueuesCleared = true;
}
}
MRecordingQueues.clear();

return AnyQueuesCleared;
}

bool graph_impl::checkForCycles() {
Expand Down Expand Up @@ -1970,8 +1974,7 @@ void modifiable_command_graph::begin_recording(
}

void modifiable_command_graph::end_recording() {
graph_impl::WriteLock Lock(impl->MMutex);
impl->clearQueues();
impl->clearQueues(true /*Needs lock*/);
}

void modifiable_command_graph::end_recording(queue &RecordingQueue) {
Expand Down
12 changes: 6 additions & 6 deletions sycl/source/detail/graph/graph_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -190,9 +190,7 @@ class graph_impl : public std::enable_shared_from_this<graph_impl> {

/// Remove all queues which are recording to this graph, also sets all queues
/// cleared back to the executing state.
///
/// @return True if any queues were removed.
bool clearQueues();
void clearQueues(bool NeedsLock);

/// Associate a sycl event with a node in the graph.
/// @param EventImpl Event to associate with a node in map.
Expand Down Expand Up @@ -561,10 +559,12 @@ class graph_impl : public std::enable_shared_from_this<graph_impl> {
/// Device associated with this graph. All graph nodes will execute on this
/// device.
sycl::device MDevice;

using RecQueuesStorage =
std::set<std::weak_ptr<sycl::detail::queue_impl>,
std::owner_less<std::weak_ptr<sycl::detail::queue_impl>>>;
/// Unique set of queues which are currently recording to this graph.
std::set<std::weak_ptr<sycl::detail::queue_impl>,
std::owner_less<std::weak_ptr<sycl::detail::queue_impl>>>
MRecordingQueues;
RecQueuesStorage MRecordingQueues;
/// Map of events to their associated recorded nodes.
std::unordered_map<std::shared_ptr<sycl::detail::event_impl>, node_impl *>
MEventsMap;
Expand Down
6 changes: 4 additions & 2 deletions sycl/source/detail/queue_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -602,7 +602,8 @@ class queue_impl : public std::enable_shared_from_this<queue_impl> {
bool CallerNeedsEvent);

void setCommandGraphUnlocked(
std::shared_ptr<ext::oneapi::experimental::detail::graph_impl> Graph) {
const std::shared_ptr<ext::oneapi::experimental::detail::graph_impl>
&Graph) {
MGraph = Graph;
MExtGraphDeps.reset();

Expand All @@ -614,7 +615,8 @@ class queue_impl : public std::enable_shared_from_this<queue_impl> {
}

void setCommandGraph(
std::shared_ptr<ext::oneapi::experimental::detail::graph_impl> Graph) {
const std::shared_ptr<ext::oneapi::experimental::detail::graph_impl>
&Graph) {
std::lock_guard<std::mutex> Lock(MMutex);
setCommandGraphUnlocked(Graph);
}
Expand Down
Loading