I’m revisiting some code I was working on a while ago and rewrite it from Flux to Lux.
Previously, I was trying to use DiffEqFlux to solve for a neural network U where my differential equation contains both U([x,y,t],p) as well as its derivative with respect to one of those variables, say \frac{\partial}{\partial t} U([x,y,t],p) or \frac{\partial}{\partial x} U([x,y,t],p).
An example of what these equations might look like is in a question I posted a year ago here, where I was able to get the code to run.
After converting this code to use Lux, I was still getting errors I didn’t understand. I know now that they are a result of trying to compute the Jacobian within the differential equation function. Using the Simultaneous Fitting of Multiple Neural Networks example, I’ve added two lines of code to illustrate my problem.
using Lux, DiffEqFlux, Optimization, OptimizationOptimJL, DifferentialEquations, Random
rng = Random.default_rng()
Random.seed!(rng,1)
function fitz(du,u,p,t)
v,w = u
a,b,τinv,l = p
du[1] = v - v^3/3 -w + l
du[2] = τinv*(v + a - b*w)
end
p_ = Float32[0.7,0.8,1/12.5,0.5]
u0 = [1f0;1f0]
tspan = (0f0,10f0)
prob = ODEProblem(fitz,u0,tspan,p_)
sol = solve(prob, Tsit5(), saveat = 0.5 )
# Ideal data
X = Array(sol)
Xₙ = X + Float32(1e-3)*randn(eltype(X), size(X)) #noisy data
# For xz term
NN_1 = Lux.Chain(Lux.Dense(2, 16, tanh), Lux.Dense(16, 1))
p1,st1 = Lux.setup(rng, NN_1)
# for xy term
NN_2 = Lux.Chain(Lux.Dense(3, 16, tanh), Lux.Dense(16, 1))
p2, st2 = Lux.setup(rng, NN_2)
scaling_factor = 1f0
p1 = Lux.ComponentArray(p1)
p2 = Lux.ComponentArray(p2)
p = Lux.ComponentArray{eltype(p1)}()
p = Lux.ComponentArray(p;p1)
p = Lux.ComponentArray(p;p2)
p = Lux.ComponentArray(p;scaling_factor)
function dudt_(u,p,t)
v,w = u
z1 = NN_1([v,w], p.p1, st1)[1]
z2 = NN_2([v,w,t], p.p2, st2)[1]
A = [v,w,t]
jac_temp = jacobian(A->NN_2(A,p.p2,st2)[1],A)[1]
[z1[1],p.scaling_factor*z2[1]]
end
prob_nn = ODEProblem(dudt_,u0, tspan, p)
sol_nn = solve(prob_nn, Tsit5(),saveat = sol.t)
function predict(θ)
Array(solve(prob_nn, Vern7(), p=θ, saveat = sol.t,
abstol=1e-6, reltol=1e-6,
sensealg = InterpolatingAdjoint(autojacvec=ReverseDiffVJP(true))))
end
# No regularisation right now
function loss(θ)
pred = predict(θ)
sum(abs2, Xₙ .- pred), pred
end
loss(p)
const losses = []
callback(θ,l,pred) = begin
push!(losses, l)
if length(losses)%50==0
println(losses[end])
end
false
end
adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x,p) -> loss(x), adtype)
optprob = Optimization.OptimizationProblem(optf, p)
res1_uode = Optimization.solve(optprob, ADAM(0.01), callback=callback, maxiters = 50)
I’ve added two lines of code within the dudt_() function that differ from the DiffEqFlux example. With these, I’m trying to compute the jacobian.
A = [v,w,t]
jac_temp = jacobian(A->NN_2(A,p.p2,st2)[1],A)[1]
At this stage, I’m not even trying to do anything with the jacobian – I just want to calculate it. These lines of code work if I just run them by themselves. However, within the optimization, I get the following error message:
ERROR: Need an adjoint for constructor ReverseDiff.TrackedArray{Float32, Float32, 2, Base.ReshapedArray{Float32, 2, SubArray{Float32, 1, Vector{Float32}, Tuple{UnitRange{Int64}}, true}, Tuple{}}, Base.ReshapedArray{Float32, 2, SubArray{Float32, 1, Vector{Float32}, Tuple{UnitRange{Int64}}, true}, Tuple{}}}. Gradient is of type Matrix{ReverseDiff.TrackedReal{Float32, Float32, ReverseDiff.TrackedArray{Float32, Float32, 1, Vector{Float32}, Vector{Float32}}}}
Any help would be great. Thanks!