AD of Neural ODEs with `SplitODEProblem` using Zygote or Enzyme

Hello, I am trying to perform AD in the Neural ODE setup using a SplitODEProblem as it is a PDE discretization of the diffusion equation which is rather stiff when using explicit timesteppers. The neural network appears on the rhs part of the SplitODEProblem to capture missing fluxes. Here is a MWE:

using LinearAlgebra
using DiffEqBase
using OrdinaryDiffEq: SplitODEProblem, solve, IMEXEuler
import SciMLBase
using Lux, OptimizationOptimisers, ComponentArrays, Random, SciMLSensitivity
using Optimization
using Printf

n = 32
zC = collect(1.:1:32)
Δz = zC[2] - zC[1]

rng = Random.default_rng()

NN = Chain(Dense(33, 10, leakyrelu), Dense(10, 31))
ps, st = Lux.setup(rng, NN)

ps = ps |> ComponentArray .|> Float64
ps .*= 0

function Dᶜ(N, Δ)
    D = zeros(N, N+1)
    for k in 1:N
        D[k, k]   = -1.0
        D[k, k+1] =  1.0
    end
    D .= 1/Δ .* D
    return D
end

function Dᶠ(N, Δ)
    D = zeros(N+1, N)
    for k in 2:N
        D[k, k-1] = -1.0
        D[k, k]   =  1.0
    end
    D .= 1/Δ .* D
    return D
end

D_center = Dᶜ(n, Δz)
D_face = Dᶠ(n, Δz)

u0 = collect(range(0, 1, length=n))

function compute_diffusivity(∂u∂z)
    return ifelse(∂u∂z < 0, 0.2, 1e-5)
end

α_initial = compute_diffusivity.(D_face * u0)

D = Tridiagonal(D_center * (α_initial .* D_face))

params = (top=3e-3, bottom=0., NN=NN, st=st, f=1e-4)

function rhs(u, p, t)
    x′ = vcat(u, params.f)

    residual_flux = vcat(params.bottom, first(params.NN(x′, p.ps, params.st)), params.top)

    du = -D_center * residual_flux
    return du
end

function update_diffusivity(A, u, p, t)
    ∂u∂z = D_face * u
    α = compute_diffusivity.(∂u∂z)
    return Tridiagonal(D_center * (α .* D_face))
end

D2 = SciMLBase.MatrixOperator(D, update_func=update_diffusivity)

ps_training = ComponentArray(;ps)

times = collect(0:0.1:1)
tspan = (times[1], times[end])

prob = SplitODEProblem(D2, rhs, u0, tspan, ps_training)

alg = IMEXEuler()
println("Solving...")
sol = solve(prob, alg, dt = 1e-3, saveat = times)

truth = rand(32, 11)

function loss(p)
    prob = SplitODEProblem(D2, rhs, u0, tspan, p)
    sol = Array(solve(prob, alg, dt = 1e-3, saveat = times))
    return sum(abs2, sol - truth)
end

iter = 0
maxiter = 3

callback = function (p, l)
    @printf("loss total %6.10e\n", l,)
    return false
end

loss(ps_training)

adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x, p) -> loss(x), adtype)
optprob = Optimization.OptimizationProblem(optf, ps_training)
res = Optimization.solve(optprob, OptimizationOptimisers.Adam(1e-5), callback=callback, maxiters=maxiter)

When I use adtype = Optimization.AutoForwardDiff(), the code runs without issues. However when I use adtype = Optimization.AutoReverseDiff() or adtype = Optimization.AutoZygote(), the error below is thrown:

ERROR: type ODEFunction has no field f1
Stacktrace:

getproperty at Base.jl

initialize!(integrator::OrdinaryDiffEq.ODEIntegrator{…}, cache::OrdinaryDiffEq.SBDFCache{…}) at bdf_perform_step.jl

__init(prob::ODEProblem{…}, alg::OrdinaryDiffEq.SBDF{…}, timeseries_init::Tuple{}, ts_init::Tuple{}, ks_init::Tuple{}, recompile::Type{…}; saveat::Vector{…}, tstops::Vector{…}, d_discontinuities::Tuple{}, save_idxs::Nothing, save_everystep::Bool, save_on::Bool, save_start::Bool, save_end::Nothing, callback::CallbackSet{…}, dense::Bool, calck::Bool, dt::Float64, dtmin::Float64, dtmax::Float64, force_dtmin::Bool, adaptive::Bool, gamma::Int64, abstol::Float64, reltol::Float64, qmin::Int64, qmax::Int64, qsteady_min::Int64, qsteady_max::Int64, beta1::Nothing, beta2::Nothing, qoldinit::Int64, controller::Nothing, fullnormalize::Bool, failfactor::Int64, maxiters::Int64, internalnorm::typeof(DiffEqBase.ODE_DEFAULT_NORM), internalopnorm::typeof(opnorm), isoutofdomain::typeof(DiffEqBase.ODE_DEFAULT_ISOUTOFDOMAIN), unstable_check::typeof(DiffEqBase.ODE_DEFAULT_UNSTABLE_CHECK), verbose::Bool, timeseries_errors::Bool, dense_errors::Bool, advance_to_tstop::Bool, stop_at_next_tstop::Bool, initialize_save::Bool, progress::Bool, progress_steps::Int64, progress_name::String, progress_message::typeof(DiffEqBase.ODE_DEFAULT_PROG_MESSAGE), progress_id::Symbol, userdata::Nothing, allow_extrapolation::Bool, initialize_integrator::Bool, alias_u0::Bool, alias_du0::Bool, initializealg::OrdinaryDiffEq.DefaultInit, kwargs::Base.Pairs{…}) at solve.jl

__init at solve.jl (repeats 5 times)

#__solve#757 at solve.jl

__solve at solve.jl

solve_call(_prob::ODEProblem{…}, args::OrdinaryDiffEq.SBDF{…}; merge_callbacks::Bool, kwargshandle::Nothing, kwargs::Base.Pairs{…}) at solve.jl

kwcall(::@NamedTuple{…}, ::typeof(DiffEqBase.solve_call), _prob::ODEProblem{…}, args::OrdinaryDiffEq.SBDF{…}) at solve.jl

solve_up(prob::ODEProblem{…}, sensealg::Nothing, u0::Vector{…}, p::ComponentVector{…}, args::OrdinaryDiffEq.SBDF{…}; kwargs::Base.Pairs{…}) at solve.jl

solve_up at solve.jl

solve(prob::ODEProblem{…}, args::OrdinaryDiffEq.SBDF{…}; sensealg::Nothing, u0::Nothing, p::Nothing, wrap::Val{…}, kwargs::Base.Pairs{…}) at solve.jl

kwcall(::@NamedTuple{…}, ::typeof(solve), prob::ODEProblem{…}, args::OrdinaryDiffEq.SBDF{…}) at solve.jl

_adjoint_sensitivities(sol::ODESolution{…}, sensealg::InterpolatingAdjoint{…}, alg::OrdinaryDiffEq.SBDF{…}; t::Vector{…}, dgdu_discrete::Function, dgdp_discrete::Nothing, dgdu_continuous::Nothing, dgdp_continuous::Nothing, g::Nothing, abstol::Float64, reltol::Float64, checkpoints::Vector{…}, corfunc_analytical::Nothing, callback::Nothing, kwargs::Base.Pairs{…}) at sensitivity_interface.jl

_adjoint_sensitivities at sensitivity_interface.jl

#adjoint_sensitivities#63 at sensitivity_interface.jl

kwcall(::@NamedTuple{…}, ::typeof(adjoint_sensitivities), sol::ODESolution{…}, args::OrdinaryDiffEq.SBDF{…}) at sensitivity_interface.jl

(::SciMLSensitivity.var"#adjoint_sensitivity_backpass#307"{…})(Δ::ODESolution{…}) at concrete_solve.jl

ZBack at chainrules.jl

(::Zygote.var"#kw_zpullback#53"{…})(dy::ODESolution{…}) at chainrules.jl

#291 at lib.jl

(::Zygote.var"#2169#back#293"{…})(Δ::ODESolution{…}) at adjoint.jl

#solve#41 at solve.jl

(::Zygote.Pullback{…})(Δ::ODESolution{…}) at interface2.jl

#291 at lib.jl

(::Zygote.var"#2169#back#293"{…})(Δ::ODESolution{…}) at adjoint.jl

solve at solve.jl

(::Zygote.Pullback{…})(Δ::ODESolution{…}) at interface2.jl

loss at IMEX_SplitODE_MWE.jl

(::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Float64) at interface2.jl

#11 at IMEX_SplitODE_MWE.jl

#291 at lib.jl

#2169#back at adjoint.jl

OptimizationFunction at scimlfunctions.jl

#291 at lib.jl

(::Zygote.var"#2169#back#293"{Zygote.var"#291#292"{Tuple{…}, Zygote.Pullback{…}}})(Δ::Float64) at adjoint.jl

#37 at OptimizationZygoteExt.jl

(::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Float64) at interface2.jl

#291 at lib.jl

#2169#back at adjoint.jl

#39 at OptimizationZygoteExt.jl

(::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Float64) at interface2.jl

(::Zygote.var"#75#76"{Zygote.Pullback{Tuple{…}, Tuple{…}}})(Δ::Float64) at interface.jl

gradient(f::Function, args::ComponentVector{Float64, Vector{Float64}, Tuple{Axis{…}}}) at interface.jl

(::OptimizationZygoteExt.var"#38#56"{…})(::ComponentVector{…}, ::ComponentVector{…}) at OptimizationZygoteExt.jl

macro expansion at OptimizationOptimisers.jl

macro expansion at utils.jl

__solve(cache::OptimizationCache{…}) at OptimizationOptimisers.jl

solve!(cache::OptimizationCache{…}) at solve.jl

solve(::OptimizationProblem{…}, ::Adam; kwargs::Base.Pairs{…}) at solve.jl

kwcall(::@NamedTuple{…}, ::typeof(solve), ::OptimizationProblem{…}, ::Adam) at solve.jl

top-level scope at IMEX_SplitODE_MWE.jl

Some type information was truncated. Use show(err) to see complete types.

ERROR: type ODEFunction has no field f1
Stacktrace:

getproperty at Base.jl

initialize!(integrator::OrdinaryDiffEq.ODEIntegrator{…}, cache::OrdinaryDiffEq.SBDFCache{…}) at bdf_perform_step.jl

__init(prob::ODEProblem{…}, alg::OrdinaryDiffEq.SBDF{…}, timeseries_init::Tuple{}, ts_init::Tuple{}, ks_init::Tuple{}, recompile::Type{…}; saveat::Vector{…}, tstops::Vector{…}, d_discontinuities::Tuple{}, save_idxs::Nothing, save_everystep::Bool, save_on::Bool, save_start::Bool, save_end::Nothing, callback::CallbackSet{…}, dense::Bool, calck::Bool, dt::Float64, dtmin::Float64, dtmax::Float64, force_dtmin::Bool, adaptive::Bool, gamma::Int64, abstol::Float64, reltol::Float64, qmin::Int64, qmax::Int64, qsteady_min::Int64, qsteady_max::Int64, beta1::Nothing, beta2::Nothing, qoldinit::Int64, controller::Nothing, fullnormalize::Bool, failfactor::Int64, maxiters::Int64, internalnorm::typeof(DiffEqBase.ODE_DEFAULT_NORM), internalopnorm::typeof(opnorm), isoutofdomain::typeof(DiffEqBase.ODE_DEFAULT_ISOUTOFDOMAIN), unstable_check::typeof(DiffEqBase.ODE_DEFAULT_UNSTABLE_CHECK), verbose::Bool, timeseries_errors::Bool, dense_errors::Bool, advance_to_tstop::Bool, stop_at_next_tstop::Bool, initialize_save::Bool, progress::Bool, progress_steps::Int64, progress_name::String, progress_message::typeof(DiffEqBase.ODE_DEFAULT_PROG_MESSAGE), progress_id::Symbol, userdata::Nothing, allow_extrapolation::Bool, initialize_integrator::Bool, alias_u0::Bool, alias_du0::Bool, initializealg::OrdinaryDiffEq.DefaultInit, kwargs::Base.Pairs{…}) at solve.jl

__init at solve.jl (repeats 5 times)

#__solve#757 at solve.jl

__solve at solve.jl

solve_call(_prob::ODEProblem{…}, args::OrdinaryDiffEq.SBDF{…}; merge_callbacks::Bool, kwargshandle::Nothing, kwargs::Base.Pairs{…}) at solve.jl

kwcall(::@NamedTuple{…}, ::typeof(DiffEqBase.solve_call), _prob::ODEProblem{…}, args::OrdinaryDiffEq.SBDF{…}) at solve.jl

solve_up(prob::ODEProblem{…}, sensealg::Nothing, u0::Vector{…}, p::ComponentVector{…}, args::OrdinaryDiffEq.SBDF{…}; kwargs::Base.Pairs{…}) at solve.jl

solve_up at solve.jl

solve(prob::ODEProblem{…}, args::OrdinaryDiffEq.SBDF{…}; sensealg::Nothing, u0::Nothing, p::Nothing, wrap::Val{…}, kwargs::Base.Pairs{…}) at solve.jl

kwcall(::@NamedTuple{…}, ::typeof(solve), prob::ODEProblem{…}, args::OrdinaryDiffEq.SBDF{…}) at solve.jl

_adjoint_sensitivities(sol::ODESolution{…}, sensealg::InterpolatingAdjoint{…}, alg::OrdinaryDiffEq.SBDF{…}; t::Vector{…}, dgdu_discrete::Function, dgdp_discrete::Nothing, dgdu_continuous::Nothing, dgdp_continuous::Nothing, g::Nothing, abstol::Float64, reltol::Float64, checkpoints::Vector{…}, corfunc_analytical::Nothing, callback::Nothing, kwargs::Base.Pairs{…}) at sensitivity_interface.jl

_adjoint_sensitivities at sensitivity_interface.jl

#adjoint_sensitivities#63 at sensitivity_interface.jl

kwcall(::@NamedTuple{…}, ::typeof(adjoint_sensitivities), sol::ODESolution{…}, args::OrdinaryDiffEq.SBDF{…}) at sensitivity_interface.jl

(::SciMLSensitivity.var"#adjoint_sensitivity_backpass#307"{…})(Δ::ODESolution{…}) at concrete_solve.jl

ZBack at chainrules.jl

(::Zygote.var"#kw_zpullback#53"{…})(dy::ODESolution{…}) at chainrules.jl

#291 at lib.jl

(::Zygote.var"#2169#back#293"{…})(Δ::ODESolution{…}) at adjoint.jl

#solve#41 at solve.jl

(::Zygote.Pullback{…})(Δ::ODESolution{…}) at interface2.jl

#291 at lib.jl

(::Zygote.var"#2169#back#293"{…})(Δ::ODESolution{…}) at adjoint.jl

solve at solve.jl

(::Zygote.Pullback{…})(Δ::ODESolution{…}) at interface2.jl

loss at IMEX_SplitODE_MWE.jl

(::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Float64) at interface2.jl

#11 at IMEX_SplitODE_MWE.jl

#291 at lib.jl

#2169#back at adjoint.jl

OptimizationFunction at scimlfunctions.jl

#291 at lib.jl

(::Zygote.var"#2169#back#293"{Zygote.var"#291#292"{Tuple{…}, Zygote.Pullback{…}}})(Δ::Float64) at adjoint.jl

#37 at OptimizationZygoteExt.jl

(::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Float64) at interface2.jl

#291 at lib.jl

#2169#back at adjoint.jl

#39 at OptimizationZygoteExt.jl

(::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Float64) at interface2.jl

(::Zygote.var"#75#76"{Zygote.Pullback{Tuple{…}, Tuple{…}}})(Δ::Float64) at interface.jl

gradient(f::Function, args::ComponentVector{Float64, Vector{Float64}, Tuple{Axis{…}}}) at interface.jl

(::OptimizationZygoteExt.var"#38#56"{…})(::ComponentVector{…}, ::ComponentVector{…}) at OptimizationZygoteExt.jl

macro expansion at OptimizationOptimisers.jl

macro expansion at utils.jl

__solve(cache::OptimizationCache{…}) at OptimizationOptimisers.jl

solve!(cache::OptimizationCache{…}) at solve.jl

solve(::OptimizationProblem{…}, ::Adam; kwargs::Base.Pairs{…}) at solve.jl

kwcall(::@NamedTuple{…}, ::typeof(solve), ::OptimizationProblem{…}, ::Adam) at solve.jl

top-level scope at IMEX_SplitODE_MWE.jl

When I use adtype = AutoEnzyme(), the error below is thrown:

ERROR: Enzyme execution failed.
Mismatched activity for:   store {} addrspace(10)* %.fca.0.extract3, {} addrspace(10)* addrspace(10)* %.sroa.04.0..sroa_cast, align 8, !dbg !15, !alias.scope !17, !noalias !21 const val:   %.fca.0.extract3 = extractvalue { {} addrspace(10)* } %1, 0, !dbg !8
Type tree: {[-1]:Pointer, [-1,-1]:Pointer, [-1,-1,0]:Pointer, [-1,-1,0,-1]:Float@double, [-1,-1,8]:Integer, [-1,-1,9]:Integer, [-1,-1,10]:Integer, [-1,-1,11]:Integer, [-1,-1,12]:Integer, [-1,-1,13]:Integer, [-1,-1,14]:Integer, [-1,-1,15]:Integer, [-1,-1,16]:Integer, [-1,-1,17]:Integer, [-1,-1,18]:Integer, [-1,-1,19]:Integer, [-1,-1,20]:Integer, [-1,-1,21]:Integer, [-1,-1,22]:Integer, [-1,-1,23]:Integer, [-1,-1,24]:Integer, [-1,-1,25]:Integer, [-1,-1,26]:Integer, [-1,-1,27]:Integer, [-1,-1,28]:Integer, [-1,-1,29]:Integer, [-1,-1,30]:Integer, [-1,-1,31]:Integer, [-1,-1,32]:Integer, [-1,-1,33]:Integer, [-1,-1,34]:Integer, [-1,-1,35]:Integer, [-1,-1,36]:Integer, [-1,-1,37]:Integer, [-1,-1,38]:Integer, [-1,-1,39]:Integer}
You may be using a constant variable as temporary storage for active memory (https://enzyme.mit.edu/julia/stable/#Activity-of-temporary-storage). If not, please open an issue, and either rewrite this variable to not be conditionally active or use Enzyme.API.runtimeActivity!(true) as a workaround for now

Stacktrace:
 [1] SplitODEProblem
   @ C:\Users\xinle\.julia\packages\SciMLBase\Dwomw\src\problems\ode_problems.jl:444
 [2] SplitODEProblem
   @ C:\Users\xinle\.julia\packages\SciMLBase\Dwomw\src\problems\ode_problems.jl:0

Stacktrace:

throwerr(cstr::Cstring) at compiler.jl

Thank you so much in advance for your help! I am trying to train the Neural ODEs in a more performant way, hence I am trying the SplitODEProblem with the hope of using Zygote or Enzyme.

I think there is already an issue on this. It is indeed a missing part of the adjoints that it doesn’t support a split ODE problem with a linear operator as the first part. It is something that is possible to support but we just haven’t gotten around to it yet.