Chris,
Here’s the error message for Enzyme:
ERROR: LoadError: "Error cannot store inactive but differentiable variable [0.0, 0.1, 0.2, 0.3, 0.4, 0.5] into active tuple"
Stacktrace:
[1] runtime_newstruct_augfwd(activity::Type{Val{(true, true, false)}}, width::Val{1}, ModifiedBetween::Val{(true, true, true, true)}, ::Type{@NamedTuple{u0::SubArray{Float64, 1, Vector{Float64}, Tuple{UnitRange{Int64}}, true}, p::ComponentVector{Float64, SubArray{Float64, 1, Vector{Float64}, Tuple{UnitRange{Int64}}, true}, Tuple{Axis{(layer_1 = ViewAxis(1:50, Axis(weight = ViewAxis(1:40, ShapedAxis((10, 4))), bias = 41:50)), layer_2 = ViewAxis(51:160, Axis(weight = ViewAxis(1:100, ShapedAxis((10, 10))), bias = 101:110)), layer_3 = ViewAxis(161:182, Axis(weight = ViewAxis(1:20, ShapedAxis((2, 10))), bias = 21:22)))}}}, saveat::Vector{Float64}}}, RT::Val{Any}, primal_1::SubArray{Float64, 1, Vector{Float64}, Tuple{UnitRange{Int64}}, true}, shadow_1_1::SubArray{Float64, 1, Vector{Float64}, Tuple{UnitRange{Int64}}, true}, primal_2::ComponentVector{Float64, SubArray{Float64, 1, Vector{Float64}, Tuple{UnitRange{Int64}}, true}, Tuple{Axis{(layer_1 = ViewAxis(1:50, Axis(weight = ViewAxis(1:40, ShapedAxis((10, 4))), bias = 41:50)), layer_2 = ViewAxis(51:160, Axis(weight = ViewAxis(1:100, ShapedAxis((10, 10))), bias = 101:110)), layer_3 = ViewAxis(161:182, Axis(weight = ViewAxis(1:20, ShapedAxis((2, 10))), bias = 21:22)))}}}, shadow_2_1::ComponentVector{Float64, SubArray{Float64, 1, Vector{Float64}, Tuple{UnitRange{Int64}}, true}, Tuple{Axis{(layer_1 = ViewAxis(1:50, Axis(weight = ViewAxis(1:40, ShapedAxis((10, 4))), bias = 41:50)), layer_2 = ViewAxis(51:160, Axis(weight = ViewAxis(1:100, ShapedAxis((10, 10))), bias = 101:110)), layer_3 = ViewAxis(161:182, Axis(weight = ViewAxis(1:20, ShapedAxis((2, 10))), bias = 21:22)))}}}, primal_3::Vector{Float64}, shadow_3_1::Nothing)
@ Enzyme.Compiler ~/.julia/packages/Enzyme/aViNX/src/rules/typeunstablerules.jl:33
[2] NamedTuple
@ ./boot.jl:727 [inlined]
[3] predict_neuralode
@ ~/DropBounce/src/MWE.jl:52
[4] loss_neuralode
@ ~/DropBounce/src/MWE.jl:56 [inlined]
[5] loss_neuralode
@ ~/DropBounce/src/MWE.jl:0 [inlined]
[6] augmented_julia_loss_neuralode_17527_inner_1wrap
@ ~/DropBounce/src/MWE.jl:0
[7] macro expansion
@ ~/.julia/packages/Enzyme/aViNX/src/compiler.jl:5201 [inlined]
[8] enzyme_call
@ ~/.julia/packages/Enzyme/aViNX/src/compiler.jl:4747 [inlined]
[9] AugmentedForwardThunk
@ ~/.julia/packages/Enzyme/aViNX/src/compiler.jl:4683 [inlined]
[10] autodiff
@ ~/.julia/packages/Enzyme/aViNX/src/Enzyme.jl:396 [inlined]
[11] autodiff
@ ~/.julia/packages/Enzyme/aViNX/src/Enzyme.jl:524 [inlined]
[12] macro expansion
@ ~/.julia/packages/Enzyme/aViNX/src/sugar.jl:319 [inlined]
[13] gradient(::ReverseMode{false, false, FFIABI, false, false}, ::typeof(loss_neuralode), ::ComponentVector{Float64, Vector{Float64}, Tuple{Axis{(p = ViewAxis(1:182, Axis(layer_1 = ViewAxis(1:50, Axis(weight = ViewAxis(1:40, ShapedAxis((10, 4))), bias = 41:50)), layer_2 = ViewAxis(51:160, Axis(weight = ViewAxis(1:100, ShapedAxis((10, 10))), bias = 101:110)), layer_3 = ViewAxis(161:182, Axis(weight = ViewAxis(1:20, ShapedAxis((2, 10))), bias = 21:22)))), u0 = 183:184)}}})
@ Enzyme ~/.julia/packages/Enzyme/aViNX/src/sugar.jl:258
[14] top-level scope
@ ~/DropBounce/src/MWE.jl:115
in expression starting at /home/vleon/DropBounce/src/MWE.jl:115
And here is the current version of the code I’m running:
println("Load libraries")
using OrdinaryDiffEq, DifferentialEquations, ComponentArrays
using Enzyme, Lux
using Zygote, SciMLSensitivity
using StableRNGs
using Optimization, OptimizationOptimisers
using Plots
rng = StableRNG(1111)
println("Setup NODE")
in_size = 4
# This script works with Zygote.gradient when layer_size=5, but julia crashes on type inference when layer_size=10 when calling Zygote.gradient()
layer_size = 10
const sc = Lux.Chain(Lux.Dense(in_size,layer_size,tanh),
Lux.Dense(layer_size,layer_size,tanh),
Lux.Dense(layer_size,2))
# Get the initial parameters and state variables of the model
p_nn, st = Lux.setup(rng, sc)
const _st = st
function NODE!(du,u,p,t)
NN = sc([u;vs],p,_st)[1]
du[1] = u[2] + NN[1]
du[2] = NN[2]
end
println("Test NODE")
u0_test = [1.0,2.0]
vs = [3.,4.]
theta = ComponentArray(p=p_nn, u0=u0_test)
sc([theta.u0;vs], theta.p, _st)
println("Check if ODE solves")
ts = [0,.1,.2,.3,.4,.5]
prob = ODEProblem(NODE!, theta.u0, (0.,1.))
pred = solve(prob, Tsit5(), u0=theta.u0, p=theta.p, saveat=ts)[1,:]
R = rand(length(ts))
function predict_neuralode(theta)
Array(solve(prob, Tsit5(), u0 = theta.u0, p = theta.p, saveat = ts))
end
function loss_neuralode(theta)
pred = predict_neuralode(theta)[1,:]
loss = sum(abs2, R .- pred)
return loss
end
loss_neuralode(theta)
println("Check if works with Enzyme autodiff")
Enzyme.gradient(Reverse, loss_neuralode, theta)
# Testing Enzyme autodiff based on: https://docs.sciml.ai/SciMLSensitivity/dev/faq/
prob = ODEProblem(NODE!, theta.u0, (0.,1.), theta.p)
u0 = prob.u0
p = prob.p
tmp2 = Enzyme.make_zero(p)
t = prob.tspan[1]
du = zero(u0)
if DiffEqBase.isinplace(prob)
_f = prob.f
else
_f = (du, u, p, t) -> (du .= prob.f(u, p, t); nothing)
end
_tmp6 = Enzyme.make_zero(_f)
tmp3 = zero(u0)
tmp4 = zero(u0)
ytmp = zero(u0)
tmp1 = zero(u0)
# Error here
Enzyme.autodiff(Enzyme.Reverse, Enzyme.Duplicated(_f, _tmp6),
Enzyme.Const, Enzyme.Duplicated(tmp3, tmp4),
Enzyme.Duplicated(ytmp, tmp1),
Enzyme.Duplicated(p, tmp2),
Enzyme.Const(t))
It would be nice if it could work with Zygote.gradient as well with larger theta, since I would like to use Optimization.jl as follows, but I get the Internal error type inference error which seems to lead to a segmentation error (see below).
println("Test zygote")
Zygote.gradient(loss_neuralode, theta)
println("Test optimization")
adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x, theta) -> loss_neuralode(x), adtype)
optprob = Optimization.OptimizationProblem(optf, theta)
losses = Float64[]
callback = function (state, l)
push!(losses, l)
if length(losses) % 50 == 0
println("Current loss after $(length(losses)) iterations: $(losses[end])")
end
return false
end
iter1 = 1000
res1 = Optimization.solve(
optprob, OptimizationOptimisers.Adam(), callback = callback, maxiters = iter1)
println("Training loss after $(length(losses)) iterations: $(losses[end])")
This gets this error when doing Zygote(I cut off due to how long the error message is). At the bottom it says segmentation fault
.
┌ Warning: Potential performance improvement omitted. EnzymeVJP tried and failed in the automated AD choice algorithm. To show the stack trace, set SciMLSensitivity.STACKTRACE_WITH_VJPWARN[] = true. To turn off this printing, add `verbose = false` to the `solve` call.
└ @ SciMLSensitivity ~/.julia/packages/SciMLSensitivity/l0eoK/src/concrete_solve.jl:24
Internal error: during type inference of
overdub(Cassette.Context{FunctionProperties.var"##HasBranchingCtx#Name", Base.Dict{Symbol, Bool}, Nothing, FunctionProperties.var"##PassType#230", Nothing, Cassette.DisableHooks}, Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{Lux.Dense{typeof(Base.tanh), Int64, Int64, Nothing, Nothing, Static.True}, Lux.Dense{typeof(Base.tanh), Int64, Int64, Nothing, Nothing, Static.True}, Lux.Dense{typeof(Base.identity), Int64, Int64, Nothing, Nothing, Static.True}}}, Nothing}, Array{Float64, 1}, ComponentArrays.ComponentArray{Float64, 1, Array{Float64, 1}, Tuple{ComponentArrays.Axis{(layer_1=ComponentArrays.ViewAxis{Base.UnitRange{Int64}(start=1, stop=50), (weight=ComponentArrays.ViewAxis{Base.UnitRange{Int64}(start=1, stop=40), nothing, ComponentArrays.ShapedAxis{(10, 4)}}(ax=ComponentArrays.ShapedAxis{(10, 4)}()), bias=Base.UnitRange{Int64}(start=41, stop=50)), ComponentArrays.Axis{(weight=ComponentArrays.ViewAxis{Base.UnitRange{Int64}(start=1, stop=40), nothing, ComponentArrays.ShapedAxis{(10, 4)}}(ax=ComponentArrays.ShapedAxis{(10, 4)}()), bias=Base.UnitRange{Int64}(start=41, stop=50))}}(ax=ComponentArrays.Axis{(weight=ComponentArrays.ViewAxis{Base.UnitRange{Int64}(start=1, stop=40), nothing, ComponentArrays.ShapedAxis{(10, 4)}}(ax=ComponentArrays.ShapedAxis{(10, 4)}()), bias=Base.UnitRange{Int64}(start=41, stop=50))}()), layer_2=ComponentArrays.ViewAxis{Base.UnitRange{Int64}(start=51, stop=160), (weight=ComponentArrays.ViewAxis{Base.UnitRange{Int64}(start=1, stop=100), nothing, ComponentArrays.ShapedAxis{(10, 10)}}(ax=ComponentArrays.ShapedAxis{(10, 10)}()), bias=Base.UnitRange{Int64}(start=101, stop=110)), ComponentArrays.Axis{(weight=ComponentArrays.ViewAxis{Base.UnitRange{Int64}(start=1, stop=100), nothing, ComponentArrays.ShapedAxis{(10, 10)}}(ax=ComponentArrays.ShapedAxis{(10, 10)}()), bias=Base.UnitRange{Int64}(start=101, stop=110))}}(ax=ComponentArrays.Axis{(weight=ComponentArrays.ViewAxis{Base.UnitRange{Int64}(start=1, stop=100), nothing, ComponentArrays.ShapedAxis{(10, 10)}}(ax=ComponentArrays.ShapedAxis{(10, 10)}()), bias=Base.UnitRange{Int64}(start=101, stop=110))}()), layer_3=ComponentArrays.ViewAxis{Base.UnitRange{Int64}(start=161, stop=182), (weight=ComponentArrays.ViewAxis{Base.UnitRange{Int64}(start=1, stop=20), nothing, ComponentArrays.ShapedAxis{(2, 10)}}(ax=ComponentArrays.ShapedAxis{(2, 10)}()), bias=Base.UnitRange{Int64}(start=21, stop=22)), ComponentArrays.Axis{(weight=ComponentArrays.ViewAxis{Base.UnitRange{Int64}(start=1, stop=20), nothing, ComponentArrays.ShapedAxis{(2, 10)}}(ax=ComponentArrays.ShapedAxis{(2, 10)}()), bias=Base.UnitRange{Int64}(start=21, stop=22))}}(ax=ComponentArrays.Axis{(weight=ComponentArrays.ViewAxis{Base.UnitRange{Int64}(start=1, stop=20), nothing, ComponentArrays.ShapedAxis{(2, 10)}}(ax=ComponentArrays.ShapedAxis{(2, 10)}()), bias=Base.UnitRange{Int64}(start=21, stop=22))}()))}}}, NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{NamedTuple{(), Tuple{}}, NamedTuple{(), Tuple{}}, NamedTuple{(), Tuple{}}}})
Encountered unexpected error in runtime:
MethodError(f=Base.inferencebarrier, args=(typeof(Base.string)(),), world=0x0000000000001735)
jl_method_error_bare at /cache/build/tester-amdci5-12/julialang/julia-release-1-dot-11/src/gf.c:2254
jl_method_error at /cache/build/tester-amdci5-12/julialang/julia-release-1-dot-11/src/gf.c:2272
jl_lookup_generic_ at /cache/build/tester-amdci5-12/julialang/julia-release-1-dot-11/src/gf.c:3106 [inlined]
ijl_apply_generic at /cache/build/tester-amdci5-12/julialang/julia-release-1-dot-11/src/gf.c:3121
macro expansion at ./error.jl:235 [inlined]
renumber_ssa at ./compiler/ssair/slot2ssa.jl:56
#472 at ./compiler/ssair/slot2ssa.jl:62 [inlined]
ssamap at ./compiler/utilities.jl:360
renumber_ssa! at ./compiler/ssair/slot2ssa.jl:62 [inlined]
renumber_ssa! at ./compiler/ssair/slot2ssa.jl:61 [inlined]
construct_ssa! at ./compiler/ssair/slot2ssa.jl:902
slot2reg at ./compiler/optimize.jl:1219 [inlined]
run_passes_ipo_safe at ./compiler/optimize.jl:994
run_passes_ipo_safe at ./compiler/optimize.jl:1009 [inlined]
optimize at ./compiler/optimize.jl:983
jfptr_optimize_42650.1 at /home/vleon/.julia/juliaup/julia-1.11.2+0.x64.linux.gnu/lib/julia/sys.so (unknown line)
finish_nocycle at ./compiler/typeinfer.jl:265
_typeinf at ./compiler/typeinfer.jl:249
typeinf at ./compiler/typeinfer.jl:215
typeinf_edge at ./compiler/typeinfer.jl:923
abstract_call_method at ./compiler/abstractinterpretation.jl:660
abstract_call_gf_by_type at ./compiler/abstractinterpretation.jl:102
abstract_call_known at ./compiler/abstractinterpretation.jl:2200
abstract_call at ./compiler/abstractinterpretation.jl:2282
abstract_apply at ./compiler/abstractinterpretation.jl:1690
abstract_call_known at ./compiler/abstractinterpretation.jl:2102
abstract_call at ./compiler/abstractinterpretation.jl:2282
abstract_call at ./compiler/abstractinterpretation.jl:2275
abstract_call at ./compiler/abstractinterpretation.jl:2420
abstract_eval_call at ./compiler/abstractinterpretation.jl:2435
abstract_eval_statement_expr at ./compiler/abstractinterpretation.jl:2451
abstract_eval_statement at ./compiler/abstractinterpretation.jl:2749
abstract_eval_basic_statement at ./compiler/abstractinterpretation.jl:3065
typeinf_local at ./compiler/abstractinterpretation.jl:3319
typeinf_nocycle at ./compiler/abstractinterpretation.jl:3401
_typeinf at ./compiler/typeinfer.jl:244
typeinf at ./compiler/typeinfer.jl:215
typeinf_edge at ./compiler/typeinfer.jl:923
abstract_call_method at ./compiler/abstractinterpretation.jl:660
abstract_call_gf_by_type at ./compiler/abstractinterpretation.jl:102
abstract_call_known at ./compiler/abstractinterpretation.jl:2200
abstract_call at ./compiler/abstractinterpretation.jl:2282
abstract_call at ./compiler/abstractinterpretation.jl:2275
abstract_call at ./compiler/abstractinterpretation.jl:2420
abstract_eval_call at ./compiler/abstractinterpretation.jl:2435
abstract_eval_statement_expr at ./compiler/abstractinterpretation.jl:2451
abstract_eval_statement at ./compiler/abstractinterpretation.jl:2749
abstract_eval_basic_statement at ./compiler/abstractinterpretation.jl:3065
typeinf_local at ./compiler/abstractinterpretation.jl:3319
typeinf_nocycle at ./compiler/abstractinterpretation.jl:3401
_typeinf at ./compiler/typeinfer.jl:244
typeinf at ./compiler/typeinfer.jl:215
typeinf_edge at ./compiler/typeinfer.jl:923
abstract_call_method at ./compiler/abstractinterpretation.jl:660
abstract_call_gf_by_type at ./compiler/abstractinterpretation.jl:102
abstract_call_known at ./compiler/abstractinterpretation.jl:2200
abstract_call at ./compiler/abstractinterpretation.jl:2282
abstract_apply at ./compiler/abstractinterpretation.jl:1690
abstract_call_known at ./compiler/abstractinterpretation.jl:2102
abstract_call at ./compiler/abstractinterpretation.jl:2282
abstract_call at ./compiler/abstractinterpretation.jl:2275
abstract_call at ./compiler/abstractinterpretation.jl:2420
abstract_eval_call at ./compiler/abstractinterpretation.jl:2435
abstract_eval_statement_expr at ./compiler/abstractinterpretation.jl:2451
abstract_eval_statement at ./compiler/abstractinterpretation.jl:2749
abstract_eval_basic_statement at ./compiler/abstractinterpretation.jl:3065
typeinf_local at ./compiler/abstractinterpretation.jl:3319
typeinf_nocycle at ./compiler/abstractinterpretation.jl:3401
_typeinf at ./compiler/typeinfer.jl:244
typeinf at ./compiler/typeinfer.jl:215
typeinf_edge at ./compiler/typeinfer.jl:923
abstract_call_method at ./compiler/abstractinterpretation.jl:660
abstract_call_gf_by_type at ./compiler/abstractinterpretation.jl:102
abstract_call_known at ./compiler/abstractinterpretation.jl:2200
abstract_call at ./compiler/abstractinterpretation.jl:2282
abstract_call at ./compiler/abstractinterpretation.jl:2275
abstract_call at ./compiler/abstractinterpretation.jl:2420
abstract_eval_call at ./compiler/abstractinterpretation.jl:2435
abstract_eval_statement_expr at ./compiler/abstractinterpretation.jl:2451
abstract_eval_statement at ./compiler/abstractinterpretation.jl:2749
abstract_eval_basic_statement at ./compiler/abstractinterpretation.jl:3065
typeinf_local at ./compiler/abstractinterpretation.jl:3319
typeinf_nocycle at ./compiler/abstractinterpretation.jl:3401
_typeinf at ./compiler/typeinfer.jl:244
typeinf at ./compiler/typeinfer.jl:215
typeinf_edge at ./compiler/typeinfer.jl:923
abstract_call_method at ./compiler/abstractinterpretation.jl:660
abstract_call_gf_by_type at ./compiler/abstractinterpretation.jl:102
abstract_call_known at ./compiler/abstractinterpretation.jl:2200
abstract_call at ./compiler/abstractinterpretation.jl:2282
abstract_apply at ./compiler/abstractinterpretation.jl:1690
abstract_call_known at ./compiler/abstractinterpretation.jl:2102
abstract_call at ./compiler/abstractinterpretation.jl:2282
abstract_call at ./compiler/abstractinterpretation.jl:2275
abstract_call at ./compiler/abstractinterpretation.jl:2420
abstract_eval_call at ./compiler/abstractinterpretation.jl:2435
abstract_eval_statement_expr at ./compiler/abstractinterpretation.jl:2451
abstract_eval_statement at ./compiler/abstractinterpretation.jl:2749
abstract_eval_basic_statement at ./compiler/abstractinterpretation.jl:3065
typeinf_local at ./compiler/abstractinterpretation.jl:3319
typeinf_nocycle at ./compiler/abstractinterpretation.jl:3401
_typeinf at ./compiler/typeinfer.jl:244
typeinf at ./compiler/typeinfer.jl:215
typeinf_edge at ./compiler/typeinfer.jl:923
abstract_call_method at ./compiler/abstractinterpretation.jl:660
abstract_call_gf_by_type at ./compiler/abstractinterpretation.jl:102
abstract_call_known at ./compiler/abstractinterpretation.jl:2200
abstract_call at ./compiler/abstractinterpretation.jl:2282
abstract_call at ./compiler/abstractinterpretation.jl:2275
abstract_call at ./compiler/abstractinterpretation.jl:2420
abstract_eval_call at ./compiler/abstractinterpretation.jl:2435
abstract_eval_statement_expr at ./compiler/abstractinterpretation.jl:2451
abstract_eval_statement at ./compiler/abstractinterpretation.jl:2749
abstract_eval_basic_statement at ./compiler/abstractinterpretation.jl:3065
typeinf_local at ./compiler/abstractinterpretation.jl:3319
typeinf_nocycle at ./compiler/abstractinterpretation.jl:3401