-
Notifications
You must be signed in to change notification settings - Fork 64
/
Copy pathiterator.jl
73 lines (56 loc) · 2.46 KB
/
iterator.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
"""
DeviceIterator(dev::AbstractDevice, iterator)
Create a `DeviceIterator` that iterates through the provided `iterator` via `iterate`. Upon
each iteration, the current batch is copied to the device `dev`, and the previous iteration
is marked as freeable from GPU memory (via `unsafe_free!`) (no-op for a CPU device).
The conversion follows the same semantics as `dev(<item from iterator>)`.
!!! tip "Similarity to `CUDA.CuIterator`"
The design inspiration was taken from `CUDA.CuIterator` and was generalized to work with
other backends and more complex iterators (using `Functors`).
!!! tip "`MLUtils.DataLoader`"
Calling `dev(::MLUtils.DataLoader)` will automatically convert the dataloader to use the
same semantics as `DeviceIterator`. This is generally preferred over looping over the
dataloader directly and transferring the data to the device.
## Examples
The following was run on a computer with an NVIDIA GPU.
```julia-repl
julia> using MLDataDevices, MLUtils
julia> X = rand(Float64, 3, 33);
julia> dataloader = DataLoader(X; batchsize=13, shuffle=false);
julia> for (i, x) in enumerate(dataloader)
@show i, summary(x)
end
(i, summary(x)) = (1, "3×13 Matrix{Float64}")
(i, summary(x)) = (2, "3×13 Matrix{Float64}")
(i, summary(x)) = (3, "3×7 Matrix{Float64}")
julia> for (i, x) in enumerate(CUDADevice()(dataloader))
@show i, summary(x)
end
(i, summary(x)) = (1, "3×13 CuArray{Float32, 2, CUDA.DeviceMemory}")
(i, summary(x)) = (2, "3×13 CuArray{Float32, 2, CUDA.DeviceMemory}")
(i, summary(x)) = (3, "3×7 CuArray{Float32, 2, CUDA.DeviceMemory}")
```
"""
struct DeviceIterator{D <: Function, I}
dev::D
iterator::I
end
function Base.iterate(c::DeviceIterator)
item = iterate(c.iterator)
item === nothing && return nothing
batch, next_state = item
dev_batch = c.dev(batch)
return dev_batch, (next_state, dev_batch)
end
function Base.iterate(c::DeviceIterator, (state, prev_batch))
item = iterate(c.iterator, state)
item === nothing && return nothing
batch, next_state = item
Internal.unsafe_free!(prev_batch) # free the previous batch
dev_batch = c.dev(batch)
return dev_batch, (next_state, dev_batch)
end
Base.IteratorSize(::Type{DeviceIterator{D, I}}) where {D, I} = Base.IteratorSize(I)
Base.length(c::DeviceIterator) = length(c.iterator)
Base.axes(c::DeviceIterator) = axes(c.iterator)
Base.IteratorEltype(::Type{DeviceIterator{D, I}}) where {D, I} = Base.EltypeUnknown()