Hi,
I’m trying to optimize an a circuit equations with an UDE, following this example from the sciml docs.
U = Lux.Chain(Lux.Dense(3, 32, Lux.relu), Lux.Dense(32, 32, Lux.relu),
Lux.Dense(32, 32, Lux.relu), Lux.Dense(32, 2))
p_nn, st_nn = Lux.setup(rng, U)
function ude_dynamics!(du, u, p_nn, t, p_ode)
Q, I = u[1], u[2]
Û = U([Q, I, t], p_nn, st_nn)[1]
du[1] = -I + Û[1]
du[2] = (1 / L_mod - n_mod * L̃) * (Q / C_mod - (R_mod + n_mod * R_line) * I) + Û[2]
end
# Closure to be used in ODEproblem. Our ODE uses parameters from global scope,
# pass nothing here.
nn_dynamics!(du, u, p, t) = ude_dynamics!(du, u, p, t, nothing)
u0 = [n_mod * V₀ * C_mod, 0.0]
prob_ude = ODEProblem(nn_dynamics!, u0, (trg_ode[1], trg_ode[end]), p_nn)
sol_ude = solve(prob_ude, Tsit5(), reltol=1e-6, abstol=1e-6)
# Update the ODE problem with current NN parameters
function predict(θ, u0=u0, trg=trg_ode)
_prob = remake(prob_ude, u0=u0, tspan=(trg[1], trg[end]), p=θ)
Array(solve(_prob, Tsit5(), saveat=trg_ode, reltol=1e-6, abstol=1e-6))
end
# Loss is L2-distance to true current.
true_I = load_true_I(trg_ode)
function loss(θ)
Î = predict(θ)[2,:]
mean(abs2, Î - true_I)
end
losses = Float64[]
callback = function (p, l)
push!(losses, l)
if length(losses) % 50 == 0
println("Current loss after $(length(losses)) iterations: $(losses[end])")
end
return false
end
adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x, p_nn) -> loss(x), adtype)
optprob = Optimization.OptimizationProblem(optf, ComponentVector{Float64}(p_nn))
res1 = Optimization.solve(optprob, ADAM(), callback = callback, maxiters = 5000)
println("Training loss after $(length(losses)) iterations: $(losses[end])")
I’m having difficulties running this code. Most of the times I’m getting a segfault:
julia> res1 = Optimization.solve(optprob, ADAM(), callback = callback, maxiters = 10)
┌ Warning: Using fallback BLAS replacements, performance may be degraded
└ @ Enzyme.Compiler ~/.julia/packages/GPUCompiler/cy24l/src/utils.jl:56
GC error (probable corruption) :
Allocations: 288657556 (Pool: 288579160; Big: 78396); GC: 414
!!! ERROR in jl_ -- ABORTING !!!
0x165000000: Queued root: 0x2f9b18920 :: 0x120d301c0 (bits: 3)
of type Core.MethodInstance
0x165000018: Queued root: 0x172dd98d0 :: 0x120d2de10 (bits: 3)
of type Array{Any, 1}
0x165000030: Queued root: 0x2bab6d990 :: 0x120d2de10 (bits: 3)
of type Array{Any, 1}
0x165000048: Queued root: 0x2a628e410 :: 0x120d2de10 (bits: 7)
of type Array{Any, 1}
0x165000060: Queued root: 0x10ea39b30 :: 0x120d308d0 (bits: 3)
of type Core.SimpleVector
0x165000078: Queued root: 0x10e129c90 :: 0x120d2de10 (bits: 3)
of type Array{Any, 1}
0x165000090: Queued root: 0x1733ccc90 :: 0x120d308d0 (bits: 3)
of type Core.SimpleVector
0x1650000a8: Queued root: 0x2f1c8cec0 :: 0x120d301c0 (bits: 3)
of type Core.MethodInstance
0x1650000c0: Queued root: 0x2e6b00a30 :: 0x120d2de10 (bits: 3)
of type Array{Any, 1}
0x1650000d8: Queued root: 0x10c020b80 :: 0x120d2de10 (bits: 3)
of type Array{Any, 1}
0x1650000f0: Queued root: 0x2f0edb460 :: 0x120d2de10 (bits: 3)
of type Array{Any, 1}
0x165000108: Queued root: 0x2f0ed9c00 :: 0x120d2de10 (bits: 3)
of type Array{Any, 1}
0x165000120: Queued root: 0x2b9ec8cd0 :: 0x120d2de10 (bits: 3)
of type Array{Any, 1}
0x165000138: Queued root: 0x17240d560 :: 0x120d308d0 (bits: 3)
of type Core.SimpleVector
0x165000150: Queued root: 0x111833250 :: 0x120d2de10 (bits: 7)
of type Array{Any, 1}
0x165000168: Queued root: 0x111832fc0 :: 0x120d2de10 (bits: 7)
of type Array{Any, 1}
0x165000180: Queued root: 0x10c52f450 :: 0x120d2de10 (bits: 3)
of type Array{Any, 1}
[...]
0x165001ed8: Queued root: 0x10b32d2d0 :: 0x120d2e000 (bits: 3)
of type Task
0x165001ef0: r-- Stack frame 0x16d020700 -- 66 of 134 (direct)
0x165001f18: `- Object (16bit) 0x2bbfa6710 :: 0x2e7b9d6d1 -- [2, 96)
of type Enzyme.Compiler.Tape{NamedTuple{(Symbol("1"), Symbol("2")), Tuple{NamedTuple{(Symbol("1"), Symbol("2"), Symbol("3"), Symbol("4"), Symbol("5"), Symbol("6"), Symbol("7"), Symbol("8"), Symbol("9"), Symbol("10"), Symbol("11")), Tuple{NamedTuple{(Symbol("1"), Symbol("2"), Symbol("3"), Symbol("4"), Symbol("5"), Symbol("6"), Symbol("7"), Symbol("8"), Symbol("9"), Symbol("10"), Symbol("11"), Symbol("12"), Symbol("13"), Symbol("14"), Symbol("15"), Symbol("16"), Symbol("17"), Symbol("18"), Symbol("19"), Symbol("20"), Symbol("21"), Symbol("22"), Symbol("23"), Symbol("24"), Symbol("25")), Tuple{Any, Any, NamedTuple{(Symbol("1"), Symbol("2"), Symbol("3"), Symbol("4"), Symbol("5"), Symbol("6"), Symbol("7"), Symbol("8"), Symbol("9"), Symbol("10")), Tuple{NamedTuple{(Symbol("1"), Symbol("2"), Symbol("3"), Symbol("4"), Symbol("5"), Symbol("6"), Symbol("7"), Symbol("8")), Tuple{Core.LLVMPtr{UInt8, 0}, Core.LLVMPtr{Float64, 0}, Core.LLVMPtr{Float64, 0}, UInt64, Core.LLVMPtr{UInt8, 0}, UInt64, Core.LLVMPtr{Float64, 0}, Core.LLVMPtr{Float64, 0}}}, Any, Any, Any, Any, Any, Any, Any, Any, UInt64}}, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, UInt64, Bool, Bool, Bool, UInt64, Bool, Bool, Core.LLVMPtr{Bool, 0}, Core.LLVMPtr{Bool, 0}}}, NamedTuple{(Symbol("1"),), Tuple{Any}}, NamedTuple{(Symbol("1"), Symbol("2"), Symbol("3"), Symbol("4"), Symbol("5"), Symbol("6"), Symbol("7"), Symbol("8"), Symbol("9"), Symbol("10"), Symbol("11"), Symbol("12"), Symbol("13"), Symbol("14"), Symbol("15"), Symbol("16"), Symbol("17"), Symbol("18"), Symbol("19"), Symbol("20"), Symbol("21"), Symbol("22"), Symbol("23"), Symbol("24"), Symbol("25")), Tuple{Any, Any, NamedTuple{(Symbol("1"), Symbol("2"), Symbol("3"), Symbol("4"), Symbol("5"), Symbol("6"), Symbol("7"), Symbol("8"), Symbol("9"), Symbol("10")), Tuple{NamedTuple{(Symbol("1"), Symbol("2"), Symbol("3"), Symbol("4"), Symbol("5"), Symbol("6"), Symbol("7"), Symbol("8")), Tuple{Core.LLVMPtr{UInt8, 0}, Core.LLVMPtr{Float64, 0}, Core.LLVMPtr{Float64, 0}, UInt64, Core.LLVMPtr{UInt8, 0}, UInt64, Core.LLVMPtr{Float64, 0}, Core.LLVMPtr{Float64, 0}}}, Any, Any, Any, Any, Any, Any, Any, Any, UInt64}}, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, UInt64, Bool, Bool, Bool, UInt64, Bool, Bool, Core.LLVMPtr{Bool, 0}, Core.LLVMPtr{Bool, 0}}}, NamedTuple{(Symbol("1"),), Tuple{Any}}, NamedTuple{(Symbol("1"), Symbol("2"), Symbol("3"), Symbol("4"), Symbol("5"), Symbol("6"), Symbol("7"), Symbol("8"), Symbol("9"), Symbol("10"), Symbol("11"), Symbol("12"), Symbol("13"), Symbol("14"), Symbol("15"), Symbol("16"), Symbol("17"), Symbol("18"), Symbol("19"), Symbol("20"), Symbol("21"), Symbol("22"), Symbol("23"), Symbol("24"), Symbol("25")), Tuple{Any, Any, NamedTuple{(Symbol("1"), Symbol("2"), Symbol("3"), Symbol("4"), Symbol("5"), Symbol("6"), Symbol("7"), Symbol("8"), Symbol("9"), Symbol("10")), Tuple{NamedTuple{(Symbol("1"), Symbol("2"), Symbol("3"), Symbol("4"), Symbol("5"), Symbol("6"), Symbol("7"), Symbol("8")), Tuple{Core.LLVMPtr{UInt8, 0}, Core.LLVMPtr{Float64, 0}, Core.LLVMPtr{Float64, 0}, UInt64, Core.LLVMPtr{UInt8, 0}, UInt64, Core.LLVMPtr{Float64, 0}, Core.LLVMPtr{Float64, 0}}}, Any, Any, Any, Any, Any, Any, Any, Any, UInt64}}, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, UInt64, Bool, Bool, Bool, UInt64, Bool, Bool, Core.LLVMPtr{Bool, 0}, Core.LLVMPtr{Bool, 0}}}, NamedTuple{(Symbol("1"),), Tuple{Any}}, NamedTuple{(Symbol("1"), Symbol("2"), Symbol("3"), Symbol("4"), Symbol("5"), Symbol("6"), Symbol("7"), Symbol("8"), Symbol("9"), Symbol("10"), Symbol("11"), Symbol("12"), Symbol("13"), Symbol("14"), Symbol("15")), Tuple{Any, Any, NamedTuple{(Symbol("1"), Symbol("2"), Symbol("3"), Symbol("4"), Symbol("5"), Symbol("6"), Symbol("7"), Symbol("8"), Symbol("9"), Symbol("10")), Tuple{NamedTuple{(Symbol("1"), Symbol("2"), Symbol("3"), Symbol("4"), Symbol("5"), Symbol("6"), Symbol("7"), Symbol("8")), Tuple{Core.LLVMPtr{UInt8, 0}, Core.LLVMPtr{Float64, 0}, Core.LLVMPtr{Float64, 0}, UInt64, Core.LLVMPtr{UInt8, 0}, UInt64, Core.LLVMPtr{Float64, 0}, Core.LLVMPtr{Float64, 0}}}, Any, Any, Any, Any, Any, Any, Any, Any, UInt64}}, Any, Any, Any, Any, Any, Any, Any, Any, Any, Bool, Bool, Bool}}, NamedTuple{(Symbol("1"),), Tuple{Any}}, Any, Any, Any}}, NamedTuple{(Symbol("1"),), Tuple{Any}}}}, Nothing, Tuple{Array{Float64, 1}, NamedTuple{(:layer_1, :layer_2, :layer_3, :layer_4), NTuple{4, NamedTuple{(), Tuple{}}}}}}
[71577] signal (6): Abort trap: 6
in expression starting at REPL[21]:1
__pthread_kill at /usr/lib/system/libsystem_kernel.dylib (unknown line)
Allocations: 288657556 (Pool: 288579160; Big: 78396); GC: 414
[1] 71577 abort /Applications/Julia-1.9.app/Contents/Resources/julia/bin/julia --project=.
on an M2 with Julia-1.9.
It runs on a different machine (AMD-based) with Julia -1.9, but only when loading the script from the repl. In this case, this warning appears in every optimizer step:
┌ Warning: EnzymeVJP tried and failed in the automated AD choice algorithm with the following error. (To turn off this printing, add `verbose = false` to the `solve` call)
└ @ SciMLSensitivity ~/.julia/packages/SciMLSensitivity/E8w3Z/src/concrete_solve.jl:21
MethodError: no method matching asprogress(::Base.CoreLogging.LogLevel, ::String, ::Module, ::Symbol, ::Symbol, ::String, ::Int64)
The applicable method may be too new: running in world age 33663, while current world is 34006.
Closest candidates are:
asprogress(::Any, ::Any, ::Any, ::Any, ::Any, ::Any, ::Any; progress, kwargs...) (method too new to be called from this world context.)
@ ProgressLogging ~/.julia/packages/ProgressLogging/6KXlp/src/ProgressLogging.jl:156
asprogress(::Any, ::ProgressLogging.Progress, ::Any...; _...) (method too new to be called from this world context.)
@ ProgressLogging ~/.julia/packages/ProgressLogging/6KXlp/src/ProgressLogging.jl:155
asprogress(::Any, ::ProgressLogging.ProgressString, ::Any...; _...) (method too new to be called from this world context.)
@ ProgressLogging ~/.julia/packages/ProgressLogging/6KXlp/src/ProgressLogging.jl:200
Also, the performance is quiet slow. It runs only on a single core.
Any ideas how to get this code to run and increase performance (multi-threadings or gpu?) are appreciated