Computing the jacobian of NN while using Lux

I’m revisiting some code I was working on a while ago and rewrite it from Flux to Lux.
Previously, I was trying to use DiffEqFlux to solve for a neural network U where my differential equation contains both U([x,y,t],p) as well as its derivative with respect to one of those variables, say \frac{\partial}{\partial t} U([x,y,t],p) or \frac{\partial}{\partial x} U([x,y,t],p).

An example of what these equations might look like is in a question I posted a year ago here, where I was able to get the code to run.

After converting this code to use Lux, I was still getting errors I didn’t understand. I know now that they are a result of trying to compute the Jacobian within the differential equation function. Using the Simultaneous Fitting of Multiple Neural Networks example, I’ve added two lines of code to illustrate my problem.

using Lux, DiffEqFlux, Optimization, OptimizationOptimJL, DifferentialEquations, Random

rng = Random.default_rng()

function fitz(du,u,p,t)
  v,w = u
  a,b,τinv,l = p
  du[1] = v - v^3/3 -w + l
  du[2] = τinv*(v +  a - b*w)

p_ = Float32[0.7,0.8,1/12.5,0.5]
u0 = [1f0;1f0]
tspan = (0f0,10f0)
prob = ODEProblem(fitz,u0,tspan,p_)
sol = solve(prob, Tsit5(), saveat = 0.5 )

# Ideal data
X = Array(sol)
Xₙ = X + Float32(1e-3)*randn(eltype(X), size(X))  #noisy data

# For xz term
NN_1 = Lux.Chain(Lux.Dense(2, 16, tanh), Lux.Dense(16, 1))
p1,st1 = Lux.setup(rng, NN_1)

# for xy term
NN_2 = Lux.Chain(Lux.Dense(3, 16, tanh), Lux.Dense(16, 1))
p2, st2 = Lux.setup(rng, NN_2)
scaling_factor = 1f0

p1 = Lux.ComponentArray(p1)
p2 = Lux.ComponentArray(p2)

p = Lux.ComponentArray{eltype(p1)}()
p = Lux.ComponentArray(p;p1)
p = Lux.ComponentArray(p;p2)
p = Lux.ComponentArray(p;scaling_factor)

function dudt_(u,p,t)
    v,w = u
    z1 = NN_1([v,w], p.p1, st1)[1]
    z2 = NN_2([v,w,t], p.p2, st2)[1]
    A = [v,w,t]
    jac_temp = jacobian(A->NN_2(A,p.p2,st2)[1],A)[1]
prob_nn = ODEProblem(dudt_,u0, tspan, p)
sol_nn = solve(prob_nn, Tsit5(),saveat = sol.t)

function predict(θ)
    Array(solve(prob_nn, Vern7(), p=θ, saveat = sol.t,
                         abstol=1e-6, reltol=1e-6,
                         sensealg = InterpolatingAdjoint(autojacvec=ReverseDiffVJP(true))))

# No regularisation right now
function loss(θ)
    pred = predict(θ)
    sum(abs2, Xₙ .- pred), pred
const losses = []
callback(θ,l,pred) = begin
    push!(losses, l)
    if length(losses)%50==0
adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x,p) -> loss(x), adtype)

optprob = Optimization.OptimizationProblem(optf, p)
res1_uode = Optimization.solve(optprob, ADAM(0.01), callback=callback, maxiters = 50)

I’ve added two lines of code within the dudt_() function that differ from the DiffEqFlux example. With these, I’m trying to compute the jacobian.

A = [v,w,t]
jac_temp = jacobian(A->NN_2(A,p.p2,st2)[1],A)[1]

At this stage, I’m not even trying to do anything with the jacobian – I just want to calculate it. These lines of code work if I just run them by themselves. However, within the optimization, I get the following error message:

ERROR: Need an adjoint for constructor ReverseDiff.TrackedArray{Float32, Float32, 2, Base.ReshapedArray{Float32, 2, SubArray{Float32, 1, Vector{Float32}, Tuple{UnitRange{Int64}}, true}, Tuple{}}, Base.ReshapedArray{Float32, 2, SubArray{Float32, 1, Vector{Float32}, Tuple{UnitRange{Int64}}, true}, Tuple{}}}. Gradient is of type Matrix{ReverseDiff.TrackedReal{Float32, Float32, ReverseDiff.TrackedArray{Float32, Float32, 1, Vector{Float32}, Vector{Float32}}}}

Any help would be great. Thanks!

What are you using to compute the Jacobian there? ForwardDiff would make the most sense

Ah! I was just using whatever the default was for my import statements (maybe DiffEqFlux?) Admittedly, I don’t know the differences between the various ways to take the jacobian.
I just tried using


(I recognize that I’m just pulling out rows of jacobians here – I’m just trying to make something work) It gave me this error:

ERROR: MethodError: no method matching Float32(::ReverseDiff.TrackedReal{Float32, Float32, Nothing})
Closest candidates are:
  (::Type{T})(::Real, ::RoundingMode) where T<:AbstractFloat at rounding.jl:200
  (::Type{T})(::T) where T<:Number at boot.jl:772
  (::Type{T})(::AbstractChar) where T<:Union{AbstractChar, Number} at char.jl:50

I found a discussion on Github that the sensealg is causing this. Removing “sensealg=…” makes the code run, though I get the warning:

Warning: `ForwardDiff.jacobian(f, x)` within Zygote cannot track gradients with respect to `f`,
│ and `f` appears to be a closure, or a struct with fields (according to `issingletontype(typeof(f))`).

Is this something I should be concerned by? If I’m reading this correctly, I think I would want to track gradients with respect to f… I also tried using:

jac_temp = AD.gradient(AD.ForwardDiffBackend(), A->NN_2(A,p.p2,st2)[1][1],A)

From the Lux docs, but get an warning that “Reverse-Mode AD VJP choices all failed. Falling back to numerical VJP”, and then a MethodError.

I’d need to play with your example a bit, which I won’t get to today. Could you open an issue on SciMLSensitivity.jl?

Sure – I’ve opened an issue.