Skip to content
This repository has been archived by the owner on Nov 4, 2024. It is now read-only.

initializer taking a single type argument #29

Closed
a1ix2 opened this issue Jul 12, 2024 · 1 comment · Fixed by #30
Closed

initializer taking a single type argument #29

a1ix2 opened this issue Jul 12, 2024 · 1 comment · Fixed by #30

Comments

@a1ix2
Copy link

a1ix2 commented Jul 12, 2024

I want to initialize Lux layer weights using an initializer other than ones/zeros/rand/randn but still specify the type such as Float16 or Float32.

The way it's setup, Lux layers like Dense expect init_weight to be a function with signature (rng, dims...). The only option to customize is initial_weight=glorot_uniform(gain=0.5) and the __partial_apply are setup so that when Lux initializes it will always be calling (line 324 in initializer)

glorot_uniform(gain=0.5)(rng, dims...) = glorot_uniform(rng, Float32, dims...; gain=0.5)

I might be missing something obvious but I don't see a __partial_apply that exposes the right signature so that I could do something like initial_weight = glorot_uniform(Float64). If I try it of course complains

julia> ps, st = Lux.setup(Xoshiro(), nn)
ERROR: MethodError: no method matching glorot_uniform(::Xoshiro, ::Type{Float64}, ::Xoshiro, ::Int64, ::Int64)

So for now I use an anonymous function

init_weight = (rng, dims...) -> glorot_uniform(rng, Float64, dims...; gain=0.5)
@avik-pal
Copy link
Member

Thanks for pointing it out; seems like a bug. I will have a look

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants