I think I’ve got it figured out now. Code below borrows heavily from the source code for DiffEqFlux.jl
Big thanks to @ChrisRackauckas for figuring this out in DiffEqFlux.jl
using Flux
using ReverseDiff
#Create modified Flux NN type that carries destructured parameters
struct ModdNN{M, R, P}
model::M
re::R
p::P
function ModNN(model; p = nothing)
_p, re = Flux.destructure(model)
if p === nothing
p = _p
end
return new{typeof(model), typeof(re), typeof(p)}(model, re, p)
end
end
#function to compute the derivative of third NN output with respect to input
function diff3destruct(re, p, x)
H = [Flux.gradient(w -> re(p)(w)[3], [y])[1][1] for y in x]
end
#Create an instance of modified type
mod1 = ModNN(Chain(Dense(1 => 32, relu),
Dense(32 => 32, relu),
Dense(32 => 3)))
#Record parameter values to check that they're updating correctly
checkP = copy(mod1.p)
#Loss function (simple aggregation of NN outputs)
loss(re, p, x) = sum(diff3destruct(re, p, x))
#Get the gradients using Reverse diff
gs = ReverseDiff.gradient(p -> loss(mod1.re, p, rand(3)), mod1.p)
opt = ADAM(0.01)
#Update the NN
Flux.Optimise.update!(opt,mod1.p,gs)
Now checking that the neural network updated correctly:
julia> maximum(mod1.p .- checkP)
0.010000005f0