Hello, I have an MRE where if I use any sensitivity algorithm apart from ForwardDiffSensitivity() I get an error. In particular, when I use the sensitivity algorithm InterpolatingAdjoint(autojacvec=ZygoteVJP()) I get the following error " MethodError: no method matching vec(::Nothing)".
Here is my MRE:
cd(@__DIR__)
using Pkg
Pkg.activate(".")
using Lux
using ComponentArrays
using Zygote
using ForwardDiff
using DifferentialEquations
using Optimization
using OptimizationOptimJL
using OptimizationOptimisers
using Random
using CairoMakie
# Simulate Data
function true_ode(du,u, p, t)
x, y = u
du[1] = -y
du[2] = x
end
# number of simulated trajectories
Ns = 50
u0s = randn(2, Ns)
tspan = (0.0, 10.0)
Nt = 100
saveat = LinRange(tspan[1], tspan[2], Nt)
sols = zeros(2*Nt, Ns)
for i = 1:Ns
println(i)
prob = ODEProblem(true_ode, u0s[:, i], tspan)
sols[:, i] = Array(solve(prob, Tsit5(); saveat=saveat))[:]
end
# Define Neural ODE
weight_init_mag = 0.1
f = Lux.Chain(Lux.Dense(2, 10, selu; init_weight=Lux.glorot_uniform(gain=weight_init_mag)),
Lux.Dense(10, 10, selu; init_weight=Lux.glorot_uniform(gain=weight_init_mag)),
Lux.Dense(10, 10, selu; init_weight=Lux.glorot_uniform(gain=weight_init_mag)),
Lux.Dense(10, 2; init_weight=Lux.glorot_uniform(gain=weight_init_mag)))
rng = MersenneTwister(1111)
ps, st = Lux.setup(rng, f)
ps = ComponentArray(ps)
function f_ode(du, u, p, t)
du, _ = f(u, p, st)
end
# Define Neural ODE Objective Function
#sensealg = ForwardDiffSensitivity()
#sensealg = BacksolveAdjoint(autojacvec=ReverseDiffVJP())
#sensealg = QuadratureAdjoint(autojacvec=ReverseDiffVJP(true))
sensealg = InterpolatingAdjoint(autojacvec=ZygoteVJP())
function cost(p)
loss = 0
for i = 1:Ns
prob = ODEProblem(f_ode, u0s[:, i], tspan, p)
pred = Array(solve(prob, Tsit5(); sensealg=sensealg, saveat=saveat))[:]
loss += sum(abs2, pred - sols[:, i]) / Nt
end
return loss
end
# Train Loop
opt = Optimisers.ADAM(1e-2)
st_opt = Optimisers.setup(opt, ps)
numiters = Int(1e3)
for iter = 1:numiters
println(iter)
@time begin
l, back = pullback(cost, ps)
gs = back(one(l))[1]
st_opt, ps = Optimisers.update(st_opt, ps, gs)
end
end