Loss functions that involve gradients

I am training a neural net NN(), where the loss function involves gradient of NN(.) wrt to the input. For example:

m = Flux.Chain(Dense(5,5,relu), Dense(5,5,relu), Dense(5,1))
g(z) = only(m(z))

function loss(x,y)
    w = Flux.gradient(g, x)[1]
    yhat = dot(w, x) / g(x)
    return (yhat - y)^2

This does not work within the Flux framework. Currently, I train the network 1) “manually” using ForwardDiff.gradient where required and 2) where the network is small, using Optim.

Is there a way we can handle this within Flux? Many thanks!

I use such loss functions, but they are finicky. You need to be certain that gradients are all differentiable, which means that you have to in practice re-write some gradients, as they are not second-order gradients friendly. But in theory, zygote can do 2nd order gradients. Be aware that compilation time might be long.

Thanks Tomas.

@balaji1975 how do you do the training with ForwadDiff.gradient?

what code follows after your loss? Specifically how do you do the call for gradients and the update later of parameters?


It has been a long time, memory is sketchy. I ended up using just Flux with a custom loss function.

Essentially, i was trying to find a neural network N that minimized the error:
