Skip to content

Commit

Permalink
Don't inspect the active state during pool free.
Browse files Browse the repository at this point in the history
Avoids querying the current context, which may cause a task switch.
  • Loading branch information
maleadt committed Feb 17, 2022
1 parent 92b5fe2 commit ea43c67
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 4 deletions.
2 changes: 2 additions & 0 deletions lib/cudadrv/context.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ mutable struct CuContext
valid::Bool

function new_unique(handle)
# XXX: this makes it dangerous to call this function from finalizers.
# can we do without the lock?
Base.@lock context_lock get!(valid_contexts, handle) do
new(handle, true)
end
Expand Down
23 changes: 19 additions & 4 deletions src/pool.jl
Original file line number Diff line number Diff line change
Expand Up @@ -308,12 +308,27 @@ Releases a buffer `buf` to the memory pool.
return
end
@inline function _free(buf::Mem.DeviceBuffer; stream::Union{Nothing,CuStream})
state = active_state()
if stream_ordered(state.device)
# NOTE: this function is often called from finalizers, from which we can't switch tasks,
# so we need to take care not to call managed functions (i.e. functions that may
# initialize the CUDA context) because querying the active context using
# `current_context()` takes a lock

# verify that the caller has called `context!` already, which eagerly activates the
# context (i.e. doesn't only set it in the state, but configures the CUDA APIs).
handle_ref = Ref{CUcontext}()
cuCtxGetCurrent(handle_ref)
if buf.ctx.handle != handle_ref[]
error("Trying to free $buf from a different context than the one it was allocated from ($(handle_ref[]))")
end

dev = current_device()
if stream_ordered(dev)
# mark the pool as active
pool_mark(state.device)
pool_mark(dev)

actual_free(buf; stream=something(stream, state.stream))
# for safety, we default to the default stream and force this operation to be ordered
# against all other streams. to opt out of this, pass a specific stream instead.
actual_free(buf; stream=something(stream, default_stream()))
else
actual_free(buf)
end
Expand Down

0 comments on commit ea43c67

Please sign in to comment.