You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I am running into a Segmentation fault (core dumped) crash when trying to compute gradients with Enzyme for a custom RNN model in Flux. Forward pass runs without issues, crash seems to occur on trying to compile a specific gradient computation.
Sorry the MWE is a bit long--the RNN is stateful, and the segfault only occurs when adding the possibility to let the model run for multiple steps before generating the answer (timemovement! in the MWE).
I can rewrite the code differently to avoid the segfault, so solvable on my end (possibly caused by a mistaken approach in my code? though Mooncake continued to provide gradients in this version), but posting here in case the segfault needs to be fixed.
I was able to reproduce on my apple-m1 and on linux HPC. Output below is from Linux. Tested with Flux 0.16.3 and Enzyme 0.13.30.
Version info
Julia Version 1.10.5
Commit 6f3fdf7b362 (2024-08-27 14:19 UTC)
Build Info:
Official https://julialang.org/ release
Platform Info:
OS: Linux (x86_64-linux-gnu)
CPU: 112 × Intel(R) Xeon(R) Platinum 8480CL
WORD_SIZE: 64
LIBM: libopenlibm
LLVM: libLLVM-15.0.7 (ORCJIT, sapphirerapids)
Threads: 1 default, 0 interactive, 1 GC (on 112 virtual cores)
Environment:
JULIA_DEBUG = all
Crash output:
[2327665] signal (11.1): Segmentation fault
in expression starting at none:0
jl_object_id__cold at /cache/build/builder-amdci4-4/julialang/julia-release-1-dot-10/src/builtins.c:455
type_hash at /cache/build/builder-amdci4-4/julialang/julia-release-1-dot-10/src/jltypes.c:1584
typekey_hash at /cache/build/builder-amdci4-4/julialang/julia-release-1-dot-10/src/jltypes.c:1614
jl_precompute_memoized_dt at /cache/build/builder-amdci4-4/julialang/julia-release-1-dot-10/src/jltypes.c:1694
inst_datatype_inner at /cache/build/builder-amdci4-4/julialang/julia-release-1-dot-10/src/jltypes.c:2141
jl_inst_arg_tuple_type at /cache/build/builder-amdci4-4/julialang/julia-release-1-dot-10/src/jltypes.c:2243
arg_type_tuple at /cache/build/builder-amdci4-4/julialang/julia-release-1-dot-10/src/gf.c:2232 [inlined]
jl_lookup_generic_ at /cache/build/builder-amdci4-4/julialang/julia-release-1-dot-10/src/gf.c:3021 [inlined]
ijl_apply_generic at /cache/build/builder-amdci4-4/julialang/julia-release-1-dot-10/src/gf.c:3073
print_response at /cache/build/builder-amdci4-4/julialang/julia-release-1-dot-10/usr/share/julia/stdlib/v1.10/REPL/src/REPL.jl:281
do_respond at /cache/build/builder-amdci4-4/julialang/julia-release-1-dot-10/usr/share/julia/stdlib/v1.10/REPL/src/REPL.jl:911
jfptr_do_respond_91941.1 at /usr/licensed/julia/1.10.5/lib/julia/sys.so (unknown line)
_jl_invoke at /cache/build/builder-amdci4-4/julialang/julia-release-1-dot-10/src/gf.c:2895 [inlined]
ijl_apply_generic at /cache/build/builder-amdci4-4/julialang/julia-release-1-dot-10/src/gf.c:3077
jl_apply at /cache/build/builder-amdci4-4/julialang/julia-release-1-dot-10/src/julia.h:1982 [inlined]
jl_f__call_latest at /cache/build/builder-amdci4-4/julialang/julia-release-1-dot-10/src/builtins.c:812
#invokelatest#2 at ./essentials.jl:892 [inlined]
invokelatest at ./essentials.jl:889 [inlined]
run_interface at /cache/build/builder-amdci4-4/julialang/julia-release-1-dot-10/usr/share/julia/stdlib/v1.10/REPL/src/LineEdit.jl:2656
jfptr_run_interface_90695.1 at /usr/licensed/julia/1.10.5/lib/julia/sys.so (unknown line)
_jl_invoke at /cache/build/builder-amdci4-4/julialang/julia-release-1-dot-10/src/gf.c:2895 [inlined]
ijl_apply_generic at /cache/build/builder-amdci4-4/julialang/julia-release-1-dot-10/src/gf.c:3077
run_frontend at /cache/build/builder-amdci4-4/julialang/julia-release-1-dot-10/usr/share/julia/stdlib/v1.10/REPL/src/REPL.jl:1312
#62 at /cache/build/builder-amdci4-4/julialang/julia-release-1-dot-10/usr/share/julia/stdlib/v1.10/REPL/src/REPL.jl:386
jfptr_YY.62_91844.1 at /usr/licensed/julia/1.10.5/lib/julia/sys.so (unknown line)
_jl_invoke at /cache/build/builder-amdci4-4/julialang/julia-release-1-dot-10/src/gf.c:2895 [inlined]
ijl_apply_generic at /cache/build/builder-amdci4-4/julialang/julia-release-1-dot-10/src/gf.c:3077
jl_apply at /cache/build/builder-amdci4-4/julialang/julia-release-1-dot-10/src/julia.h:1982 [inlined]
start_task at /cache/build/builder-amdci4-4/julialang/julia-release-1-dot-10/src/task.c:1238
Allocations: 136137670 (Pool: 135967501; Big: 170169); GC: 253
Segmentation fault (core dumped)
MWE:
using Flux
using Enzyme
import NNlib
using LinearAlgebra
#====CREATE CUSTOM RNN====#
# CREATE STATEFUL, MATRIX-VALUED RNN LAYER
mutable struct MatrixRnnCell{A<:Matrix{Float32}}
Whh::A # Recurrent weights
bh::A # Recurrent bias
h::A # Current state
const init::A # Store initial state h for reset
end
function matrnn_constructor(n::Int, k::Int)::MatrixRnnCell
Whh = randn(Float32, k, k) / sqrt(Float32(k))
bh = randn(Float32, k, n) / sqrt(Float32(k))
h = randn(Float32, k, n) * 0.01f0
return MatrixRnnCell(Whh, bh, h, h)
end
function(m::MatrixRnnCell)(state::Matrix{Float32},
I::AbstractArray{Float32, 2};
selfreset::Bool=false)::Matrix{Float32}
if selfreset
h = m.init
else
h = state
end
h_new = NNlib.tanh_fast.(m.Whh * h .+ m.bh .+ I)
m.h = h_new
return h_new
end
state(m::MatrixRnnCell) = m.h
reset!(m::MatrixRnnCell) = (m.h = m.init)
# CREATE SECOND LAYER
struct WeightedMeanLayer{V<:Vector{Float32}}
weight::V
end
function WeightedMeanLayer(num::Int;
init = ones)
weight_init = init(Float32, num)
WeightedMeanLayer(weight_init)
end
(a::WeightedMeanLayer)(X::Array{Float32}) = X' * a.weight
# CHAIN RNN AND SECOND LAYER
struct MatrixRNN{M<:MatrixRnnCell, D<:WeightedMeanLayer}
rnn::M
dec::D
end
state(m::MatrixRNN) = state(m.rnn)
reset!(m::MatrixRNN) = reset!(m.rnn)
(m::MatrixRNN)(x::AbstractArray{Float32,2};
selfreset::Bool = false)::Vector{Float32} = (m.dec ∘ m.rnn)(
state(m), x; selfreset = selfreset)
Flux.@layer MatrixRnnCell trainable=(Whh, bh)
Flux.@layer WeightedMeanLayer
Flux.@layer :expand MatrixRNN trainable=(rnn, dec)
Flux.@non_differentiable reset!(m::MatrixRnnCell)
Flux.@non_differentiable reset!(m::MatrixRNN)
EnzymeRules.inactive(::typeof(reset!), args...) = nothing
#====DEFINE LOSS FUNCTIONS====#
# Helper function for prediction loops
# To avoid the segfault, comment out the timemovement! function here, and the two calls to it below.
function timemovement!(m, x, turns)::Nothing
for _ in 1:turns
m(x; selfreset = false)
end
end
function populatepreds!(preds, m, xs::Array{Float32, 3}, turns)::Nothing
for i in axes(xs, 3)
reset!(m)
example = @view xs[:,:,i]
timemovement!(m, example, turns)
preds[:,i] .= m(example; selfreset = false)
end
end
# Predict - Single datum
function modelpredict(m,
x::Matrix{Float32},
turns::Int)::Matrix{Float32}
if m.rnn.h != m.rnn.init
reset!(m)
end
timemovement!(m, x, turns)
preds = m(x; selfreset = false)
return reshape(preds, :, 1)
end
# Predict - Many data points in an array
function modelpredict(m,
xs::Array{Float32, 3},
turns::Int)::Matrix{Float32}
output_size = size(xs, 2)
nb_examples = size(xs, 3)
preds = Array{Float32}(undef, output_size, nb_examples)
populatepreds!(preds, m, xs, turns)
return preds
end
# Wrapper for prediction and a trivial loss for MWE
function trainingloss(m, xs, ys, turns)
ys_hat = modelpredict(m, xs, turns)
return mean(abs2, vec(ys) .- vec(ys_hat))
end
#====CREATE TRIVIAL DATA====#
const K::Int = 400
const N::Int = 64
const DATASETSIZE::Int = 64
dataY = randn(Float32, N, N, DATASETSIZE)
dataX = randn(Float32, K, N*N, DATASETSIZE)
#====INITIALIZE MODEL====#
activemodel = MatrixRNN(
matrnn_constructor(N^2, K),
WeightedMeanLayer(K))
# Test loss function, entire dataset
using BenchmarkTools
@btime trainingloss($activemodel, $dataX, $dataY, 0)
#====GRADIENT COMPUTATION====#
# Training loss, one data point, no time step -- this works.
loss, grads = Flux.withgradient(trainingloss,
Duplicated(activemodel),
dataX[:,:,1],
dataY[:,:,1], 0)
# Training loss, 2 (or more) data points, no time step -- segfault.
loss, grads = Flux.withgradient(trainingloss,
Duplicated(activemodel),
dataX[:,:,1:2],
dataY[:,:,1:2], 0)
loss, grads = Flux.withgradient(trainingloss,
Duplicated(activemodel),
dataX[:,:,1:32],
dataY[:,:,1:32], 0)
The text was updated successfully, but these errors were encountered:
will take a look shortly, though if you're doing a full neural network, I'd recommend also using Reactant.jl on the outside (which will optimize it and automatically rewrite it to apply to CPU/GPU/TPU/etc).
I am running into a Segmentation fault (core dumped) crash when trying to compute gradients with Enzyme for a custom RNN model in Flux. Forward pass runs without issues, crash seems to occur on trying to compile a specific gradient computation.
Sorry the MWE is a bit long--the RNN is stateful, and the segfault only occurs when adding the possibility to let the model run for multiple steps before generating the answer (timemovement! in the MWE).
I can rewrite the code differently to avoid the segfault, so solvable on my end (possibly caused by a mistaken approach in my code? though Mooncake continued to provide gradients in this version), but posting here in case the segfault needs to be fixed.
I was able to reproduce on my apple-m1 and on linux HPC. Output below is from Linux. Tested with Flux 0.16.3 and Enzyme 0.13.30.
Version info
Crash output:
MWE:
The text was updated successfully, but these errors were encountered: