Skip to content

Commit

Permalink
Merge pull request #1383 from JuliaGPU/tb/unify_retry_reclaim
Browse files Browse the repository at this point in the history
Memory pool improvements
  • Loading branch information
maleadt authored Feb 17, 2022
2 parents 0d0c316 + ea43c67 commit 2656e15
Show file tree
Hide file tree
Showing 8 changed files with 127 additions and 129 deletions.
2 changes: 2 additions & 0 deletions lib/cudadrv/context.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ mutable struct CuContext
valid::Bool

function new_unique(handle)
# XXX: this makes it dangerous to call this function from finalizers.
# can we do without the lock?
Base.@lock context_lock get!(valid_contexts, handle) do
new(handle, true)
end
Expand Down
20 changes: 12 additions & 8 deletions lib/cudadrv/memory.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,14 @@ Base.unsafe_convert(T::Type{<:Union{Ptr,CuPtr,CuArrayPtr}}, buf::AbstractBuffer)
A buffer of device memory residing on the GPU.
"""
struct DeviceBuffer <: AbstractBuffer
ctx::CuContext
ptr::CuPtr{Cvoid}
bytesize::Int

async::Bool
end

DeviceBuffer() = DeviceBuffer(CU_NULL, 0, false)
DeviceBuffer() = DeviceBuffer(context(), CU_NULL, 0, false)

Base.pointer(buf::DeviceBuffer) = buf.ptr
Base.sizeof(buf::DeviceBuffer) = buf.bytesize
Expand Down Expand Up @@ -85,7 +86,7 @@ function alloc(::Type{DeviceBuffer}, bytesize::Integer;
CUDA.cuMemAlloc_v2(ptr_ref, bytesize)
end

return DeviceBuffer(reinterpret(CuPtr{Cvoid}, ptr_ref[]), bytesize, async)
return DeviceBuffer(context(), reinterpret(CuPtr{Cvoid}, ptr_ref[]), bytesize, async)
end

function free(buf::DeviceBuffer; stream::Union{Nothing,CuStream}=nothing)
Expand All @@ -109,11 +110,12 @@ end
A buffer of pinned memory on the CPU, possibly accessible on the GPU.
"""
struct HostBuffer <: AbstractBuffer
ctx::CuContext
ptr::Ptr{Cvoid}
bytesize::Int
end

HostBuffer() = HostBuffer(C_NULL, 0)
HostBuffer() = HostBuffer(context(), C_NULL, 0)

Base.pointer(buf::HostBuffer) = buf.ptr
Base.sizeof(buf::HostBuffer) = buf.bytesize
Expand Down Expand Up @@ -157,7 +159,7 @@ function alloc(::Type{HostBuffer}, bytesize::Integer, flags=0)
ptr_ref = Ref{Ptr{Cvoid}}()
CUDA.cuMemHostAlloc(ptr_ref, bytesize, flags)

return HostBuffer(ptr_ref[], bytesize)
return HostBuffer(context(), ptr_ref[], bytesize)
end


Expand All @@ -179,7 +181,7 @@ function register(::Type{HostBuffer}, ptr::Ptr, bytesize::Integer, flags=0)

CUDA.cuMemHostRegister_v2(ptr, bytesize, flags)

return HostBuffer(ptr, bytesize)
return HostBuffer(context(), ptr, bytesize)
end

"""
Expand Down Expand Up @@ -208,11 +210,12 @@ end
A managed buffer that is accessible on both the CPU and GPU.
"""
struct UnifiedBuffer <: AbstractBuffer
ctx::CuContext
ptr::CuPtr{Cvoid}
bytesize::Int
end

UnifiedBuffer() = UnifiedBuffer(CU_NULL, 0)
UnifiedBuffer() = UnifiedBuffer(context(), CU_NULL, 0)

Base.pointer(buf::UnifiedBuffer) = buf.ptr
Base.sizeof(buf::UnifiedBuffer) = buf.bytesize
Expand Down Expand Up @@ -241,7 +244,7 @@ function alloc(::Type{UnifiedBuffer}, bytesize::Integer,
ptr_ref = Ref{CuPtr{Cvoid}}()
CUDA.cuMemAllocManaged(ptr_ref, bytesize, flags)

return UnifiedBuffer(ptr_ref[], bytesize)
return UnifiedBuffer(context(), ptr_ref[], bytesize)
end


Expand Down Expand Up @@ -281,6 +284,7 @@ end
## array buffer

mutable struct ArrayBuffer{T,N} <: AbstractBuffer
ctx::CuContext
ptr::CuArrayPtr{T}
dims::Dims{N}
end
Expand Down Expand Up @@ -342,7 +346,7 @@ function alloc(::Type{<:ArrayBuffer{T}}, dims::Dims{N}) where {T,N}
CUDA.cuArray3DCreate_v2(handle_ref, allocateArray_ref)
ptr = reinterpret(CuArrayPtr{T}, handle_ref[])

return ArrayBuffer{T,N}(ptr, dims)
return ArrayBuffer{T,N}(context(), ptr, dims)
end

function free(buf::ArrayBuffer)
Expand Down
2 changes: 1 addition & 1 deletion lib/cudadrv/module/global.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ struct CuGlobal{T}
if nbytes_ref[] != sizeof(T)
throw(ArgumentError("size of global '$name' does not match type parameter type $T"))
end
buf = Mem.DeviceBuffer(ptr_ref[], nbytes_ref[], false)
buf = Mem.DeviceBuffer(context(), ptr_ref[], nbytes_ref[], false)

return new{T}(buf)
end
Expand Down
10 changes: 6 additions & 4 deletions lib/cufft/fft.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,15 @@ abstract type CuFFTPlan{T<:cufftNumber, K, inplace} <: Plan{T} end
Base.convert(::Type{cufftHandle}, p::CuFFTPlan) = p.handle

function CUDA.unsafe_free!(plan::CuFFTPlan, stream::CuStream=stream())
context!(plan.ctx; skip_destroyed=true) do
cufftDestroy(plan)
end
cufftDestroy(plan)
unsafe_free!(plan.workarea, stream)
end

unsafe_finalize!(plan::CuFFTPlan) = unsafe_free!(plan, default_stream())
function unsafe_finalize!(plan::CuFFTPlan)
context!(plan.ctx; skip_destroyed=true) do
unsafe_free!(plan, default_stream())
end
end

mutable struct cCuFFTPlan{T<:cufftNumber,K,inplace,N} <: CuFFTPlan{T,K,inplace}
handle::cufftHandle
Expand Down
28 changes: 14 additions & 14 deletions src/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,15 @@ export CuArray, CuVector, CuMatrix, CuVecOrMat, cu, is_unified
struct ArrayStorage{B}
buffer::B

ctx::CuContext

# the refcount also encodes the state of the array:
# < 0: unmanaged
# = 0: freed
# > 0: referenced
refcount::Threads.Atomic{Int}
end

ArrayStorage(buf::B, ctx, state::Int) where {B} =
ArrayStorage{B}(buf, ctx, Threads.Atomic{Int}(state))
ArrayStorage(buf::B, state::Int) where {B} =
ArrayStorage{B}(buf, Threads.Atomic{Int}(state))


## array type
Expand All @@ -42,7 +40,7 @@ mutable struct CuArray{T,N,B} <: AbstractGPUArray{T,N}
maxsize
end
buf = alloc(B, bufsize)
storage = ArrayStorage(buf, context(), 1)
storage = ArrayStorage(buf, 1)
obj = new{T,N,B}(storage, maxsize, 0, dims)
finalizer(unsafe_finalize!, obj)
end
Expand Down Expand Up @@ -77,7 +75,7 @@ function unsafe_free!(xs::CuArray, stream::CuStream=stream())

refcount = Threads.atomic_add!(xs.storage.refcount, -1)
if refcount == 1
context!(xs.storage.ctx; skip_destroyed=true) do
context!(context(xs); skip_destroyed=true) do
free(xs.storage.buffer; stream)
end
end
Expand All @@ -99,6 +97,8 @@ function unsafe_finalize!(xs::CuArray)
# streams involved, or by refcounting uses and decrementing that refcount after the
# operation using `cuLaunchHostFunc`. See CUDA.jl#778 and CUDA.jl#780 for details.
unsafe_free!(xs, default_stream())
# NOTE: we don't switch contexts here, but in unsafe_free!, as arrays are refcounted
# and we may not have to free the memory yet.
end


Expand Down Expand Up @@ -196,20 +196,20 @@ function Base.unsafe_wrap(::Union{Type{CuArray},Type{CuArray{T}},Type{CuArray{T,
buf = try
typ = memory_type(ptr)
if is_managed(ptr)
Mem.UnifiedBuffer(ptr, sz)
Mem.UnifiedBuffer(ctx, ptr, sz)
elseif typ == CU_MEMORYTYPE_DEVICE
# TODO: can we identify whether this pointer was allocated asynchronously?
Mem.DeviceBuffer(ptr, sz, false)
Mem.DeviceBuffer(ctx, ptr, sz, false)
elseif typ == CU_MEMORYTYPE_HOST
Mem.HostBuffer(host_pointer(ptr), sz)
Mem.HostBuffer(ctx, host_pointer(ptr), sz)
else
error("Unknown memory type; please file an issue.")
end
catch err
error("Could not identify the buffer type; are you passing a valid CUDA pointer to unsafe_wrap?")
end

storage = ArrayStorage(buf, ctx, own ? 1 : -1)
storage = ArrayStorage(buf, own ? 1 : -1)
CuArray{T, length(dims)}(storage, dims)
end

Expand All @@ -232,12 +232,12 @@ Base.sizeof(x::CuArray) = Base.elsize(x) * length(x)

function context(A::CuArray)
A.storage === nothing && throw(UndefRefError())
return A.storage.ctx
return A.storage.buffer.ctx
end

function device(A::CuArray)
A.storage === nothing && throw(UndefRefError())
return device(A.storage.ctx)
return device(A.storage.buffer.ctx)
end


Expand Down Expand Up @@ -826,14 +826,14 @@ function Base.resize!(A::CuVector{T}, n::Integer) where T
maxsize
end

new_storage = context!(A.storage.ctx) do
new_storage = context!(context(A)) do
buf = alloc(typeof(A.storage.buffer), bufsize)
ptr = convert(CuPtr{T}, buf)
m = min(length(A), n)
if m > 0
unsafe_copyto!(ptr, pointer(A), m)
end
ArrayStorage(buf, A.storage.ctx, 1)
ArrayStorage(buf, 1)
end

unsafe_free!(A)
Expand Down
Loading

0 comments on commit 2656e15

Please sign in to comment.