ReverseDiff for loss function with Zygote derivatives

I’m trying to use Flux.jl to find the solution to a function equation with a form a bit like g(f(x)[1]) = h(f’(x)), where f(x) is an unkown function to be approximated by a neural network.

The loss function for this setup needs to include the derivative of the neural network (with respect to inputs).

From reading posts from others that have had difficulty with this, I understand that currently the most reliable way to make that work is to use Zygote to take the derivative internal to the loss function and then ReverseDiff for the derivative of the loss function with respect to the neural network parameters.

I’m having trouble figuring out the syntax to make that work though…I keep getting errors.

Here’s a small example that I’m working on:

using Flux, Zygote, ReverseDiff

mod1 = Chain(Dense(1 => 4, relu), 
              Dense(4 => 4, relu), 
              Dense(4 => 3))   #Full scale example needs 3 outputs so including here to make sure it works

dMdx3(m,x) = [Zygote.gradient( w -> m(w)[3], y)[1][1] for y in x]
julia> dMdx3(mod1,rand(3))  #checking to make sure this works
3-element Vector{Float64}:
 -0.35708116668549555
 -0.35708116668549555
 -0.35708116668549555
loss(f,x) = sum(dMdx3(f,x))
julia> loss(mod1,rand(3))  #again checking functionality
-1.0712435000564866
julia> ReverseDiff.gradient(m -> loss(m,rand(3)), mod1)
ERROR: MethodError: no method matching ReverseDiff.GradientConfig(::Chain{Tuple{Dense{typeof(relu), Matrix{Float32}, Vector{Float32}}, Dense{typeof(relu), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}})
Closest candidates are:
  ReverseDiff.GradientConfig(::Tuple) at C:\Users\Patrick\.julia\packages\ReverseDiff\YkVxM\src\api\Config.jl:37
  ReverseDiff.GradientConfig(::Tuple, ::Type{D}) where D at C:\Users\Patrick\.julia\packages\ReverseDiff\YkVxM\src\api\Config.jl:45
  ReverseDiff.GradientConfig(::Tuple, ::Type{D}, ::Vector{ReverseDiff.AbstractInstruction}) where D at C:\Users\Patrick\.julia\packages\ReverseDiff\YkVxM\src\api\Config.jl:45
  ...
Stacktrace:
 [1] gradient(f::Function, input::Chain{Tuple{Dense{typeof(relu), Matrix{Float32}, Vector{Float32}}, Dense{typeof(relu), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}})
   @ ReverseDiff C:\Users\Patrick\.julia\packages\ReverseDiff\YkVxM\src\api\gradients.jl:22
 [2] top-level scope
   @ REPL[19]:1

I got the basic syntax for using ReverseDiff to get the gradient of a neural network from the DiffEqFlux.jl documentation here. But it was used in the context of a Hamiltonian neural network so not sure if it’s correct here.

Would greatly value any advice on how to make this work properly.

I think I’ve got it figured out now. Code below borrows heavily from the source code for DiffEqFlux.jl

Big thanks to @ChrisRackauckas for figuring this out in DiffEqFlux.jl

using Flux
using ReverseDiff

#Create modified Flux NN type that carries destructured parameters
struct ModdNN{M, R, P}
    model::M
    re::R
    p::P

    function ModNN(model; p = nothing)
        _p, re = Flux.destructure(model)
        if p === nothing
            p = _p
        end
        return new{typeof(model), typeof(re), typeof(p)}(model, re, p)
    end
end

#function to compute the derivative of third NN output with respect to input
function diff3destruct(re, p, x)
    H = [Flux.gradient(w -> re(p)(w)[3], [y])[1][1] for y in x]
end

#Create an instance of modified type
mod1 = ModNN(Chain(Dense(1 => 32, relu),
Dense(32 => 32, relu),
Dense(32 => 3)))

#Record parameter values to check that they're updating correctly
checkP = copy(mod1.p)

#Loss function (simple aggregation of NN outputs)
loss(re, p, x) = sum(diff3destruct(re, p, x))

#Get the gradients using Reverse diff
gs = ReverseDiff.gradient(p -> loss(mod1.re, p, rand(3)), mod1.p)

opt = ADAM(0.01)

#Update the NN
Flux.Optimise.update!(opt,mod1.p,gs)

Now checking that the neural network updated correctly:

julia> maximum(mod1.p .- checkP)
0.010000005f0