Flux loss: Gradient wrt input leads to empty gradient wrt parameters or to "can't differentiate foreigncall"

Hi Everyone,

I am stuck with the following problem: I need the gradient (w.r.t. the input variables) of my NN in the loss function and cannot get it working. I have seen a couple of posts on this, however none with a solution I could make sense of. As suggested though, I unsuccessfully tried with Zygote, ForwardDiff and BackwardDiff. Since there are a lot of great minds out there, I hope you’ll have a couple a minutes to look at this minimal example:

using Flux, ReverseDiff, ForwardDiff, LinearAlgebra
using Flux: train!

hidden_states = [2, 3, 1]               # Layer widths.
L = length(hidden_states)               # Depth of the network.
layers = [Dense(hidden_states[i], hidden_states[i+1], tanh, bias = false) for i = 1:L-1]
model = Chain(layers[1], layers[2])
ps = params( model )

# Function needed for cost computation.
fparams = Dict("a" => -1, "b" => 2, "c1" => 4, "c2" => 2)
f(y) = [ 
    -( fparams["c1"]*fparams["a"]*y[1]*(y[1]^2 + y[2]^2) + fparams["c2"]*fparams["b"]*y[1] ), 
    -( fparams["c1"]*fparams["a"]*y[2]*(y[1]^2 + y[2]^2) + fparams["c2"]*fparams["b"]*y[2] )
    ]

x_test = [0.9, 0.9]

# 1st option runs but grads is empty. 2nd and 3rd options give: "Can't differentiate foreigncall expression".
loss(xdata) = dot( ForwardDiff.gradient( (x_) -> sum(model(x_)),  xdata ), f(xdata) )
# loss(xdata) = dot( ReverseDiff.gradient( (x_) -> sum(model(x_)),  xdata ), f(xdata) )
# loss(xdata) = dot( Zygote.gradient( (x_) -> sum(model(x_)),  xdata )[1], f(xdata) )

println( loss( x_test ) )
println( ps )

opt = RADAM()
for epoch in 1:100
    grads = gradient(ps) do
        loss(x_test)
    end
    if epoch % 20 == 0
        println( grads )
    end
    Flux.update!(opt, ps, grads)
end

println( loss( x_test ) )
println( ps )

I would be very thankful if any of you can help!

Unfortunately second derivatives are tricky, with the exception of ForwardDiff over anything.

Why Zygote over Zygote gives an error here I don’t know. Accessing a dictionary inside the gradient doesn’t seem to be it. It’s a bug and if you can isolate it maybe someone can fix it.

Trying to use Zygote over ForwardDiff, the reason you get zero is that Zygote.forwarddiff(f, x) does not store gradient with respect to f. Which is what it’s calling to avoid an error. But this is a landmine which should probably be removed.

1 Like

Thanks a lot for your answer! So the summary would be: I should wait for Diffract.jl if I want to compute this with AD, right? (Or isolate the error, but I feel like I am quite far from it…)

I figured out a “dirty” solution which is enough for my purposes by computing d(NN)/dx with finite difference. I am not proud of it but it does the job hahaha.

Anyway, if anyone has a solution for AD though, I’d still be highly interested!

Not so far. You just have to fool around with deleting things from the original function until the error just goes away, or adding things to a simpler version until it starts giving the error:

julia> loss(xdata) = sum(Zygote.gradient( (x_) -> sum(x_ .* x_),  xdata )[1]);

julia> gradient(x -> loss(x), x_test)
([2.0, 2.0],)

julia> loss(xdata) = sum(Zygote.gradient( (x_) -> sum(model(x_)),  xdata )[1]);

julia> gradient(x -> loss(x), x_test)
ERROR: Can't differentiate foreigncall expression
Stacktrace:
  [1] error(s::String)
    @ Base ./error.jl:35
  [2] Pullback
    @ ./iddict.jl:102 [inlined]
  [3] (::typeof(∂(get)))(Δ::Nothing)
    @ Zygote ./compiler/interface2.jl:0
  [4] Pullback
    @ ~/.julia/dev/Zygote/src/lib/lib.jl:68 [inlined]
  [5] (::typeof(∂(accum_global)))(Δ::Nothing)
    @ Zygote ./compiler/interface2.jl:0
  [6] Pullback
    @ ~/.julia/dev/Zygote/src/lib/lib.jl:79 [inlined]
  [7] (::typeof(∂(λ)))(Δ::Nothing)
    @ Zygote ./compiler/interface2.jl:0
  [8] Pullback
    @ ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:67 [inlined]
  [9] (::typeof(∂(λ)))(Δ::Nothing)
    @ Zygote ./compiler/interface2.jl:0
 [10] getindex
    @ ./tuple.jl:29 [inlined]
 [11] map
    @ ./tuple.jl:222 [inlined]
 [12] unthunk_tangent
    @ ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:36 [inlined]
 [13] #1639#back
    @ ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:67 [inlined]
 [14] (::typeof(∂(λ)))(Δ::Tuple{Nothing, FillArrays.Fill{Float64, 1, Tuple{Base.OneTo{Int64}}}})
    @ Zygote ./compiler/interface2.jl:0
 [15] Pullback
    @ ~/.julia/dev/Zygote/src/compiler/interface.jl:41 [inlined]
 [16] (::typeof(∂(λ)))(Δ::Tuple{FillArrays.Fill{Float64, 1, Tuple{Base.OneTo{Int64}}}})
    @ Zygote ./compiler/interface2.jl:0
 [17] Pullback
    @ ~/.julia/dev/Zygote/src/compiler/interface.jl:76 [inlined]
 [18] (::typeof(∂(gradient)))(Δ::Tuple{FillArrays.Fill{Float64, 1, Tuple{Base.OneTo{Int64}}}})
    @ Zygote ./compiler/interface2.jl:0
 [19] Pullback
    @ ./REPL[51]:1 [inlined]
 [20] (::typeof(∂(loss)))(Δ::Float64)
    @ Zygote ./compiler/interface2.jl:0
 [21] Pullback
    @ ./REPL[52]:1 [inlined]
 [22] (::Zygote.var"#56#57"{typeof(∂(#32))})(Δ::Float64)
    @ Zygote ~/.julia/dev/Zygote/src/compiler/interface.jl:41
 [23] gradient(f::Function, args::Vector{Float64})
    @ Zygote ~/.julia/dev/Zygote/src/compiler/interface.jl:76
 [24] top-level scope
    @ REPL[52]:1
 [25] top-level scope
    @ ~/.julia/packages/CUDA/5jdFl/src/initialization.jl:52