Hi there,
I wonder now if there is a convenient way to put the derivatives from AD into loss function directly.
I tried the following code, but it failed with Zygote.
using Flux
m = Chain(Dense(1, 20, tanh), Dense(20, 1))
X = (range(0., 1., length=100) |> collect)'
u(x) = m(x)
ux(x) = Zygote.pullback(u, x)[2](ones(size(x)))[1]
loss(x) = sum(abs2, ux(x))
ps = Flux.params(m)
Flux.train!(loss, params(m), dataset, ADAM())
The error info is
Can't differentiate foreigncall expression
Stacktrace:
[1] error(::String) at ./error.jl:33
[2] get at ./abstractdict.jl:596 [inlined]
[3] (::typeof(∂(get)))(::Nothing) at /home/tianbai/.julia/packages/Zygote/KNUTW/src/compiler/interface2.jl:0
[4] accum_global at /home/tianbai/.julia/packages/Zygote/KNUTW/src/lib/lib.jl:59 [inlined]
[5] (::typeof(∂(accum_global)))(::Nothing) at /home/tianbai/.julia/packages/Zygote/KNUTW/src/compiler/interface2.jl:0
[6] #119 at /home/tianbai/.julia/packages/Zygote/KNUTW/src/lib/lib.jl:70 [inlined]
[7] (::typeof(∂(λ)))(::Nothing) at /home/tianbai/.julia/packages/Zygote/KNUTW/src/compiler/interface2.jl:0
[8] #214#back at /home/tianbai/.julia/packages/ZygoteRules/6nssF/src/adjoint.jl:49 [inlined]
[9] (::typeof(∂(λ)))(::Nothing) at /home/tianbai/.julia/packages/Zygote/KNUTW/src/compiler/interface2.jl:0
[10] gradtuple1 at /home/tianbai/.julia/packages/ZygoteRules/6nssF/src/adjoint.jl:12 [inlined]