This is similar to this question.
I am training a neural net NN(), where the loss function involves gradient of NN(.) wrt to the input. For example:
m = Flux.Chain(Dense(5,5,relu), Dense(5,5,relu), Dense(5,1)) g(z) = only(m(z)) function loss(x,y) w = Flux.gradient(g, x) yhat = dot(w, x) / g(x) return (yhat - y)^2 end
This does not work within the Flux framework. Currently, I train the network 1) “manually” using ForwardDiff.gradient where required and 2) where the network is small, using Optim.
Is there a way we can handle this within Flux? Many thanks!