How to use gradient of neural network as the loss function?

Dear all,

I want to put the gradient of neural network (i.e., the gradient of NN output w.r.t input) in the loss function. However, it seems that Zygote cannot differentiate that. Here is my code:

using Flux
using Zygote
using Random

Xtrain = rand(Float32,(4,100));
Ytrain = rand(Float32,(1,100));

#******************* NN ********************
model = Chain(x->x[1:2,:],
	      Dense(2, 50, tanh),
              Dense(50, 50, tanh),
              Dense(50, 1))
θ = params(model);

#******************* Lost function ********************
function loss(x,y)
    ps = Flux.Params([x]);
    g = Flux.gradient(ps) do
        model(x)[1]
    end
    return ( Flux.mse(g[x][1:2,:][1] , x[3:4,:]) + Flux.mse(model(x),y) )
end

#******************* Training ********************
train_loader = Flux.Data.DataLoader(Xtrain,Ytrain, batchsize=1,shuffle=true);

Nepochs = 10;
opt = ADAM(0.001)

for epoch in 1:Nepochs
  @time Flux.train!(loss, θ, train_loader, opt)
end

However, the code failed to work. This is error message:

ERROR: Can't differentiate foreigncall expression
Stacktrace:
 [1] error(::String) at ./error.jl:33
 [2] get at ./iddict.jl:87 [inlined]
 [3] (::typeof(∂(get)))(::Nothing) at /home/wuxxx184/ma000311/.julia/packages/Zygote/chgvX/src/compiler/interface2.jl:0
 [4] accum_global at /home/wuxxx184/ma000311/.julia/packages/Zygote/chgvX/src/lib/lib.jl:56 [inlined]
 [5] (::typeof(∂(accum_global)))(::Nothing) at /home/wuxxx184/ma000311/.julia/packages/Zygote/chgvX/src/compiler/interface2.jl:0
 [6] #89 at /home/wuxxx184/ma000311/.julia/packages/Zygote/chgvX/src/lib/lib.jl:67 [inlined]
 [7] (::typeof(∂(λ)))(::Nothing) at /home/wuxxx184/ma000311/.julia/packages/Zygote/chgvX/src/compiler/interface2.jl:0
 [8] #1550#back at /home/wuxxx184/ma000311/.julia/packages/ZygoteRules/6nssF/src/adjoint.jl:49 [inlined]
 [9] (::typeof(∂(λ)))(::Nothing) at /home/wuxxx184/ma000311/.julia/packages/Zygote/chgvX/src/compiler/interface2.jl:0
 [10] getindex at ./tuple.jl:24 [inlined]
 [11] gradindex at /home/wuxxx184/ma000311/.julia/packages/Zygote/chgvX/src/compiler/reverse.jl:12 [inlined]
 [12] #3 at ./REPL[8]:5 [inlined]
 [13] (::typeof(∂(λ)))(::Nothing) at /home/wuxxx184/ma000311/.julia/packages/Zygote/chgvX/src/compiler/interface2.jl:0
 [14] #54 at /home/wuxxx184/ma000311/.julia/packages/Zygote/chgvX/src/compiler/interface.jl:177 [inlined]
 [15] (::typeof(∂(λ)))(::Nothing) at /home/wuxxx184/ma000311/.julia/packages/Zygote/chgvX/src/compiler/interface2.jl:0
 [16] gradient at /home/wuxxx184/ma000311/.julia/packages/Zygote/chgvX/src/compiler/interface.jl:54 [inlined]
 [17] (::typeof(∂(gradient)))(::Nothing) at /home/wuxxx184/ma000311/.julia/packages/Zygote/chgvX/src/compiler/interface2.jl:0
 [18] loss at ./REPL[8]:4 [inlined]
 [19] (::typeof(∂(loss)))(::Float32) at /home/wuxxx184/ma000311/.julia/packages/Zygote/chgvX/src/compiler/interface2.jl:0
 [20] #145 at /home/wuxxx184/ma000311/.julia/packages/Zygote/chgvX/src/lib/lib.jl:175 [inlined]
 [21] #1681#back at /home/wuxxx184/ma000311/.julia/packages/ZygoteRules/6nssF/src/adjoint.jl:49 [inlined]
 [22] #15 at /home/wuxxx184/ma000311/.julia/packages/Flux/05b38/src/optimise/train.jl:83 [inlined]
 [23] (::typeof(∂(λ)))(::Float32) at /home/wuxxx184/ma000311/.julia/packages/Zygote/chgvX/src/compiler/interface2.jl:0
 [24] (::Zygote.var"#54#55"{Params,Zygote.Context,typeof(∂(λ))})(::Float32) at /home/wuxxx184/ma000311/.julia/packages/Zygote/chgvX/src/compiler/interface.jl:177
 [25] gradient(::Function, ::Params) at /home/wuxxx184/ma000311/.julia/packages/Zygote/chgvX/src/compiler/interface.jl:54
 [26] macro expansion at /home/wuxxx184/ma000311/.julia/packages/Flux/05b38/src/optimise/train.jl:82 [inlined]
 [27] macro expansion at /home/wuxxx184/ma000311/.julia/packages/Juno/n6wyj/src/progress.jl:134 [inlined]
 [28] train!(::Function, ::Params, ::Flux.Data.DataLoader{Tuple{Array{Float32,2},Array{Float32,2}}}, ::ADAM; cb::Flux.Optimise.var"#16#22") at /home/wuxxx184/ma000311/.julia/packages/Flux/05b38/src/optimise/train.jl:80
 [29] train!(::Function, ::Params, ::Flux.Data.DataLoader{Tuple{Array{Float32,2},Array{Float32,2}}}, ::ADAM) at /home/wuxxx184/ma000311/.julia/packages/Flux/05b38/src/optimise/train.jl:78
 [30] macro expansion at ./timing.jl:174 [inlined]
 [31] top-level scope at ./REPL[12]:2


Does anyone know how to fix this?

Thanks a lot!
Xiaodong

I am doing someting similar to you and I have made it work, but it is not something trivial and I would not recommend this to begginers. You need to ensure that all gradients (aka Pullbacks) in your inner gradient are twice differentiable. I would recommend to make the problem example smaller (replace train by a gradient call), etc. I am not sure that the getindex is for example twice differentiable. So may-be replacing model(x)[1] with sum(model(x)) might do the trick.

Thank you very much for your suggesions, Tomas. I tried to remove all the indexes in the loss function:

function loss(x,y)
    ps = Flux.Params([x]);
    g = Flux.gradient(ps) do
        sum(model(x))
    end
    return ( sum(Flux.mse(g[x])) )
end

But same error will occur.
I will replace train by a simpler gradient call to see if I can trace Pullbacks.

Xiaodong

If I will not forget, I will send you tomorrow a hacked ZygoteRules which trace which pullback is called. It is a bit low-level, but it help me to debig nans and all these things.