Hello, I am trying to perform AD in the Neural ODE setup using a SplitODEProblem
as it is a PDE discretization of the diffusion equation which is rather stiff when using explicit timesteppers. The neural network appears on the rhs
part of the SplitODEProblem
to capture missing fluxes. Here is a MWE:
using LinearAlgebra
using DiffEqBase
using OrdinaryDiffEq: SplitODEProblem, solve, IMEXEuler
import SciMLBase
using Lux, OptimizationOptimisers, ComponentArrays, Random, SciMLSensitivity
using Optimization
using Printf
n = 32
zC = collect(1.:1:32)
Δz = zC[2] - zC[1]
rng = Random.default_rng()
NN = Chain(Dense(33, 10, leakyrelu), Dense(10, 31))
ps, st = Lux.setup(rng, NN)
ps = ps |> ComponentArray .|> Float64
ps .*= 0
function Dᶜ(N, Δ)
D = zeros(N, N+1)
for k in 1:N
D[k, k] = -1.0
D[k, k+1] = 1.0
end
D .= 1/Δ .* D
return D
end
function Dᶠ(N, Δ)
D = zeros(N+1, N)
for k in 2:N
D[k, k-1] = -1.0
D[k, k] = 1.0
end
D .= 1/Δ .* D
return D
end
D_center = Dᶜ(n, Δz)
D_face = Dᶠ(n, Δz)
u0 = collect(range(0, 1, length=n))
function compute_diffusivity(∂u∂z)
return ifelse(∂u∂z < 0, 0.2, 1e-5)
end
α_initial = compute_diffusivity.(D_face * u0)
D = Tridiagonal(D_center * (α_initial .* D_face))
params = (top=3e-3, bottom=0., NN=NN, st=st, f=1e-4)
function rhs(u, p, t)
x′ = vcat(u, params.f)
residual_flux = vcat(params.bottom, first(params.NN(x′, p.ps, params.st)), params.top)
du = -D_center * residual_flux
return du
end
function update_diffusivity(A, u, p, t)
∂u∂z = D_face * u
α = compute_diffusivity.(∂u∂z)
return Tridiagonal(D_center * (α .* D_face))
end
D2 = SciMLBase.MatrixOperator(D, update_func=update_diffusivity)
ps_training = ComponentArray(;ps)
times = collect(0:0.1:1)
tspan = (times[1], times[end])
prob = SplitODEProblem(D2, rhs, u0, tspan, ps_training)
alg = IMEXEuler()
println("Solving...")
sol = solve(prob, alg, dt = 1e-3, saveat = times)
truth = rand(32, 11)
function loss(p)
prob = SplitODEProblem(D2, rhs, u0, tspan, p)
sol = Array(solve(prob, alg, dt = 1e-3, saveat = times))
return sum(abs2, sol - truth)
end
iter = 0
maxiter = 3
callback = function (p, l)
@printf("loss total %6.10e\n", l,)
return false
end
loss(ps_training)
adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x, p) -> loss(x), adtype)
optprob = Optimization.OptimizationProblem(optf, ps_training)
res = Optimization.solve(optprob, OptimizationOptimisers.Adam(1e-5), callback=callback, maxiters=maxiter)
When I use adtype = Optimization.AutoForwardDiff()
, the code runs without issues. However when I use adtype = Optimization.AutoReverseDiff()
or adtype = Optimization.AutoZygote()
, the error below is thrown:
ERROR: type ODEFunction has no field f1
Stacktrace:
getproperty at Base.jl
initialize!(integrator::OrdinaryDiffEq.ODEIntegrator{…}, cache::OrdinaryDiffEq.SBDFCache{…}) at bdf_perform_step.jl
__init(prob::ODEProblem{…}, alg::OrdinaryDiffEq.SBDF{…}, timeseries_init::Tuple{}, ts_init::Tuple{}, ks_init::Tuple{}, recompile::Type{…}; saveat::Vector{…}, tstops::Vector{…}, d_discontinuities::Tuple{}, save_idxs::Nothing, save_everystep::Bool, save_on::Bool, save_start::Bool, save_end::Nothing, callback::CallbackSet{…}, dense::Bool, calck::Bool, dt::Float64, dtmin::Float64, dtmax::Float64, force_dtmin::Bool, adaptive::Bool, gamma::Int64, abstol::Float64, reltol::Float64, qmin::Int64, qmax::Int64, qsteady_min::Int64, qsteady_max::Int64, beta1::Nothing, beta2::Nothing, qoldinit::Int64, controller::Nothing, fullnormalize::Bool, failfactor::Int64, maxiters::Int64, internalnorm::typeof(DiffEqBase.ODE_DEFAULT_NORM), internalopnorm::typeof(opnorm), isoutofdomain::typeof(DiffEqBase.ODE_DEFAULT_ISOUTOFDOMAIN), unstable_check::typeof(DiffEqBase.ODE_DEFAULT_UNSTABLE_CHECK), verbose::Bool, timeseries_errors::Bool, dense_errors::Bool, advance_to_tstop::Bool, stop_at_next_tstop::Bool, initialize_save::Bool, progress::Bool, progress_steps::Int64, progress_name::String, progress_message::typeof(DiffEqBase.ODE_DEFAULT_PROG_MESSAGE), progress_id::Symbol, userdata::Nothing, allow_extrapolation::Bool, initialize_integrator::Bool, alias_u0::Bool, alias_du0::Bool, initializealg::OrdinaryDiffEq.DefaultInit, kwargs::Base.Pairs{…}) at solve.jl
__init at solve.jl (repeats 5 times)
#__solve#757 at solve.jl
__solve at solve.jl
solve_call(_prob::ODEProblem{…}, args::OrdinaryDiffEq.SBDF{…}; merge_callbacks::Bool, kwargshandle::Nothing, kwargs::Base.Pairs{…}) at solve.jl
kwcall(::@NamedTuple{…}, ::typeof(DiffEqBase.solve_call), _prob::ODEProblem{…}, args::OrdinaryDiffEq.SBDF{…}) at solve.jl
solve_up(prob::ODEProblem{…}, sensealg::Nothing, u0::Vector{…}, p::ComponentVector{…}, args::OrdinaryDiffEq.SBDF{…}; kwargs::Base.Pairs{…}) at solve.jl
solve_up at solve.jl
solve(prob::ODEProblem{…}, args::OrdinaryDiffEq.SBDF{…}; sensealg::Nothing, u0::Nothing, p::Nothing, wrap::Val{…}, kwargs::Base.Pairs{…}) at solve.jl
kwcall(::@NamedTuple{…}, ::typeof(solve), prob::ODEProblem{…}, args::OrdinaryDiffEq.SBDF{…}) at solve.jl
_adjoint_sensitivities(sol::ODESolution{…}, sensealg::InterpolatingAdjoint{…}, alg::OrdinaryDiffEq.SBDF{…}; t::Vector{…}, dgdu_discrete::Function, dgdp_discrete::Nothing, dgdu_continuous::Nothing, dgdp_continuous::Nothing, g::Nothing, abstol::Float64, reltol::Float64, checkpoints::Vector{…}, corfunc_analytical::Nothing, callback::Nothing, kwargs::Base.Pairs{…}) at sensitivity_interface.jl
_adjoint_sensitivities at sensitivity_interface.jl
#adjoint_sensitivities#63 at sensitivity_interface.jl
kwcall(::@NamedTuple{…}, ::typeof(adjoint_sensitivities), sol::ODESolution{…}, args::OrdinaryDiffEq.SBDF{…}) at sensitivity_interface.jl
(::SciMLSensitivity.var"#adjoint_sensitivity_backpass#307"{…})(Δ::ODESolution{…}) at concrete_solve.jl
ZBack at chainrules.jl
(::Zygote.var"#kw_zpullback#53"{…})(dy::ODESolution{…}) at chainrules.jl
#291 at lib.jl
(::Zygote.var"#2169#back#293"{…})(Δ::ODESolution{…}) at adjoint.jl
#solve#41 at solve.jl
(::Zygote.Pullback{…})(Δ::ODESolution{…}) at interface2.jl
#291 at lib.jl
(::Zygote.var"#2169#back#293"{…})(Δ::ODESolution{…}) at adjoint.jl
solve at solve.jl
(::Zygote.Pullback{…})(Δ::ODESolution{…}) at interface2.jl
loss at IMEX_SplitODE_MWE.jl
(::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Float64) at interface2.jl
#11 at IMEX_SplitODE_MWE.jl
#291 at lib.jl
#2169#back at adjoint.jl
OptimizationFunction at scimlfunctions.jl
#291 at lib.jl
(::Zygote.var"#2169#back#293"{Zygote.var"#291#292"{Tuple{…}, Zygote.Pullback{…}}})(Δ::Float64) at adjoint.jl
#37 at OptimizationZygoteExt.jl
(::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Float64) at interface2.jl
#291 at lib.jl
#2169#back at adjoint.jl
#39 at OptimizationZygoteExt.jl
(::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Float64) at interface2.jl
(::Zygote.var"#75#76"{Zygote.Pullback{Tuple{…}, Tuple{…}}})(Δ::Float64) at interface.jl
gradient(f::Function, args::ComponentVector{Float64, Vector{Float64}, Tuple{Axis{…}}}) at interface.jl
(::OptimizationZygoteExt.var"#38#56"{…})(::ComponentVector{…}, ::ComponentVector{…}) at OptimizationZygoteExt.jl
macro expansion at OptimizationOptimisers.jl
macro expansion at utils.jl
__solve(cache::OptimizationCache{…}) at OptimizationOptimisers.jl
solve!(cache::OptimizationCache{…}) at solve.jl
solve(::OptimizationProblem{…}, ::Adam; kwargs::Base.Pairs{…}) at solve.jl
kwcall(::@NamedTuple{…}, ::typeof(solve), ::OptimizationProblem{…}, ::Adam) at solve.jl
top-level scope at IMEX_SplitODE_MWE.jl
Some type information was truncated. Use show(err) to see complete types.
ERROR: type ODEFunction has no field f1
Stacktrace:
getproperty at Base.jl
initialize!(integrator::OrdinaryDiffEq.ODEIntegrator{…}, cache::OrdinaryDiffEq.SBDFCache{…}) at bdf_perform_step.jl
__init(prob::ODEProblem{…}, alg::OrdinaryDiffEq.SBDF{…}, timeseries_init::Tuple{}, ts_init::Tuple{}, ks_init::Tuple{}, recompile::Type{…}; saveat::Vector{…}, tstops::Vector{…}, d_discontinuities::Tuple{}, save_idxs::Nothing, save_everystep::Bool, save_on::Bool, save_start::Bool, save_end::Nothing, callback::CallbackSet{…}, dense::Bool, calck::Bool, dt::Float64, dtmin::Float64, dtmax::Float64, force_dtmin::Bool, adaptive::Bool, gamma::Int64, abstol::Float64, reltol::Float64, qmin::Int64, qmax::Int64, qsteady_min::Int64, qsteady_max::Int64, beta1::Nothing, beta2::Nothing, qoldinit::Int64, controller::Nothing, fullnormalize::Bool, failfactor::Int64, maxiters::Int64, internalnorm::typeof(DiffEqBase.ODE_DEFAULT_NORM), internalopnorm::typeof(opnorm), isoutofdomain::typeof(DiffEqBase.ODE_DEFAULT_ISOUTOFDOMAIN), unstable_check::typeof(DiffEqBase.ODE_DEFAULT_UNSTABLE_CHECK), verbose::Bool, timeseries_errors::Bool, dense_errors::Bool, advance_to_tstop::Bool, stop_at_next_tstop::Bool, initialize_save::Bool, progress::Bool, progress_steps::Int64, progress_name::String, progress_message::typeof(DiffEqBase.ODE_DEFAULT_PROG_MESSAGE), progress_id::Symbol, userdata::Nothing, allow_extrapolation::Bool, initialize_integrator::Bool, alias_u0::Bool, alias_du0::Bool, initializealg::OrdinaryDiffEq.DefaultInit, kwargs::Base.Pairs{…}) at solve.jl
__init at solve.jl (repeats 5 times)
#__solve#757 at solve.jl
__solve at solve.jl
solve_call(_prob::ODEProblem{…}, args::OrdinaryDiffEq.SBDF{…}; merge_callbacks::Bool, kwargshandle::Nothing, kwargs::Base.Pairs{…}) at solve.jl
kwcall(::@NamedTuple{…}, ::typeof(DiffEqBase.solve_call), _prob::ODEProblem{…}, args::OrdinaryDiffEq.SBDF{…}) at solve.jl
solve_up(prob::ODEProblem{…}, sensealg::Nothing, u0::Vector{…}, p::ComponentVector{…}, args::OrdinaryDiffEq.SBDF{…}; kwargs::Base.Pairs{…}) at solve.jl
solve_up at solve.jl
solve(prob::ODEProblem{…}, args::OrdinaryDiffEq.SBDF{…}; sensealg::Nothing, u0::Nothing, p::Nothing, wrap::Val{…}, kwargs::Base.Pairs{…}) at solve.jl
kwcall(::@NamedTuple{…}, ::typeof(solve), prob::ODEProblem{…}, args::OrdinaryDiffEq.SBDF{…}) at solve.jl
_adjoint_sensitivities(sol::ODESolution{…}, sensealg::InterpolatingAdjoint{…}, alg::OrdinaryDiffEq.SBDF{…}; t::Vector{…}, dgdu_discrete::Function, dgdp_discrete::Nothing, dgdu_continuous::Nothing, dgdp_continuous::Nothing, g::Nothing, abstol::Float64, reltol::Float64, checkpoints::Vector{…}, corfunc_analytical::Nothing, callback::Nothing, kwargs::Base.Pairs{…}) at sensitivity_interface.jl
_adjoint_sensitivities at sensitivity_interface.jl
#adjoint_sensitivities#63 at sensitivity_interface.jl
kwcall(::@NamedTuple{…}, ::typeof(adjoint_sensitivities), sol::ODESolution{…}, args::OrdinaryDiffEq.SBDF{…}) at sensitivity_interface.jl
(::SciMLSensitivity.var"#adjoint_sensitivity_backpass#307"{…})(Δ::ODESolution{…}) at concrete_solve.jl
ZBack at chainrules.jl
(::Zygote.var"#kw_zpullback#53"{…})(dy::ODESolution{…}) at chainrules.jl
#291 at lib.jl
(::Zygote.var"#2169#back#293"{…})(Δ::ODESolution{…}) at adjoint.jl
#solve#41 at solve.jl
(::Zygote.Pullback{…})(Δ::ODESolution{…}) at interface2.jl
#291 at lib.jl
(::Zygote.var"#2169#back#293"{…})(Δ::ODESolution{…}) at adjoint.jl
solve at solve.jl
(::Zygote.Pullback{…})(Δ::ODESolution{…}) at interface2.jl
loss at IMEX_SplitODE_MWE.jl
(::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Float64) at interface2.jl
#11 at IMEX_SplitODE_MWE.jl
#291 at lib.jl
#2169#back at adjoint.jl
OptimizationFunction at scimlfunctions.jl
#291 at lib.jl
(::Zygote.var"#2169#back#293"{Zygote.var"#291#292"{Tuple{…}, Zygote.Pullback{…}}})(Δ::Float64) at adjoint.jl
#37 at OptimizationZygoteExt.jl
(::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Float64) at interface2.jl
#291 at lib.jl
#2169#back at adjoint.jl
#39 at OptimizationZygoteExt.jl
(::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Float64) at interface2.jl
(::Zygote.var"#75#76"{Zygote.Pullback{Tuple{…}, Tuple{…}}})(Δ::Float64) at interface.jl
gradient(f::Function, args::ComponentVector{Float64, Vector{Float64}, Tuple{Axis{…}}}) at interface.jl
(::OptimizationZygoteExt.var"#38#56"{…})(::ComponentVector{…}, ::ComponentVector{…}) at OptimizationZygoteExt.jl
macro expansion at OptimizationOptimisers.jl
macro expansion at utils.jl
__solve(cache::OptimizationCache{…}) at OptimizationOptimisers.jl
solve!(cache::OptimizationCache{…}) at solve.jl
solve(::OptimizationProblem{…}, ::Adam; kwargs::Base.Pairs{…}) at solve.jl
kwcall(::@NamedTuple{…}, ::typeof(solve), ::OptimizationProblem{…}, ::Adam) at solve.jl
top-level scope at IMEX_SplitODE_MWE.jl
When I use adtype = AutoEnzyme()
, the error below is thrown:
ERROR: Enzyme execution failed.
Mismatched activity for: store {} addrspace(10)* %.fca.0.extract3, {} addrspace(10)* addrspace(10)* %.sroa.04.0..sroa_cast, align 8, !dbg !15, !alias.scope !17, !noalias !21 const val: %.fca.0.extract3 = extractvalue { {} addrspace(10)* } %1, 0, !dbg !8
Type tree: {[-1]:Pointer, [-1,-1]:Pointer, [-1,-1,0]:Pointer, [-1,-1,0,-1]:Float@double, [-1,-1,8]:Integer, [-1,-1,9]:Integer, [-1,-1,10]:Integer, [-1,-1,11]:Integer, [-1,-1,12]:Integer, [-1,-1,13]:Integer, [-1,-1,14]:Integer, [-1,-1,15]:Integer, [-1,-1,16]:Integer, [-1,-1,17]:Integer, [-1,-1,18]:Integer, [-1,-1,19]:Integer, [-1,-1,20]:Integer, [-1,-1,21]:Integer, [-1,-1,22]:Integer, [-1,-1,23]:Integer, [-1,-1,24]:Integer, [-1,-1,25]:Integer, [-1,-1,26]:Integer, [-1,-1,27]:Integer, [-1,-1,28]:Integer, [-1,-1,29]:Integer, [-1,-1,30]:Integer, [-1,-1,31]:Integer, [-1,-1,32]:Integer, [-1,-1,33]:Integer, [-1,-1,34]:Integer, [-1,-1,35]:Integer, [-1,-1,36]:Integer, [-1,-1,37]:Integer, [-1,-1,38]:Integer, [-1,-1,39]:Integer}
You may be using a constant variable as temporary storage for active memory (https://enzyme.mit.edu/julia/stable/#Activity-of-temporary-storage). If not, please open an issue, and either rewrite this variable to not be conditionally active or use Enzyme.API.runtimeActivity!(true) as a workaround for now
Stacktrace:
[1] SplitODEProblem
@ C:\Users\xinle\.julia\packages\SciMLBase\Dwomw\src\problems\ode_problems.jl:444
[2] SplitODEProblem
@ C:\Users\xinle\.julia\packages\SciMLBase\Dwomw\src\problems\ode_problems.jl:0
Stacktrace:
throwerr(cstr::Cstring) at compiler.jl
Thank you so much in advance for your help! I am trying to train the Neural ODEs in a more performant way, hence I am trying the SplitODEProblem
with the hope of using Zygote or Enzyme.