How to put a derivative into loss function?

Hi there,

I wonder now if there is a convenient way to put the derivatives from AD into loss function directly.
I tried the following code, but it failed with Zygote.

using Flux

m = Chain(Dense(1, 20, tanh), Dense(20, 1))
X = (range(0., 1., length=100) |> collect)'

u(x) = m(x)
ux(x) = Zygote.pullback(u, x)[2](ones(size(x)))[1]

loss(x) = sum(abs2, ux(x))
ps = Flux.params(m)

Flux.train!(loss, params(m), dataset, ADAM())

The error info is

Can't differentiate foreigncall expression

 [1] error(::String) at ./error.jl:33
 [2] get at ./abstractdict.jl:596 [inlined]
 [3] (::typeof(∂(get)))(::Nothing) at /home/tianbai/.julia/packages/Zygote/KNUTW/src/compiler/interface2.jl:0
 [4] accum_global at /home/tianbai/.julia/packages/Zygote/KNUTW/src/lib/lib.jl:59 [inlined]
 [5] (::typeof(∂(accum_global)))(::Nothing) at /home/tianbai/.julia/packages/Zygote/KNUTW/src/compiler/interface2.jl:0
 [6] #119 at /home/tianbai/.julia/packages/Zygote/KNUTW/src/lib/lib.jl:70 [inlined]
 [7] (::typeof(∂(λ)))(::Nothing) at /home/tianbai/.julia/packages/Zygote/KNUTW/src/compiler/interface2.jl:0
 [8] #214#back at /home/tianbai/.julia/packages/ZygoteRules/6nssF/src/adjoint.jl:49 [inlined]
 [9] (::typeof(∂(λ)))(::Nothing) at /home/tianbai/.julia/packages/Zygote/KNUTW/src/compiler/interface2.jl:0
 [10] gradtuple1 at /home/tianbai/.julia/packages/ZygoteRules/6nssF/src/adjoint.jl:12 [inlined]
1 Like

Calculate the derivative in the loss via Zygote and then calculate the gradient of the loss via ReverseDiff.

1 Like