This is similar to this question.
I am training a neural net NN(), where the loss function involves gradient of NN(.) wrt to the input. For example:
m = Flux.Chain(Dense(5,5,relu), Dense(5,5,relu), Dense(5,1))
g(z) = only(m(z))
function loss(x,y)
w = Flux.gradient(g, x)[1]
yhat = dot(w, x) / g(x)
return (yhat - y)^2
end
This does not work within the Flux framework. Currently, I train the network 1) “manually” using ForwardDiff.gradient where required and 2) where the network is small, using Optim.
Is there a way we can handle this within Flux? Many thanks!