Zygote indeed has issues with nested differentiation. Here, another problem seems to be its implicit handling of parameters though. Switching to the new and more explicit API of Optimisers you’re example works with ReverseDiff over ForwardDiff:
using Flux, ForwardDiff, ReverseDiff, Zygote, Optimisers, Statistics
NN = Chain(Dense(1 => 12,tanh),
Dense(12 => 12,tanh),
Dense(12 => 12,tanh),
Dense(12 => 1))
θ, reNN = Flux.destructure(NN)
net(x, θ) = x*first(reNN(θ)([x]))
net(1, θ) #Works
net_xAD(x, θ) = ForwardDiff.derivative(x -> net(x, θ),x)
net_xAD(1, θ) #Works
ts = 1f-2:1f-2:1f0
loss(θ) = mean(abs2(net_xAD(t, θ)-cos(t)) for t in ts)
# Both of these work
ReverseDiff.gradient(loss, θ)
ForwardDiff.gradient(loss, θ)
function train(θ; opt = Optimisers.Adam(), steps = 250)
state = Optimisers.setup(opt, θ)
for i = 1:steps
if i % 50 == 0
display(loss(θ))
end
∇θ = ReverseDiff.gradient(loss, θ)
state, θ = Optimisers.update(state, θ, ∇θ)
end
θ
end
julia> train(θ);
0.083571285f0
0.022315795f0
0.0012668703f0
0.00012052046f0
0.00011890126f0
Interestingly, trying Zygote for the outer differentiation explains why it fails:
julia> Zygote.gradient(loss, θ)
┌ Warning: `ForwardDiff.derivative(f, x)` within Zygote cannot track gradients with respect to `f`,
│ and `f` appears to be a closure, or a struct with fields (according to `issingletontype(typeof(f))`).
│ typeof(f) = var"#23#24"{Vector{Float32}}
└ @ Zygote ~/.julia/packages/Zygote/oGI57/src/lib/forward.jl:158
(nothing,)