Zygote Type error when using Rodas5() solver but only when manually setting sensealg

I’m dealing with an error here that I’m unable to understand. It would be great if someone more experienced could help me understand what’s going on here.

It happens when I try to do pullback on a loss function with:

  1. Rodas5() solver

  2. Lux neural network specifying the ODE. When I specify the ODE without Lux the gradients are computed no problem.

  3. Setting sensealg manually to InterpolatingAdjoint(; autojacvec=ZygoteVJP()). If I let the solver decide it automatically, even when using a neural network I get no error.

I found something similar in the FAQ (Frequently Asked Questions · DifferentialEquations.jl) but I’m not sure where the tmp cache is sneaking in and why the error is so specific.

I’ve attached an MWE below:

using DifferentialEquations, SciMLSensitivity, Zygote, PreallocationTools, Lux, Random, ComponentArrays

rng = Random.default_rng()
Random.seed!(rng, 0)

basic_tgrad(u, p, t) = zero(u)

############################ Lux NN #################################################

m = Lux.Dense(2, 1, tanh)
ps, st = Lux.setup(rng, m)
ps = ComponentArray(ps)

m(Float32.([1.f0, 0.f0]), ps, st)

############################### Creating ODE Function ###################################

function f(u, p, t)
    du_1 = m(u, p, st)[1]
    du_2 = u[2]
    return [du_1 ; du_2]
end

mass_matrix = Float32.([1.0 0.0; 0.0 0.0])

f_ = ODEFunction{false}(f, mass_matrix=mass_matrix, tgrad=basic_tgrad)

function g(u, p, t)
    du_1 = p * u[1]
    du_2 = u[2]
    return [du_1; du_2]
end

g_ = ODEFunction{false}(g, mass_matrix=mass_matrix, tgrad=basic_tgrad)

################################# Solve/loss Function ###########################################

function solve_de(p; sense=true, func=f_)
    prob = ODEProblem{false}(func, [1.f0, 1.f0], (0.f0, 1.f0), p)
    if sense
        sensealg=InterpolatingAdjoint(; autojacvec=ZygoteVJP())
        return solve(prob, Rodas5(); saveat=0.1, sensealg=sensealg)
    else
        return solve(prob, Rodas5(); saveat=0.1)
    end
end

solve_de(ps)

function loss(p; sense=true, func=f_)
    pred = Array(solve_de(p; sense=sense, func=func))
    return sum(abs2, pred)
end

loss(ps)

##################################### Autograd ###################################

#Lux NN. Sensealg is manual
(l_1,), back_1 = pullback(p -> loss(p), ps)
back_1(one(l_1)) #Error: Expected Float32, Got ForwardDiff.Dual

# Lux NN. Sensealg is auto
(l_2,), back_2 = pullback(p -> loss(p; sense=false), ps)
back_2(one(l_2)) #This works

# No neural net. Sensealg is manual
(l_3,), back_3 = pullback(p -> loss(p; sense=true, func=g_), [2.f0])
back_3(one(l_3)) #This works

# No neural net. Sensealg is auto
(l_4,), back_4 = pullback(p -> loss(p; sense=false, func=g_), [2.f0])
back_4(one(l_4)) #This works

I saw that in the tutorial here (MNIST Classification using NeuralODE - Lux.jl), sensealg was set automatically. I’m assuming since it used the Tsit5() solver there was no error.

What advantages does setting sensealg manually confer? Speed?

Any directions in helping me understand this are appreciated greatly. Thank you very much!

Can you open an issue on SciMLSensitivity.jl and share the full stack trace? I think this might be related to Enzyme stuff I need to handle.

Hi! Sorry for the slow response. And yes, of course. I’ll do it ASAP.

Hi. I’ve created the issue here: Zygote Type error when using Rodas5() solver but only when manually setting sensealg · Issue #694 · SciML/SciMLSensitivity.jl · GitHub