Loss functions that involve gradients

This is similar to this question.

I am training a neural net NN(), where the loss function involves gradient of NN(.) wrt to the input. For example:

m = Flux.Chain(Dense(5,5,relu), Dense(5,5,relu), Dense(5,1))
g(z) = only(m(z))

function loss(x,y)
    w = Flux.gradient(g, x)[1]
    yhat = dot(w, x) / g(x)
    return (yhat - y)^2

This does not work within the Flux framework. Currently, I train the network 1) “manually” using ForwardDiff.gradient where required and 2) where the network is small, using Optim.

Is there a way we can handle this within Flux? Many thanks!

I use such loss functions, but they are finicky. You need to be certain that gradients are all differentiable, which means that you have to in practice re-write some gradients, as they are not second-order gradients friendly. But in theory, zygote can do 2nd order gradients. Be aware that compilation time might be long.

1 Like

Thanks Tomas.