Lux + Enzyme and Zygote + NeuralODE, segmentation fault

Hi all,

I’m using Julia v1.11.2, and installed the used packages on 12/6 (last Friday), so I assume I have the most recent package versions.

I’ve been following along the SciML tutorials for Neural ordinary differential equations (NODE), which I can get to run fine.

For my use case, u[1] and u[2] are parameters of the ODE that I want to solve and change with time. I also want them to be learnable.

There are two other parameters that I want to pass to the NN. In my MWE (see below) I keep the two parameters in vs, since they affect the dynamics of the original ODE I am trying to learn. These two parameters do not change with time, but I do want to train the same NODE across different values vs. Ultimately, I am planning to include an outer training loop that iterates across various values of vs. I include the two parameters as u[3] and u[4], so I can pass them to the NODE! function.

The following code seg faults for the Zygote.gradient line and also the Enzyme autodiff check (from here: Frequently Asked Questions (FAQ) · SciMLSensitivity.jl) for isolating potential gradient issues. Note the Enzyme test is commented out, but it also causes seg fault if I uncomment it and comment the Zygote test.

println("Load libraries")
using OrdinaryDiffEq, DifferentialEquations, ComponentArrays
using Enzyme, Lux
using Zygote, SciMLSensitivity
using StableRNGs

rng = StableRNG(1111)

println("Setup NODE")

const sc = Lux.Chain(Lux.Dense(4,5,tanh), Lux.Dense(5,5,tanh), Lux.Dense(5,2))
# Get the initial parameters and state variables of the model
p_nn, st = Lux.setup(rng, sc)
const _st = st

function NODE!(du,u,p,t)
    NN = sc(u,p,_st)[1]
    du[1] = u[2] + NN[1]
    du[2] = NN[2]
end

println("Test NODE")
u0_test = [1.0,2.0]
vs = [3.,4.]

theta = ComponentArray(p=p_nn, u0=u0_test)
sc([theta.u0;vs], theta.p, _st)

println("Check if ODE solves")

ts = [0,.1,.2,.3,.4,.5]

prob = ODEProblem(NODE!, [theta.u0; vs], (0.,1.))
pred = solve(prob, Tsit5(), u0=[theta.u0; vs], p=theta.p, saveat=ts)[1,:]

R = rand(length(ts))

function predict_neuralode(theta)
    Array(solve(prob, Tsit5(), u0 = [theta.u0; vs], p = theta.p, saveat = ts))
end

function loss_neuralode(theta)
    pred = predict_neuralode(theta)[1,:]
    loss = sum(abs2, R .- pred)
    return loss
end

loss_neuralode(theta)


println("Test zygote")

Zygote.gradient(loss_neuralode, theta)


# println("Check if works with Enzyme autodiff")

# Testing Enzyme autodiff based on: https://docs.sciml.ai/SciMLSensitivity/dev/faq/

# prob = ODEProblem(NODE!, u_test, (0.,1.), theta)
# u0 = prob.u0
# p = prob.p
# tmp2 = Enzyme.make_zero(p)
# t = prob.tspan[1]
# du = zero(u0)

# if DiffEqBase.isinplace(prob)
#     _f = prob.f
# else
#     _f = (du, u, p, t) -> (du .= prob.f(u, p, t); nothing)
# end

# _tmp6 = Enzyme.make_zero(_f)
# tmp3 = zero(u0)
# tmp4 = zero(u0)
# ytmp = zero(u0)
# tmp1 = zero(u0)

# Enzyme.autodiff(Enzyme.Reverse, Enzyme.Duplicated(_f, _tmp6),
#     Enzyme.Const, Enzyme.Duplicated(tmp3, tmp4),
#     Enzyme.Duplicated(ytmp, tmp1),
#     Enzyme.Duplicated(p, tmp2),
#     Enzyme.Const(t))

# Test gradient
# gradient(Reverse, loss_neuralode, theta)

I appreciate any help!

(See code below for reference) I succeeded in getting past the segmentation fault for Zygote by not passing vs to the NODE via u[3] and u[4] and just using vs in NODE! directly (i.e. sc([u;vs],p,_st) instead of sc(u,p,_st) as before). So now I can get a Zygote gradient. Enzyme still fails.

Now I have the odd behavior that Zygote.gradient works if I set layer_size=5, but Julia crashes out with Internal error: during type inference if layer_size is larger (for example = 10 or 50).

println("Load libraries")
using OrdinaryDiffEq, DifferentialEquations, ComponentArrays
using Enzyme, Lux
using Zygote, SciMLSensitivity
using StableRNGs

using Plots

rng = StableRNG(1111)

println("Setup NODE")

in_size = 4
# This script works when layer_size=5, but julia crashes on type inference when layer_size=10 when calling Zygote.gradient()
layer_size = 10

const sc = Lux.Chain(Lux.Dense(in_size,layer_size,tanh), 
                     Lux.Dense(layer_size,layer_size,tanh), 
                     Lux.Dense(layer_size,2))
# Get the initial parameters and state variables of the model
p_nn, st = Lux.setup(rng, sc)
const _st = st

function NODE!(du,u,p,t)
    NN = sc([u;vs],p,_st)[1]
    du[1] = u[2] + NN[1]
    du[2] = NN[2]
end

println("Test NODE")
u0_test = [1.0,2.0]
vs = [3.,4.]

theta = ComponentArray(p=p_nn, u0=u0_test)
sc([theta.u0;vs], theta.p, _st)

println("Check if ODE solves")

ts = [0,.1,.2,.3,.4,.5]

prob = ODEProblem(NODE!, theta.u0, (0.,1.))
pred = solve(prob, Tsit5(), u0=theta.u0, p=theta.p, saveat=ts)[1,:]

R = rand(length(ts))

function predict_neuralode(theta)
    Array(solve(prob, Tsit5(), u0 = theta.u0, p = theta.p, saveat = ts))
end

function loss_neuralode(theta)
    pred = predict_neuralode(theta)[1,:]
    loss = sum(abs2, R .- pred)
    return loss
end

loss_neuralode(theta)

println("Test zygote")

Zygote.gradient(loss_neuralode, theta)

Further questions:

  1. Does anyone know why this fixes the segmentation fault error?
  2. What is the recommended way to do what I am trying to do with the vs? Is there a good way to pass vs to NODE! not as a global variable, but as an input to NODE!? This would be helpful, as I am planning to iterate through various values of vs with their corresponding R during training.

Can you share the Enzyme error message? If you run in a terminal it should give it.

Chris,

Here’s the error message for Enzyme:

ERROR: LoadError: "Error cannot store inactive but differentiable variable [0.0, 0.1, 0.2, 0.3, 0.4, 0.5] into active tuple"
Stacktrace:
  [1] runtime_newstruct_augfwd(activity::Type{Val{(true, true, false)}}, width::Val{1}, ModifiedBetween::Val{(true, true, true, true)}, ::Type{@NamedTuple{u0::SubArray{Float64, 1, Vector{Float64}, Tuple{UnitRange{Int64}}, true}, p::ComponentVector{Float64, SubArray{Float64, 1, Vector{Float64}, Tuple{UnitRange{Int64}}, true}, Tuple{Axis{(layer_1 = ViewAxis(1:50, Axis(weight = ViewAxis(1:40, ShapedAxis((10, 4))), bias = 41:50)), layer_2 = ViewAxis(51:160, Axis(weight = ViewAxis(1:100, ShapedAxis((10, 10))), bias = 101:110)), layer_3 = ViewAxis(161:182, Axis(weight = ViewAxis(1:20, ShapedAxis((2, 10))), bias = 21:22)))}}}, saveat::Vector{Float64}}}, RT::Val{Any}, primal_1::SubArray{Float64, 1, Vector{Float64}, Tuple{UnitRange{Int64}}, true}, shadow_1_1::SubArray{Float64, 1, Vector{Float64}, Tuple{UnitRange{Int64}}, true}, primal_2::ComponentVector{Float64, SubArray{Float64, 1, Vector{Float64}, Tuple{UnitRange{Int64}}, true}, Tuple{Axis{(layer_1 = ViewAxis(1:50, Axis(weight = ViewAxis(1:40, ShapedAxis((10, 4))), bias = 41:50)), layer_2 = ViewAxis(51:160, Axis(weight = ViewAxis(1:100, ShapedAxis((10, 10))), bias = 101:110)), layer_3 = ViewAxis(161:182, Axis(weight = ViewAxis(1:20, ShapedAxis((2, 10))), bias = 21:22)))}}}, shadow_2_1::ComponentVector{Float64, SubArray{Float64, 1, Vector{Float64}, Tuple{UnitRange{Int64}}, true}, Tuple{Axis{(layer_1 = ViewAxis(1:50, Axis(weight = ViewAxis(1:40, ShapedAxis((10, 4))), bias = 41:50)), layer_2 = ViewAxis(51:160, Axis(weight = ViewAxis(1:100, ShapedAxis((10, 10))), bias = 101:110)), layer_3 = ViewAxis(161:182, Axis(weight = ViewAxis(1:20, ShapedAxis((2, 10))), bias = 21:22)))}}}, primal_3::Vector{Float64}, shadow_3_1::Nothing)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/aViNX/src/rules/typeunstablerules.jl:33
  [2] NamedTuple
    @ ./boot.jl:727 [inlined]
  [3] predict_neuralode
    @ ~/DropBounce/src/MWE.jl:52
  [4] loss_neuralode
    @ ~/DropBounce/src/MWE.jl:56 [inlined]
  [5] loss_neuralode
    @ ~/DropBounce/src/MWE.jl:0 [inlined]
  [6] augmented_julia_loss_neuralode_17527_inner_1wrap
    @ ~/DropBounce/src/MWE.jl:0
  [7] macro expansion
    @ ~/.julia/packages/Enzyme/aViNX/src/compiler.jl:5201 [inlined]
  [8] enzyme_call
    @ ~/.julia/packages/Enzyme/aViNX/src/compiler.jl:4747 [inlined]
  [9] AugmentedForwardThunk
    @ ~/.julia/packages/Enzyme/aViNX/src/compiler.jl:4683 [inlined]
 [10] autodiff
    @ ~/.julia/packages/Enzyme/aViNX/src/Enzyme.jl:396 [inlined]
 [11] autodiff
    @ ~/.julia/packages/Enzyme/aViNX/src/Enzyme.jl:524 [inlined]
 [12] macro expansion
    @ ~/.julia/packages/Enzyme/aViNX/src/sugar.jl:319 [inlined]
 [13] gradient(::ReverseMode{false, false, FFIABI, false, false}, ::typeof(loss_neuralode), ::ComponentVector{Float64, Vector{Float64}, Tuple{Axis{(p = ViewAxis(1:182, Axis(layer_1 = ViewAxis(1:50, Axis(weight = ViewAxis(1:40, ShapedAxis((10, 4))), bias = 41:50)), layer_2 = ViewAxis(51:160, Axis(weight = ViewAxis(1:100, ShapedAxis((10, 10))), bias = 101:110)), layer_3 = ViewAxis(161:182, Axis(weight = ViewAxis(1:20, ShapedAxis((2, 10))), bias = 21:22)))), u0 = 183:184)}}})
    @ Enzyme ~/.julia/packages/Enzyme/aViNX/src/sugar.jl:258
 [14] top-level scope
    @ ~/DropBounce/src/MWE.jl:115
in expression starting at /home/vleon/DropBounce/src/MWE.jl:115

And here is the current version of the code I’m running:

println("Load libraries")
using OrdinaryDiffEq, DifferentialEquations, ComponentArrays
using Enzyme, Lux
using Zygote, SciMLSensitivity
using StableRNGs

using Optimization, OptimizationOptimisers

using Plots

rng = StableRNG(1111)

println("Setup NODE")

in_size = 4
# This script works with Zygote.gradient when layer_size=5, but julia crashes on type inference when layer_size=10 when calling Zygote.gradient()
layer_size = 10

const sc = Lux.Chain(Lux.Dense(in_size,layer_size,tanh), 
                     Lux.Dense(layer_size,layer_size,tanh), 
                     Lux.Dense(layer_size,2))
# Get the initial parameters and state variables of the model
p_nn, st = Lux.setup(rng, sc)
const _st = st

function NODE!(du,u,p,t)
    NN = sc([u;vs],p,_st)[1]
    du[1] = u[2] + NN[1]
    du[2] = NN[2]
end

println("Test NODE")
u0_test = [1.0,2.0]
vs = [3.,4.]

theta = ComponentArray(p=p_nn, u0=u0_test)
sc([theta.u0;vs], theta.p, _st)

println("Check if ODE solves")

ts = [0,.1,.2,.3,.4,.5]

prob = ODEProblem(NODE!, theta.u0, (0.,1.))
pred = solve(prob, Tsit5(), u0=theta.u0, p=theta.p, saveat=ts)[1,:]

R = rand(length(ts))

function predict_neuralode(theta)
    Array(solve(prob, Tsit5(), u0 = theta.u0, p = theta.p, saveat = ts))
end

function loss_neuralode(theta)
    pred = predict_neuralode(theta)[1,:]
    loss = sum(abs2, R .- pred)
    return loss
end

loss_neuralode(theta)

println("Check if works with Enzyme autodiff")

Enzyme.gradient(Reverse, loss_neuralode, theta)

# Testing Enzyme autodiff based on: https://docs.sciml.ai/SciMLSensitivity/dev/faq/

prob = ODEProblem(NODE!, theta.u0, (0.,1.), theta.p)
u0 = prob.u0
p = prob.p
tmp2 = Enzyme.make_zero(p)
t = prob.tspan[1]
du = zero(u0)

if DiffEqBase.isinplace(prob)
    _f = prob.f
else
    _f = (du, u, p, t) -> (du .= prob.f(u, p, t); nothing)
end

_tmp6 = Enzyme.make_zero(_f)
tmp3 = zero(u0)
tmp4 = zero(u0)
ytmp = zero(u0)
tmp1 = zero(u0)

# Error here
Enzyme.autodiff(Enzyme.Reverse, Enzyme.Duplicated(_f, _tmp6),
    Enzyme.Const, Enzyme.Duplicated(tmp3, tmp4),
    Enzyme.Duplicated(ytmp, tmp1),
    Enzyme.Duplicated(p, tmp2),
    Enzyme.Const(t))

It would be nice if it could work with Zygote.gradient as well with larger theta, since I would like to use Optimization.jl as follows, but I get the Internal error type inference error which seems to lead to a segmentation error (see below).

println("Test zygote")

Zygote.gradient(loss_neuralode, theta)

println("Test optimization")

adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x, theta) -> loss_neuralode(x), adtype)
optprob = Optimization.OptimizationProblem(optf, theta)

losses = Float64[]

callback = function (state, l)
    push!(losses, l)
    if length(losses) % 50 == 0
        println("Current loss after $(length(losses)) iterations: $(losses[end])")
    end
    return false
end

iter1 = 1000

res1 = Optimization.solve(
    optprob, OptimizationOptimisers.Adam(), callback = callback, maxiters = iter1)
println("Training loss after $(length(losses)) iterations: $(losses[end])")

This gets this error when doing Zygote(I cut off due to how long the error message is). At the bottom it says segmentation fault.

┌ Warning: Potential performance improvement omitted. EnzymeVJP tried and failed in the automated AD choice algorithm. To show the stack trace, set SciMLSensitivity.STACKTRACE_WITH_VJPWARN[] = true. To turn off this printing, add `verbose = false` to the `solve` call.
└ @ SciMLSensitivity ~/.julia/packages/SciMLSensitivity/l0eoK/src/concrete_solve.jl:24

Internal error: during type inference of
overdub(Cassette.Context{FunctionProperties.var"##HasBranchingCtx#Name", Base.Dict{Symbol, Bool}, Nothing, FunctionProperties.var"##PassType#230", Nothing, Cassette.DisableHooks}, Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{Lux.Dense{typeof(Base.tanh), Int64, Int64, Nothing, Nothing, Static.True}, Lux.Dense{typeof(Base.tanh), Int64, Int64, Nothing, Nothing, Static.True}, Lux.Dense{typeof(Base.identity), Int64, Int64, Nothing, Nothing, Static.True}}}, Nothing}, Array{Float64, 1}, ComponentArrays.ComponentArray{Float64, 1, Array{Float64, 1}, Tuple{ComponentArrays.Axis{(layer_1=ComponentArrays.ViewAxis{Base.UnitRange{Int64}(start=1, stop=50), (weight=ComponentArrays.ViewAxis{Base.UnitRange{Int64}(start=1, stop=40), nothing, ComponentArrays.ShapedAxis{(10, 4)}}(ax=ComponentArrays.ShapedAxis{(10, 4)}()), bias=Base.UnitRange{Int64}(start=41, stop=50)), ComponentArrays.Axis{(weight=ComponentArrays.ViewAxis{Base.UnitRange{Int64}(start=1, stop=40), nothing, ComponentArrays.ShapedAxis{(10, 4)}}(ax=ComponentArrays.ShapedAxis{(10, 4)}()), bias=Base.UnitRange{Int64}(start=41, stop=50))}}(ax=ComponentArrays.Axis{(weight=ComponentArrays.ViewAxis{Base.UnitRange{Int64}(start=1, stop=40), nothing, ComponentArrays.ShapedAxis{(10, 4)}}(ax=ComponentArrays.ShapedAxis{(10, 4)}()), bias=Base.UnitRange{Int64}(start=41, stop=50))}()), layer_2=ComponentArrays.ViewAxis{Base.UnitRange{Int64}(start=51, stop=160), (weight=ComponentArrays.ViewAxis{Base.UnitRange{Int64}(start=1, stop=100), nothing, ComponentArrays.ShapedAxis{(10, 10)}}(ax=ComponentArrays.ShapedAxis{(10, 10)}()), bias=Base.UnitRange{Int64}(start=101, stop=110)), ComponentArrays.Axis{(weight=ComponentArrays.ViewAxis{Base.UnitRange{Int64}(start=1, stop=100), nothing, ComponentArrays.ShapedAxis{(10, 10)}}(ax=ComponentArrays.ShapedAxis{(10, 10)}()), bias=Base.UnitRange{Int64}(start=101, stop=110))}}(ax=ComponentArrays.Axis{(weight=ComponentArrays.ViewAxis{Base.UnitRange{Int64}(start=1, stop=100), nothing, ComponentArrays.ShapedAxis{(10, 10)}}(ax=ComponentArrays.ShapedAxis{(10, 10)}()), bias=Base.UnitRange{Int64}(start=101, stop=110))}()), layer_3=ComponentArrays.ViewAxis{Base.UnitRange{Int64}(start=161, stop=182), (weight=ComponentArrays.ViewAxis{Base.UnitRange{Int64}(start=1, stop=20), nothing, ComponentArrays.ShapedAxis{(2, 10)}}(ax=ComponentArrays.ShapedAxis{(2, 10)}()), bias=Base.UnitRange{Int64}(start=21, stop=22)), ComponentArrays.Axis{(weight=ComponentArrays.ViewAxis{Base.UnitRange{Int64}(start=1, stop=20), nothing, ComponentArrays.ShapedAxis{(2, 10)}}(ax=ComponentArrays.ShapedAxis{(2, 10)}()), bias=Base.UnitRange{Int64}(start=21, stop=22))}}(ax=ComponentArrays.Axis{(weight=ComponentArrays.ViewAxis{Base.UnitRange{Int64}(start=1, stop=20), nothing, ComponentArrays.ShapedAxis{(2, 10)}}(ax=ComponentArrays.ShapedAxis{(2, 10)}()), bias=Base.UnitRange{Int64}(start=21, stop=22))}()))}}}, NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{NamedTuple{(), Tuple{}}, NamedTuple{(), Tuple{}}, NamedTuple{(), Tuple{}}}})
Encountered unexpected error in runtime:
MethodError(f=Base.inferencebarrier, args=(typeof(Base.string)(),), world=0x0000000000001735)
jl_method_error_bare at /cache/build/tester-amdci5-12/julialang/julia-release-1-dot-11/src/gf.c:2254
jl_method_error at /cache/build/tester-amdci5-12/julialang/julia-release-1-dot-11/src/gf.c:2272
jl_lookup_generic_ at /cache/build/tester-amdci5-12/julialang/julia-release-1-dot-11/src/gf.c:3106 [inlined]
ijl_apply_generic at /cache/build/tester-amdci5-12/julialang/julia-release-1-dot-11/src/gf.c:3121
macro expansion at ./error.jl:235 [inlined]
renumber_ssa at ./compiler/ssair/slot2ssa.jl:56
#472 at ./compiler/ssair/slot2ssa.jl:62 [inlined]
ssamap at ./compiler/utilities.jl:360
renumber_ssa! at ./compiler/ssair/slot2ssa.jl:62 [inlined]
renumber_ssa! at ./compiler/ssair/slot2ssa.jl:61 [inlined]
construct_ssa! at ./compiler/ssair/slot2ssa.jl:902
slot2reg at ./compiler/optimize.jl:1219 [inlined]
run_passes_ipo_safe at ./compiler/optimize.jl:994
run_passes_ipo_safe at ./compiler/optimize.jl:1009 [inlined]
optimize at ./compiler/optimize.jl:983
jfptr_optimize_42650.1 at /home/vleon/.julia/juliaup/julia-1.11.2+0.x64.linux.gnu/lib/julia/sys.so (unknown line)
finish_nocycle at ./compiler/typeinfer.jl:265
_typeinf at ./compiler/typeinfer.jl:249
typeinf at ./compiler/typeinfer.jl:215
typeinf_edge at ./compiler/typeinfer.jl:923
abstract_call_method at ./compiler/abstractinterpretation.jl:660
abstract_call_gf_by_type at ./compiler/abstractinterpretation.jl:102
abstract_call_known at ./compiler/abstractinterpretation.jl:2200
abstract_call at ./compiler/abstractinterpretation.jl:2282
abstract_apply at ./compiler/abstractinterpretation.jl:1690
abstract_call_known at ./compiler/abstractinterpretation.jl:2102
abstract_call at ./compiler/abstractinterpretation.jl:2282
abstract_call at ./compiler/abstractinterpretation.jl:2275
abstract_call at ./compiler/abstractinterpretation.jl:2420
abstract_eval_call at ./compiler/abstractinterpretation.jl:2435
abstract_eval_statement_expr at ./compiler/abstractinterpretation.jl:2451
abstract_eval_statement at ./compiler/abstractinterpretation.jl:2749
abstract_eval_basic_statement at ./compiler/abstractinterpretation.jl:3065
typeinf_local at ./compiler/abstractinterpretation.jl:3319
typeinf_nocycle at ./compiler/abstractinterpretation.jl:3401
_typeinf at ./compiler/typeinfer.jl:244
typeinf at ./compiler/typeinfer.jl:215
typeinf_edge at ./compiler/typeinfer.jl:923
abstract_call_method at ./compiler/abstractinterpretation.jl:660
abstract_call_gf_by_type at ./compiler/abstractinterpretation.jl:102
abstract_call_known at ./compiler/abstractinterpretation.jl:2200
abstract_call at ./compiler/abstractinterpretation.jl:2282
abstract_call at ./compiler/abstractinterpretation.jl:2275
abstract_call at ./compiler/abstractinterpretation.jl:2420
abstract_eval_call at ./compiler/abstractinterpretation.jl:2435
abstract_eval_statement_expr at ./compiler/abstractinterpretation.jl:2451
abstract_eval_statement at ./compiler/abstractinterpretation.jl:2749
abstract_eval_basic_statement at ./compiler/abstractinterpretation.jl:3065
typeinf_local at ./compiler/abstractinterpretation.jl:3319
typeinf_nocycle at ./compiler/abstractinterpretation.jl:3401
_typeinf at ./compiler/typeinfer.jl:244
typeinf at ./compiler/typeinfer.jl:215
typeinf_edge at ./compiler/typeinfer.jl:923
abstract_call_method at ./compiler/abstractinterpretation.jl:660
abstract_call_gf_by_type at ./compiler/abstractinterpretation.jl:102
abstract_call_known at ./compiler/abstractinterpretation.jl:2200
abstract_call at ./compiler/abstractinterpretation.jl:2282
abstract_apply at ./compiler/abstractinterpretation.jl:1690
abstract_call_known at ./compiler/abstractinterpretation.jl:2102
abstract_call at ./compiler/abstractinterpretation.jl:2282
abstract_call at ./compiler/abstractinterpretation.jl:2275
abstract_call at ./compiler/abstractinterpretation.jl:2420
abstract_eval_call at ./compiler/abstractinterpretation.jl:2435
abstract_eval_statement_expr at ./compiler/abstractinterpretation.jl:2451
abstract_eval_statement at ./compiler/abstractinterpretation.jl:2749
abstract_eval_basic_statement at ./compiler/abstractinterpretation.jl:3065
typeinf_local at ./compiler/abstractinterpretation.jl:3319
typeinf_nocycle at ./compiler/abstractinterpretation.jl:3401
_typeinf at ./compiler/typeinfer.jl:244
typeinf at ./compiler/typeinfer.jl:215
typeinf_edge at ./compiler/typeinfer.jl:923
abstract_call_method at ./compiler/abstractinterpretation.jl:660
abstract_call_gf_by_type at ./compiler/abstractinterpretation.jl:102
abstract_call_known at ./compiler/abstractinterpretation.jl:2200
abstract_call at ./compiler/abstractinterpretation.jl:2282
abstract_call at ./compiler/abstractinterpretation.jl:2275
abstract_call at ./compiler/abstractinterpretation.jl:2420
abstract_eval_call at ./compiler/abstractinterpretation.jl:2435
abstract_eval_statement_expr at ./compiler/abstractinterpretation.jl:2451
abstract_eval_statement at ./compiler/abstractinterpretation.jl:2749
abstract_eval_basic_statement at ./compiler/abstractinterpretation.jl:3065
typeinf_local at ./compiler/abstractinterpretation.jl:3319
typeinf_nocycle at ./compiler/abstractinterpretation.jl:3401
_typeinf at ./compiler/typeinfer.jl:244
typeinf at ./compiler/typeinfer.jl:215
typeinf_edge at ./compiler/typeinfer.jl:923
abstract_call_method at ./compiler/abstractinterpretation.jl:660
abstract_call_gf_by_type at ./compiler/abstractinterpretation.jl:102
abstract_call_known at ./compiler/abstractinterpretation.jl:2200
abstract_call at ./compiler/abstractinterpretation.jl:2282
abstract_apply at ./compiler/abstractinterpretation.jl:1690
abstract_call_known at ./compiler/abstractinterpretation.jl:2102
abstract_call at ./compiler/abstractinterpretation.jl:2282
abstract_call at ./compiler/abstractinterpretation.jl:2275
abstract_call at ./compiler/abstractinterpretation.jl:2420
abstract_eval_call at ./compiler/abstractinterpretation.jl:2435
abstract_eval_statement_expr at ./compiler/abstractinterpretation.jl:2451
abstract_eval_statement at ./compiler/abstractinterpretation.jl:2749
abstract_eval_basic_statement at ./compiler/abstractinterpretation.jl:3065
typeinf_local at ./compiler/abstractinterpretation.jl:3319
typeinf_nocycle at ./compiler/abstractinterpretation.jl:3401
_typeinf at ./compiler/typeinfer.jl:244
typeinf at ./compiler/typeinfer.jl:215
typeinf_edge at ./compiler/typeinfer.jl:923
abstract_call_method at ./compiler/abstractinterpretation.jl:660
abstract_call_gf_by_type at ./compiler/abstractinterpretation.jl:102
abstract_call_known at ./compiler/abstractinterpretation.jl:2200
abstract_call at ./compiler/abstractinterpretation.jl:2282
abstract_call at ./compiler/abstractinterpretation.jl:2275
abstract_call at ./compiler/abstractinterpretation.jl:2420
abstract_eval_call at ./compiler/abstractinterpretation.jl:2435
abstract_eval_statement_expr at ./compiler/abstractinterpretation.jl:2451
abstract_eval_statement at ./compiler/abstractinterpretation.jl:2749
abstract_eval_basic_statement at ./compiler/abstractinterpretation.jl:3065
typeinf_local at ./compiler/abstractinterpretation.jl:3319
typeinf_nocycle at ./compiler/abstractinterpretation.jl:3401

Could you make an issue on Enzyme for that MWE?

For the Zygote one, I think that’s the Cassette v1.11 update issue. @aviatesk or @Oscar_Smith were you looking into that?

Chris, sounds good, I opened an issue on Enzyme for the MWE here: Cannot store inactive but differentiable variable into active tuple · Issue #2212 · EnzymeAD/Enzyme.jl · GitHub

I’m interested in a fix for the Zygote issue for larger layer sizes.

Also, I’m curious if there is a better way to pass vs into my NODE! function without it being differentiable. In the full training loop, I will be training the sc NN with varying vs. Right now, I am using global variable vs which I feel may not be the best approach?