Hi all, I am trying to construct a loss function in Flux that is composed of a typical mean squared error component, as well as a derivative component. For some reason I get a
Can't differentiate foreigncall expression error when I try to compute the gradient of the loss function. A minimal working example of the code is shown below. Does anybody know how to fix this/or have a work around?
using Flux m = Chain(Dense(3, 10, relu), Dense(10, 10, relu), Dense(10, 1)) ps = Flux.params(m) function loss(x, y) fitloss = Flux.Losses.mse(m(x), y) # typical loss function derivativeloss = abs2(gradient(a -> m(a), x)) # problematic term (only care about derivative of the 3rd input) return fitloss + derivativeloss end xt = rand(3) yt = rand(1) gs = gradient(ps) do loss(xt, yt) end
Note, this is essentially the same unsolved error as found here. Thanks for helping!