Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Segfault when computing Enzyme gradients for backprop through time on Flux custom RNN #2302

Open
m-laprise opened this issue Feb 9, 2025 · 1 comment

Comments

@m-laprise
Copy link

I am running into a Segmentation fault (core dumped) crash when trying to compute gradients with Enzyme for a custom RNN model in Flux. Forward pass runs without issues, crash seems to occur on trying to compile a specific gradient computation.

Sorry the MWE is a bit long--the RNN is stateful, and the segfault only occurs when adding the possibility to let the model run for multiple steps before generating the answer (timemovement! in the MWE).

I can rewrite the code differently to avoid the segfault, so solvable on my end (possibly caused by a mistaken approach in my code? though Mooncake continued to provide gradients in this version), but posting here in case the segfault needs to be fixed.

I was able to reproduce on my apple-m1 and on linux HPC. Output below is from Linux. Tested with Flux 0.16.3 and Enzyme 0.13.30.

Version info

Julia Version 1.10.5
Commit 6f3fdf7b362 (2024-08-27 14:19 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 112 × Intel(R) Xeon(R) Platinum 8480CL
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-15.0.7 (ORCJIT, sapphirerapids)
Threads: 1 default, 0 interactive, 1 GC (on 112 virtual cores)
Environment:
  JULIA_DEBUG = all

Crash output:

[2327665] signal (11.1): Segmentation fault
in expression starting at none:0
jl_object_id__cold at /cache/build/builder-amdci4-4/julialang/julia-release-1-dot-10/src/builtins.c:455
type_hash at /cache/build/builder-amdci4-4/julialang/julia-release-1-dot-10/src/jltypes.c:1584
typekey_hash at /cache/build/builder-amdci4-4/julialang/julia-release-1-dot-10/src/jltypes.c:1614
jl_precompute_memoized_dt at /cache/build/builder-amdci4-4/julialang/julia-release-1-dot-10/src/jltypes.c:1694
inst_datatype_inner at /cache/build/builder-amdci4-4/julialang/julia-release-1-dot-10/src/jltypes.c:2141
jl_inst_arg_tuple_type at /cache/build/builder-amdci4-4/julialang/julia-release-1-dot-10/src/jltypes.c:2243
arg_type_tuple at /cache/build/builder-amdci4-4/julialang/julia-release-1-dot-10/src/gf.c:2232 [inlined]
jl_lookup_generic_ at /cache/build/builder-amdci4-4/julialang/julia-release-1-dot-10/src/gf.c:3021 [inlined]
ijl_apply_generic at /cache/build/builder-amdci4-4/julialang/julia-release-1-dot-10/src/gf.c:3073
print_response at /cache/build/builder-amdci4-4/julialang/julia-release-1-dot-10/usr/share/julia/stdlib/v1.10/REPL/src/REPL.jl:281
do_respond at /cache/build/builder-amdci4-4/julialang/julia-release-1-dot-10/usr/share/julia/stdlib/v1.10/REPL/src/REPL.jl:911
jfptr_do_respond_91941.1 at /usr/licensed/julia/1.10.5/lib/julia/sys.so (unknown line)
_jl_invoke at /cache/build/builder-amdci4-4/julialang/julia-release-1-dot-10/src/gf.c:2895 [inlined]
ijl_apply_generic at /cache/build/builder-amdci4-4/julialang/julia-release-1-dot-10/src/gf.c:3077
jl_apply at /cache/build/builder-amdci4-4/julialang/julia-release-1-dot-10/src/julia.h:1982 [inlined]
jl_f__call_latest at /cache/build/builder-amdci4-4/julialang/julia-release-1-dot-10/src/builtins.c:812
#invokelatest#2 at ./essentials.jl:892 [inlined]
invokelatest at ./essentials.jl:889 [inlined]
run_interface at /cache/build/builder-amdci4-4/julialang/julia-release-1-dot-10/usr/share/julia/stdlib/v1.10/REPL/src/LineEdit.jl:2656
jfptr_run_interface_90695.1 at /usr/licensed/julia/1.10.5/lib/julia/sys.so (unknown line)
_jl_invoke at /cache/build/builder-amdci4-4/julialang/julia-release-1-dot-10/src/gf.c:2895 [inlined]
ijl_apply_generic at /cache/build/builder-amdci4-4/julialang/julia-release-1-dot-10/src/gf.c:3077
run_frontend at /cache/build/builder-amdci4-4/julialang/julia-release-1-dot-10/usr/share/julia/stdlib/v1.10/REPL/src/REPL.jl:1312
#62 at /cache/build/builder-amdci4-4/julialang/julia-release-1-dot-10/usr/share/julia/stdlib/v1.10/REPL/src/REPL.jl:386
jfptr_YY.62_91844.1 at /usr/licensed/julia/1.10.5/lib/julia/sys.so (unknown line)
_jl_invoke at /cache/build/builder-amdci4-4/julialang/julia-release-1-dot-10/src/gf.c:2895 [inlined]
ijl_apply_generic at /cache/build/builder-amdci4-4/julialang/julia-release-1-dot-10/src/gf.c:3077
jl_apply at /cache/build/builder-amdci4-4/julialang/julia-release-1-dot-10/src/julia.h:1982 [inlined]
start_task at /cache/build/builder-amdci4-4/julialang/julia-release-1-dot-10/src/task.c:1238
Allocations: 136137670 (Pool: 135967501; Big: 170169); GC: 253
Segmentation fault (core dumped)

MWE:

using Flux
using Enzyme
import NNlib
using LinearAlgebra

#====CREATE CUSTOM RNN====#

# CREATE STATEFUL, MATRIX-VALUED RNN LAYER
mutable struct MatrixRnnCell{A<:Matrix{Float32}}
    Whh::A # Recurrent weights
    bh::A  # Recurrent bias
    h::A   # Current state
    const init::A # Store initial state h for reset
end

function matrnn_constructor(n::Int, k::Int)::MatrixRnnCell
    Whh = randn(Float32, k, k) / sqrt(Float32(k))
    bh = randn(Float32, k, n) / sqrt(Float32(k))
    h = randn(Float32, k, n) * 0.01f0
    return MatrixRnnCell(Whh, bh, h, h)
end

function(m::MatrixRnnCell)(state::Matrix{Float32}, 
                           I::AbstractArray{Float32, 2}; 
                           selfreset::Bool=false)::Matrix{Float32}
    if selfreset
        h = m.init
    else
        h = state
    end
    h_new = NNlib.tanh_fast.(m.Whh * h .+ m.bh .+ I) 
    m.h = h_new
    return h_new 
end
state(m::MatrixRnnCell) = m.h
reset!(m::MatrixRnnCell) = (m.h = m.init)

# CREATE SECOND LAYER
struct WeightedMeanLayer{V<:Vector{Float32}}
    weight::V
end
function WeightedMeanLayer(num::Int;
                           init = ones)
    weight_init = init(Float32, num)
    WeightedMeanLayer(weight_init)
end
(a::WeightedMeanLayer)(X::Array{Float32}) = X' * a.weight

# CHAIN RNN AND SECOND LAYER
struct MatrixRNN{M<:MatrixRnnCell, D<:WeightedMeanLayer}
    rnn::M
    dec::D
end
state(m::MatrixRNN) = state(m.rnn)
reset!(m::MatrixRNN) = reset!(m.rnn)
(m::MatrixRNN)(x::AbstractArray{Float32,2}; 
               selfreset::Bool = false)::Vector{Float32} = (m.dec ∘ m.rnn)(
                state(m), x; selfreset = selfreset)
            
Flux.@layer MatrixRnnCell trainable=(Whh, bh)
Flux.@layer WeightedMeanLayer
Flux.@layer :expand MatrixRNN trainable=(rnn, dec)
Flux.@non_differentiable reset!(m::MatrixRnnCell)
Flux.@non_differentiable reset!(m::MatrixRNN)
EnzymeRules.inactive(::typeof(reset!), args...) = nothing

#====DEFINE LOSS FUNCTIONS====#
# Helper function for prediction loops
# To avoid the segfault, comment out the timemovement! function here, and the two calls to it below.
function timemovement!(m, x, turns)::Nothing
    for _ in 1:turns
        m(x; selfreset = false)
    end
end
function populatepreds!(preds, m, xs::Array{Float32, 3}, turns)::Nothing
    for i in axes(xs, 3)
        reset!(m)
        example = @view xs[:,:,i]
        timemovement!(m, example, turns)
        preds[:,i] .= m(example; selfreset = false)
    end
end

# Predict - Single datum
function modelpredict(m, 
                      x::Matrix{Float32}, 
                      turns::Int)::Matrix{Float32}
    if m.rnn.h != m.rnn.init
        reset!(m)
    end
    timemovement!(m, x, turns)
    preds = m(x; selfreset = false)
    return reshape(preds, :, 1)
end

# Predict - Many data points in an array
function modelpredict(m, 
                      xs::Array{Float32, 3}, 
                      turns::Int)::Matrix{Float32}
    output_size = size(xs, 2)
    nb_examples = size(xs, 3)
    preds = Array{Float32}(undef, output_size, nb_examples)
    populatepreds!(preds, m, xs, turns)
    return preds
end

# Wrapper for prediction and a trivial loss for MWE
function trainingloss(m, xs, ys, turns)
    ys_hat = modelpredict(m, xs, turns)
    return mean(abs2, vec(ys) .- vec(ys_hat))
end

#====CREATE TRIVIAL DATA====#
const K::Int = 400
const N::Int = 64
const DATASETSIZE::Int = 64
dataY = randn(Float32, N, N, DATASETSIZE) 
dataX = randn(Float32, K, N*N, DATASETSIZE) 

#====INITIALIZE MODEL====#
activemodel = MatrixRNN(
    matrnn_constructor(N^2, K), 
    WeightedMeanLayer(K)) 
# Test loss function, entire dataset 
using BenchmarkTools
@btime trainingloss($activemodel, $dataX, $dataY, 0)

#====GRADIENT COMPUTATION====#
# Training loss, one data point, no time step -- this works.
loss, grads = Flux.withgradient(trainingloss, 
                                Duplicated(activemodel), 
                                dataX[:,:,1], 
                                dataY[:,:,1], 0)

# Training loss, 2 (or more) data points, no time step -- segfault.
loss, grads = Flux.withgradient(trainingloss, 
                                Duplicated(activemodel), 
                                dataX[:,:,1:2], 
                                dataY[:,:,1:2], 0)

loss, grads = Flux.withgradient(trainingloss, 
                                Duplicated(activemodel), 
                                dataX[:,:,1:32], 
                                dataY[:,:,1:32], 0) 
@wsmoses
Copy link
Member

wsmoses commented Feb 11, 2025

will take a look shortly, though if you're doing a full neural network, I'd recommend also using Reactant.jl on the outside (which will optimize it and automatically rewrite it to apply to CPU/GPU/TPU/etc).

https://github.com/EnzymeAD/Reactant.jl

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants