Is it possible to do Nested AD ~elegantly~ in Julia? (PINNs)

Zygote indeed has issues with nested differentiation. Here, another problem seems to be its implicit handling of parameters though. Switching to the new and more explicit API of Optimisers you’re example works with ReverseDiff over ForwardDiff:

using Flux, ForwardDiff, ReverseDiff, Zygote, Optimisers, Statistics

NN = Chain(Dense(1 => 12,tanh),
           Dense(12 => 12,tanh),
           Dense(12 => 12,tanh),
           Dense(12 => 1))
θ, reNN = Flux.destructure(NN)

net(x, θ) = x*first(reNN(θ)([x]))
net(1, θ) #Works

net_xAD(x, θ) = ForwardDiff.derivative(x -> net(x, θ),x) 
net_xAD(1, θ) #Works

ts = 1f-2:1f-2:1f0
loss(θ) = mean(abs2(net_xAD(t, θ)-cos(t)) for t in ts) 

# Both of these work
ReverseDiff.gradient(loss, θ)
ForwardDiff.gradient(loss, θ)

function train(θ; opt = Optimisers.Adam(), steps = 250)
    state = Optimisers.setup(opt, θ)
    for i = 1:steps
        if i % 50 == 0
            display(loss(θ))
        end
        ∇θ = ReverseDiff.gradient(loss, θ)
        state, θ = Optimisers.update(state, θ, ∇θ)
    end
    θ
end
julia> train(θ);
0.083571285f0
0.022315795f0
0.0012668703f0
0.00012052046f0
0.00011890126f0

Interestingly, trying Zygote for the outer differentiation explains why it fails:

julia> Zygote.gradient(loss, θ)
┌ Warning: `ForwardDiff.derivative(f, x)` within Zygote cannot track gradients with respect to `f`,
│ and `f` appears to be a closure, or a struct with fields (according to `issingletontype(typeof(f))`).
│ typeof(f) = var"#23#24"{Vector{Float32}}
└ @ Zygote ~/.julia/packages/Zygote/oGI57/src/lib/forward.jl:158
(nothing,)
5 Likes