This is similar to this question.
I am training a neural net NN(), where the loss function involves gradient of NN(.) wrt to the input. For example:
m = Flux.Chain(Dense(5,5,relu), Dense(5,5,relu), Dense(5,1))
g(z) = only(m(z))
function loss(x,y)
w = Flux.gradient(g, x)[1]
yhat = dot(w, x) / g(x)
return (yhat - y)^2
end
This does not work within the Flux framework. Currently, I train the network 1) “manually” using ForwardDiff.gradient where required and 2) where the network is small, using Optim.
Is there a way we can handle this within Flux? Many thanks!
I use such loss functions, but they are finicky. You need to be certain that gradients are all differentiable, which means that you have to in practice re-write some gradients, as they are not second-order gradients friendly. But in theory, zygote can do 2nd order gradients. Be aware that compilation time might be long.
1 Like
@balaji1975 how do you do the training with ForwadDiff.gradient?
what code follows after your loss? Specifically how do you do the call for gradients and the update later of parameters?
@lazarusA
It has been a long time, memory is sketchy. I ended up using just Flux with a custom loss function.
Essentially, i was trying to find a neural network N that minimized the error: