Hi all, I am trying to construct a loss function in Flux that is composed of a typical mean squared error component, as well as a derivative component. For some reason I get a Can't differentiate foreigncall expression
error when I try to compute the gradient of the loss function. A minimal working example of the code is shown below. Does anybody know how to fix this/or have a work around?
using Flux
m = Chain(Dense(3, 10, relu), Dense(10, 10, relu), Dense(10, 1))
ps = Flux.params(m)
function loss(x, y)
fitloss = Flux.Losses.mse(m(x), y) # typical loss function
derivativeloss = abs2(gradient(a -> m(a)[1], x)[1][3]) # problematic term (only care about derivative of the 3rd input)
return fitloss + derivativeloss
end
xt = rand(3)
yt = rand(1)
gs = gradient(ps) do
loss(xt, yt)
end
Note, this is essentially the same unsolved error as found here. Thanks for helping!