I’m dealing with an error here that I’m unable to understand. It would be great if someone more experienced could help me understand what’s going on here.
It happens when I try to do pullback on a loss function with:
-
Rodas5() solver
-
Lux neural network specifying the ODE. When I specify the ODE without Lux the gradients are computed no problem.
-
Setting sensealg manually to InterpolatingAdjoint(; autojacvec=ZygoteVJP()). If I let the solver decide it automatically, even when using a neural network I get no error.
I found something similar in the FAQ (Frequently Asked Questions · DifferentialEquations.jl) but I’m not sure where the tmp cache is sneaking in and why the error is so specific.
I’ve attached an MWE below:
using DifferentialEquations, SciMLSensitivity, Zygote, PreallocationTools, Lux, Random, ComponentArrays
rng = Random.default_rng()
Random.seed!(rng, 0)
basic_tgrad(u, p, t) = zero(u)
############################ Lux NN #################################################
m = Lux.Dense(2, 1, tanh)
ps, st = Lux.setup(rng, m)
ps = ComponentArray(ps)
m(Float32.([1.f0, 0.f0]), ps, st)
############################### Creating ODE Function ###################################
function f(u, p, t)
du_1 = m(u, p, st)[1]
du_2 = u[2]
return [du_1 ; du_2]
end
mass_matrix = Float32.([1.0 0.0; 0.0 0.0])
f_ = ODEFunction{false}(f, mass_matrix=mass_matrix, tgrad=basic_tgrad)
function g(u, p, t)
du_1 = p * u[1]
du_2 = u[2]
return [du_1; du_2]
end
g_ = ODEFunction{false}(g, mass_matrix=mass_matrix, tgrad=basic_tgrad)
################################# Solve/loss Function ###########################################
function solve_de(p; sense=true, func=f_)
prob = ODEProblem{false}(func, [1.f0, 1.f0], (0.f0, 1.f0), p)
if sense
sensealg=InterpolatingAdjoint(; autojacvec=ZygoteVJP())
return solve(prob, Rodas5(); saveat=0.1, sensealg=sensealg)
else
return solve(prob, Rodas5(); saveat=0.1)
end
end
solve_de(ps)
function loss(p; sense=true, func=f_)
pred = Array(solve_de(p; sense=sense, func=func))
return sum(abs2, pred)
end
loss(ps)
##################################### Autograd ###################################
#Lux NN. Sensealg is manual
(l_1,), back_1 = pullback(p -> loss(p), ps)
back_1(one(l_1)) #Error: Expected Float32, Got ForwardDiff.Dual
# Lux NN. Sensealg is auto
(l_2,), back_2 = pullback(p -> loss(p; sense=false), ps)
back_2(one(l_2)) #This works
# No neural net. Sensealg is manual
(l_3,), back_3 = pullback(p -> loss(p; sense=true, func=g_), [2.f0])
back_3(one(l_3)) #This works
# No neural net. Sensealg is auto
(l_4,), back_4 = pullback(p -> loss(p; sense=false, func=g_), [2.f0])
back_4(one(l_4)) #This works
I saw that in the tutorial here (MNIST Classification using NeuralODE - Lux.jl), sensealg was set automatically. I’m assuming since it used the Tsit5() solver there was no error.
What advantages does setting sensealg manually confer? Speed?
Any directions in helping me understand this are appreciated greatly. Thank you very much!