Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Type piracy breaks (dev::AbstractDevice)(d::DataLoader) #2592

Open
piever opened this issue Feb 21, 2025 · 3 comments
Open

Type piracy breaks (dev::AbstractDevice)(d::DataLoader) #2592

piever opened this issue Feb 21, 2025 · 3 comments

Comments

@piever
Copy link
Contributor

piever commented Feb 21, 2025

Sorry for the blunt title, didn't know how to best describe it.

So, here's a MWE

julia> using MLUtils, MLDataDevices

julia> dev, data = cpu_device(), DataLoader(rand(10, 10));

julia> @which dev(data)
(D::MLDataDevices.AbstractDevice)(x)
     @ MLDataDevices ~/.julia/packages/MLDataDevices/uhCbD/src/public.jl:366

julia> dev(data)
DeviceIterator{CPUDevice, DataLoader{BatchView{Matrix{Float64}, ObsView{Matrix{Float64}, Vector{Int64}}, Val{nothing}}, Bool, :serial, Val{nothing}, Matrix{Float64}, Random.TaskLocalRNG}}(CPUDevice(), DataLoader(::Matrix{Float64}))

julia> using Flux

julia> @which dev(data)
(device::MLDataDevices.AbstractDevice)(d::DataLoader)
     @ Flux ~/.julia/packages/Flux/3711C/src/devices.jl:5

julia> dev(data)
ERROR: MethodError: no method matching DataLoader(::MLUtils.MappedData{…}, ::Int64, ::Bool, ::Bool, ::Bool, ::Bool, ::Val{…}, ::Random.TaskLocalRNG)
The type `DataLoader` exists, but no method is defined for this combination of argument types when trying to construct it.

Closest candidates are:
  DataLoader(::Any; buffer, parallel, shuffle, batchsize, partial, collate, rng)
   @ MLUtils ~/.julia/packages/MLUtils/EDvou/src/dataloader.jl:151

Stacktrace:
 [1] (::CPUDevice)(d::DataLoader{BatchView{Matrix{…}, ObsView{…}, Val{…}}, Bool, :serial, Val{nothing}, Matrix{Float64}, Random.TaskLocalRNG})
   @ Flux ~/.julia/packages/Flux/3711C/src/devices.jl:6
 [2] top-level scope
   @ REPL[9]:1
Some type information was truncated. Use `show(err)` to see complete types.

on julia 1.11.3 in a clean environment:

(jl_2Vzl72) pkg> st
Status `/tmp/jl_2Vzl72/Project.toml`
  [587475ba] Flux v0.16.3
  [7e8f7934] MLDataDevices v1.6.10
  [f1d291b0] MLUtils v0.4.7

I suspect the following code is problematic: https://github.com/FluxML/Flux.jl/blob/master/src/devices.jl. In particular the type piracy shouldn't be needed as the fallback device(dataloader) seems to already do the right thing (see also this Slack thread).

@ToucheSir
Copy link
Member

The situation is a little complicated, because DeviceIterator does not do the same thing as moving a DataLoader to a particular device. The former is only an iterator, while the latter path (which is what Flux had prior to moving to MLDataDevices), supports random access and everything else in the DataLoader API.

I think the fix here would be to move (device::AbstractDevice)(d::DataLoader) to an extension of MLUtils.jl. We could even make it an Adapt extension instead of a MLDataDevices one if that makes sense. It would also require some fixups to account for recent changes to MLUtils itself (the true reason for the error). CC @CarloLucibello for his thoughts on this.

@CarloLucibello
Copy link
Member

CarloLucibello commented Feb 22, 2025

Basically a DataLoader is just an iterable (it has no random access, DataLoader(rand(10), batchsize=2)[1] throws an error).

I think that DeviceIterator handles things nicely
by mapping to device while also trying to free memory during the iterations.

We should just remove

function (device::MLDataDevices.AbstractDevice)(d::MLUtils.DataLoader)

which is also undocumented.

I would also remove the specialization gpu(d::DataLoader) here:

function gpu(d::MLUtils.DataLoader)

If removed gpu(d) would create a DeviceIterator, as I think it should.
I wouldn't mark it as a breaking change because we would still basically conform to the current docstring:

"""
    gpu(data::DataLoader)
    cpu(data::DataLoader)

Transforms a given `DataLoader` to apply `gpu` or `cpu` to each batch of data,
when iterated over. (If no GPU is available, this does nothing.)
"""

@ToucheSir
Copy link
Member

You're right, I was mixing up the actual DataLoader struct with some other type (probably BatchView).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants