How to use gradient of neural network as the loss function?

Dear all,

I want to put the gradient of neural network (i.e., the gradient of NN output w.r.t input) in the loss function. However, it seems that Zygote cannot differentiate that. Here is my code:

using Flux
using Zygote
using Random

Xtrain = rand(Float32,(4,100));
Ytrain = rand(Float32,(1,100));

#******************* NN ********************
model = Chain(x->x[1:2,:],
	      Dense(2, 50, tanh),
              Dense(50, 50, tanh),
              Dense(50, 1))
θ = params(model);

#******************* Lost function ********************
function loss(x,y)
    ps = Flux.Params([x]);
    g = Flux.gradient(ps) do
        model(x)[1]
    end
    return ( Flux.mse(g[x][1:2,:][1] , x[3:4,:]) + Flux.mse(model(x),y) )
end

#******************* Training ********************
train_loader = Flux.Data.DataLoader(Xtrain,Ytrain, batchsize=1,shuffle=true);

Nepochs = 10;
opt = ADAM(0.001)

for epoch in 1:Nepochs
  @time Flux.train!(loss, θ, train_loader, opt)
end

However, the code failed to work. This is error message:

ERROR: Can't differentiate foreigncall expression
Stacktrace:
 [1] error(::String) at ./error.jl:33
 [2] get at ./iddict.jl:87 [inlined]
 [3] (::typeof(∂(get)))(::Nothing) at /home/wuxxx184/ma000311/.julia/packages/Zygote/chgvX/src/compiler/interface2.jl:0
 [4] accum_global at /home/wuxxx184/ma000311/.julia/packages/Zygote/chgvX/src/lib/lib.jl:56 [inlined]
 [5] (::typeof(∂(accum_global)))(::Nothing) at /home/wuxxx184/ma000311/.julia/packages/Zygote/chgvX/src/compiler/interface2.jl:0
 [6] #89 at /home/wuxxx184/ma000311/.julia/packages/Zygote/chgvX/src/lib/lib.jl:67 [inlined]
 [7] (::typeof(∂(λ)))(::Nothing) at /home/wuxxx184/ma000311/.julia/packages/Zygote/chgvX/src/compiler/interface2.jl:0
 [8] #1550#back at /home/wuxxx184/ma000311/.julia/packages/ZygoteRules/6nssF/src/adjoint.jl:49 [inlined]
 [9] (::typeof(∂(λ)))(::Nothing) at /home/wuxxx184/ma000311/.julia/packages/Zygote/chgvX/src/compiler/interface2.jl:0
 [10] getindex at ./tuple.jl:24 [inlined]
 [11] gradindex at /home/wuxxx184/ma000311/.julia/packages/Zygote/chgvX/src/compiler/reverse.jl:12 [inlined]
 [12] #3 at ./REPL[8]:5 [inlined]
 [13] (::typeof(∂(λ)))(::Nothing) at /home/wuxxx184/ma000311/.julia/packages/Zygote/chgvX/src/compiler/interface2.jl:0
 [14] #54 at /home/wuxxx184/ma000311/.julia/packages/Zygote/chgvX/src/compiler/interface.jl:177 [inlined]
 [15] (::typeof(∂(λ)))(::Nothing) at /home/wuxxx184/ma000311/.julia/packages/Zygote/chgvX/src/compiler/interface2.jl:0
 [16] gradient at /home/wuxxx184/ma000311/.julia/packages/Zygote/chgvX/src/compiler/interface.jl:54 [inlined]
 [17] (::typeof(∂(gradient)))(::Nothing) at /home/wuxxx184/ma000311/.julia/packages/Zygote/chgvX/src/compiler/interface2.jl:0
 [18] loss at ./REPL[8]:4 [inlined]
 [19] (::typeof(∂(loss)))(::Float32) at /home/wuxxx184/ma000311/.julia/packages/Zygote/chgvX/src/compiler/interface2.jl:0
 [20] #145 at /home/wuxxx184/ma000311/.julia/packages/Zygote/chgvX/src/lib/lib.jl:175 [inlined]
 [21] #1681#back at /home/wuxxx184/ma000311/.julia/packages/ZygoteRules/6nssF/src/adjoint.jl:49 [inlined]
 [22] #15 at /home/wuxxx184/ma000311/.julia/packages/Flux/05b38/src/optimise/train.jl:83 [inlined]
 [23] (::typeof(∂(λ)))(::Float32) at /home/wuxxx184/ma000311/.julia/packages/Zygote/chgvX/src/compiler/interface2.jl:0
 [24] (::Zygote.var"#54#55"{Params,Zygote.Context,typeof(∂(λ))})(::Float32) at /home/wuxxx184/ma000311/.julia/packages/Zygote/chgvX/src/compiler/interface.jl:177
 [25] gradient(::Function, ::Params) at /home/wuxxx184/ma000311/.julia/packages/Zygote/chgvX/src/compiler/interface.jl:54
 [26] macro expansion at /home/wuxxx184/ma000311/.julia/packages/Flux/05b38/src/optimise/train.jl:82 [inlined]
 [27] macro expansion at /home/wuxxx184/ma000311/.julia/packages/Juno/n6wyj/src/progress.jl:134 [inlined]
 [28] train!(::Function, ::Params, ::Flux.Data.DataLoader{Tuple{Array{Float32,2},Array{Float32,2}}}, ::ADAM; cb::Flux.Optimise.var"#16#22") at /home/wuxxx184/ma000311/.julia/packages/Flux/05b38/src/optimise/train.jl:80
 [29] train!(::Function, ::Params, ::Flux.Data.DataLoader{Tuple{Array{Float32,2},Array{Float32,2}}}, ::ADAM) at /home/wuxxx184/ma000311/.julia/packages/Flux/05b38/src/optimise/train.jl:78
 [30] macro expansion at ./timing.jl:174 [inlined]
 [31] top-level scope at ./REPL[12]:2


Does anyone know how to fix this?

Thanks a lot!
Xiaodong

I am doing someting similar to you and I have made it work, but it is not something trivial and I would not recommend this to begginers. You need to ensure that all gradients (aka Pullbacks) in your inner gradient are twice differentiable. I would recommend to make the problem example smaller (replace train by a gradient call), etc. I am not sure that the getindex is for example twice differentiable. So may-be replacing model(x)[1] with sum(model(x)) might do the trick.

Thank you very much for your suggesions, Tomas. I tried to remove all the indexes in the loss function:

function loss(x,y)
    ps = Flux.Params([x]);
    g = Flux.gradient(ps) do
        sum(model(x))
    end
    return ( sum(Flux.mse(g[x])) )
end

But same error will occur.
I will replace train by a simpler gradient call to see if I can trace Pullbacks.

Xiaodong

If I will not forget, I will send you tomorrow a hacked ZygoteRules which trace which pullback is called. It is a bit low-level, but it help me to debig nans and all these things.

Hi!
I have the same issue, to no avail. @Tomas_Pevny could you share your solution here?

Can you post MWE?

1 Like

Here’s what I’ve been trying to do:

using Zygote, Flux
NN = Chain(Conv((3, 3), 1=>1, pad=(1,1), relu), flatten, Dense(784,2)) |> gpu
loss(x) = sum(abs2,Zygote.forward_jacobian(NN,x)[2]) # the second element gives the Jacobian matrix
x = rand(28,28,1,1) |> gpu
loss(x) # works
ps = Flux.params(NN)
res,back = Zygote.pullback(ps) do
 loss(x)
end

This gives me a CuArray only supports bits types error.

Replacing forward_jacobian by Zygote.jacobian (this is a new function present in the Zygote master branch) in loss by doing

loss(x) = sum(abs2,Zygote.jacobian(NN,x)[1])

gives the following error upon training:

ERROR: this intrinsic must be compiled to be called                                                                                                                                                         
Stacktrace:                                                                                                                                                                                                 
 [1] macro expansion at /home/kaliam/.julia/packages/Zygote/IsBxF/src/compiler/interface2.jl:0 [inlined]                                                                                                    
 [2] _pullback(::Zygote.Context, ::Core.IntrinsicFunction, ::String, ::Type{UInt64}, ::Type{Tuple{Ptr{UInt64}}}, ::Ptr{UInt64}) at /home/kaliam/.julia/packages/Zygote/IsBxF/src/compiler/interface2.jl:9   
 [3] getindex at ./atomics.jl:358 [inlined]                                                                                                                                                                 
 [4] _pullback(::Zygote.Context, ::typeof(getindex), ::Base.Threads.Atomic{UInt64}) at /home/kaliam/.julia/packages/Zygote/IsBxF/src/compiler/interface2.jl:0                                               
 [5] macro expansion at /home/kaliam/.julia/packages/CUDA/wTQsK/lib/utils/call.jl:37 [inlined]                                                                                                              
 [6] cudnnGetVersion at /home/kaliam/.julia/packages/CUDA/wTQsK/lib/cudnn/libcudnn.jl:5 [inlined]                                                                                                           
 [7] _pullback(::Zygote.Context, ::typeof(CUDA.CUDNN.cudnnGetVersion)) at /home/kaliam/.julia/packages/Zygote/IsBxF/src/compiler/interface2.jl:0                                                            
 [8] version at /home/kaliam/.julia/packages/CUDA/wTQsK/lib/cudnn/base.jl:21 [inlined]                                                                                                                      
 [9] _pullback(::Zygote.Context, ::typeof(CUDA.CUDNN.version)) at /home/kaliam/.julia/packages/Zygote/IsBxF/src/compiler/interface2.jl:0                                                               
 [10] #∇conv_filter!#586 at /home/kaliam/.julia/packages/CUDA/wTQsK/lib/cudnn/nnlib.jl:185 [inlined]                                                                                                       
 [11] _pullback(::Zygote.Context, ::CUDA.CUDNN.var"##∇conv_filter!#586", ::Int64, ::Int64, ::typeof(∇conv_filter!), ::CuArray{Float32,4}, ::CuArray{Float32,4}, ::CuArray{Float32,4}, ::DenseConvDims{2,(3, 
3),1,1,(1, 1),(1, 1, 1, 1),(1, 1),false}) at /home/kaliam/.julia/packages/Zygote/IsBxF/src/compiler/interface2.jl:0                                                                                    
 [12] ∇conv_filter! at /home/kaliam/.julia/packages/CUDA/wTQsK/lib/cudnn/nnlib.jl:185 [inlined]                                                                                                            
 [13] _pullback(::Zygote.Context, ::typeof(∇conv_filter!), ::CuArray{Float32,4}, ::CuArray{Float32,4}, ::CuArray{Float32,4}, ::DenseConvDims{2,(3, 3),1,1,(1, 1),(1, 1, 1, 1),(1, 1),false}) at /home/kaliam
/.julia/packages/Zygote/IsBxF/src/compiler/interface2.jl:0                                                                                                                                             
 [14] #∇conv_filter#89 at /home/kaliam/.julia/packages/NNlib/uX1eA/src/conv.jl:116 [inlined]                                                                                                               
 [15] _pullback(::Zygote.Context, ::NNlib.var"##∇conv_filter#89", ::Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}}, ::typeof(∇conv_filter), ::CuArray{Float32,4}, ::CuArray{Float32,4}
, ::DenseConvDims{2,(3, 3),1,1,(1, 1),(1, 1, 1, 1),(1, 1),false}) at /home/kaliam/.julia/packages/Zygote/IsBxF/src/compiler/interface2.jl:0                                                            
 [16] ∇conv_filter at /home/kaliam/.julia/packages/NNlib/uX1eA/src/conv.jl:114 [inlined]                                                                                                                   
 [17] _pullback(::Zygote.Context, ::typeof(∇conv_filter), ::CuArray{Float32,4}, ::CuArray{Float32,4}, ::DenseConvDims{2,(3, 3),1,1,(1, 1),(1, 1, 1, 1),(1, 1),false}) at /home/kaliam/.julia/packages/Zygote
/IsBxF/src/compiler/interface2.jl:0                                                                                                                   
 [18] adjoint at /home/kaliam/.julia/packages/Zygote/IsBxF/src/lib/lib.jl:188 [inlined]                                                                                                            
 [19] _pullback at /home/kaliam/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:57 [inlined]                                                                                                          
 [20] #176 at /home/kaliam/.julia/packages/NNlib/uX1eA/src/conv.jl:229 [inlined]                                                                                                                           
 [21] _pullback(::Zygote.Context, ::NNlib.var"#176#179"{Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},CuArray{Float32,4},DenseConvDims{2,(3, 3),1,1,(1, 1),(1, 1, 1, 1),(1, 1),false}
}) at /home/kaliam/.julia/packages/Zygote/IsBxF/src/compiler/interface2.jl:0                                                                                                                           
 [22] Thunk at /home/kaliam/.julia/packages/ChainRulesCore/7d1hl/src/differentials/thunks.jl:98 [inlined]                                                                                                 
 [23] _pullback(::Zygote.Context, ::ChainRulesCore.Thunk{NNlib.var"#176#179"{Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},CuArray{Float32,4},DenseConvDims{2,(3, 3),1,1,(1, 1),(1, 1
, 1, 1),(1, 1),false}}}) at /home/kaliam/.julia/packages/Zygote/IsBxF/src/compiler/interface2.jl:0                                                                                                 
 [24] unthunk at /home/kaliam/.julia/packages/ChainRulesCore/7d1hl/src/differentials/thunks.jl:99 [inlined]                                                                                            
 [25] wrap_chainrules_output at /home/kaliam/.julia/packages/Zygote/IsBxF/src/compiler/chainrules.jl:41 [inlined]                                                                                          
 [26] (::Zygote.var"#534#538"{Zygote.Context,typeof(Zygote.wrap_chainrules_output)})(::ChainRulesCore.Thunk{NNlib.var"#176#179"{Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},CuArray
{Float32,4},DenseConvDims{2,(3, 3),1,1,(1, 1),(1, 1, 1, 1),(1, 1),false}}}) at /home/kaliam/.julia/packages/Zygote/IsBxF/src/lib/array.jl:181                                          
 [27] map at ./tuple.jl:159 [inlined]                                                                                                                                                  
 [28] map at ./tuple.jl:160 [inlined]                                                                                                        
 [29] ∇map at /home/kaliam/.julia/packages/Zygote/IsBxF/src/lib/array.jl:181 [inlined]                                                                                                       
 [30] adjoint at /home/kaliam/.julia/packages/Zygote/IsBxF/src/lib/array.jl:197 [inlined]             
 [31] _pullback at /home/kaliam/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:57 [inlined]         
 [32] wrap_chainrules_output at /home/kaliam/.julia/packages/Zygote/IsBxF/src/compiler/chainrules.jl:42 [inlined]                                                                                           
 [33] ZBack at /home/kaliam/.julia/packages/Zygote/IsBxF/src/compiler/chainrules.jl:77 [inlined]      
 [34] _pullback(::Zygote.Context, ::Zygote.ZBack{NNlib.var"#conv_pullback#177"{Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},CuArray{Float32,4},CuArray{Float32,4},DenseConvDims{2,(3
, 3),1,1,(1, 1),(1, 1, 1, 1),(1, 1),false}}}, ::CuArray{Float32,4}) at /home/kaliam/.julia/packages/Zygote/IsBxF/src/compiler/interface2.jl:0                                                               
 [35] Conv at /home/kaliam/.julia/packages/Flux/goUGu/src/layers/conv.jl:147 [inlined]                
 [36] _pullback(::Zygote.Context, ::typeof(∂(λ)), ::CuArray{Float32,4}) at /home/kaliam/.julia/packages/Zygote/IsBxF/src/compiler/interface2.jl:0                                                           
 [37] applychain at /home/kaliam/.julia/packages/Flux/goUGu/src/layers/basic.jl:36 [inlined]          
 [38] _pullback(::Zygote.Context, ::typeof(∂(applychain)), ::CuArray{Float32,2}) at /home/kaliam/.julia/packages/Zygote/IsBxF/src/compiler/interface2.jl:0                                                  
 [39] Chain at /home/kaliam/.julia/packages/Flux/goUGu/src/layers/basic.jl:38 [inlined]               
 [40] _pullback(::Zygote.Context, ::typeof(∂(λ)), ::CuArray{Float32,2}) at /home/kaliam/.julia/packages/Zygote/IsBxF/src/compiler/interface2.jl:0                                                           
 [41] #178 at /home/kaliam/.julia/packages/Zygote/IsBxF/src/lib/lib.jl:191 [inlined]                  
 [42] _pullback(::Zygote.Context, ::Zygote.var"#178#179"{typeof(∂(λ)),Tuple{Tuple{Nothing}}}, ::CuArray{Float32,2}) at /home/kaliam/.julia/packages/Zygote/IsBxF/src/compiler/interface2.jl:0               
 [43] #1698#back at /home/kaliam/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:59 [inlined]        
 [44] _pullback(::Zygote.Context, ::Zygote.var"#1698#back#180"{Zygote.var"#178#179"{typeof(∂(λ)),Tuple{Tuple{Nothing}}}}, ::CuArray{Float32,2}) at /home/kaliam/.julia/packages/Zygote/IsBxF/src/compiler/in
terface2.jl:0                                      
 [45] #62 at ./operators.jl:875 [inlined]                                                             
 [46] _pullback(::Zygote.Context, ::typeof(∂(λ)), ::CuArray{Float32,1}) at /home/kaliam/.julia/packages/Zygote/IsBxF/src/compiler/interface2.jl:0                                                           
 [47] #41 at /home/kaliam/.julia/packages/Zygote/IsBxF/src/compiler/interface.jl:41 [inlined]         
 [48] _pullback(::Zygote.Context, ::Zygote.var"#41#42"{typeof(∂(λ))}, ::CuArray{Float32,1}) at /home/kaliam/.julia/packages/Zygote/IsBxF/src/compiler/interface2.jl:0                                       
 [49] jacobian at /home/kaliam/.julia/packages/Zygote/IsBxF/src/lib/grad.jl:148 [inlined]             
 [50] _pullback(::Zygote.Context, ::typeof(jacobian), ::Chain{Tuple{Conv{2,2,typeof(relu),CuArray{Float32,4},CuArray{Float32,1}},typeof(flatten),Dense{typeof(identity),CuArray{Float32,2},CuArray{Float32,1
}}}}, ::CuArray{Float32,4}) at /home/kaliam/.julia/packages/Zygote/IsBxF/src/compiler/interface2.jl:0 
 [51] loss at ./REPL[34]:1 [inlined]               
 [52] _pullback(::Zygote.Context, ::typeof(loss), ::CuArray{Float32,4}) at /home/kaliam/.julia/packages/Zygote/IsBxF/src/compiler/interface2.jl:0                                                           
 [53] #13 at ./REPL[36]:2 [inlined]                
 [54] _pullback(::Zygote.Context, ::var"#13#14") at /home/kaliam/.julia/packages/Zygote/IsBxF/src/compiler/interface2.jl:0                                                                                  
 [55] pullback(::Function, ::Params) at /home/kaliam/.julia/packages/Zygote/IsBxF/src/compiler/interface.jl:247  

Thanks for helping out!

Also see this related Flux issue.

Nested AD is not yet fully supported in Zygote. Your best bet is to set it up the way DiffEqFlux does it using Zygote over ForwardDiff. For example I need jacobian vector products, i.e. directional derivatives, in my loss function and I set it up this way:

jvp(f,x,v) = ForwardDiff.partials.(f(ForwardDiff.Dual.(x, v)), 1)

You will additionally need to de-dualize the training gradients before you update the neural network weights. One way you can do this:

import Flux.update!
function update!(opt, x, x̄::AbstractArray{<:ForwardDiff.Dual})
    x̄ = getindex.(ForwardDiff.partials.(x̄), 1)
    Flux.update!(opt, x, x̄)
end
2 Likes

Very cool! This works perfectly for various Dense and Conv types of neural networks on the CPU, albeit very slow with ForwardDiff types in Conv. I assume you still need to use Zygote.pullback to get the gradients. However, when I use the GPU, I start getting different errors based on the NN I use, which tells me that I need to define more adjoints:

Using Chain(Dense(10,10)) gives the error:

ERROR: ArgumentError: invalid index: Dual{Nothing}(1,0,1) of type ForwardDiff.Dual{Nothing,Int64,2}

and using Chain(Dense(10,10,elu)) gives the error:

ERROR: MethodError: no method matching Float64(::ForwardDiff.Dual{Nothing,Float32,2})

Here’s the full code:

using Zygote, Flux, ZygoteRules, ForwardDiff

# From DiffEqFlux
ZygoteRules.@adjoint function ForwardDiff.Dual{T}(x, ẋ::Tuple) where T
  @assert length(ẋ) == 1
  ForwardDiff.Dual{T}(x, ẋ), ḋ -> (ḋ.partials[1], (ḋ.value,))
end

ZygoteRules.@adjoint ZygoteRules.literal_getproperty(d::ForwardDiff.Dual{T}, ::Val{:partials}) where T =
  d.partials, ṗ -> (ForwardDiff.Dual{T}(ṗ[1], 0),)

ZygoteRules.@adjoint ZygoteRules.literal_getproperty(d::ForwardDiff.Dual{T}, ::Val{:value}) where T =
  d.value, ẋ -> (ForwardDiff.Dual{T}(0, ẋ),)

# Define neural nets, and vecs

f = Chain(Dense(100,10)) |> gpu
x = rand(100) |> gpu
v.= rand(100) |> gpu

jvp(f,x,v) = ForwardDiff.partials.(f(ForwardDiff.Dual.(x, v)), 1)

function update!(opt, x, x̄::AbstractArray{<:ForwardDiff.Dual})
           x̄ = getindex.(ForwardDiff.partials.(x̄), 1)
           Flux.update!(opt, x, x̄)
end

loss(x,v) = sum(abs2,jvp(f,x,v))

loss(x,v) # Works

ps = Flux.params(f)
res,back = Zygote.pullback(ps) do
 loss(x,v)
end # Throws error here

update!(ADAM(0.01),ps,back(1f0))

Perhaps @ChrisRackauckas could help out here?

It’s a Zygote issue. Use ReverseDiff over Zygote until Diffractor is released.

2 Likes

So the code showed above will look a lot different, I observed that ReverseDiff does not have partials, and using ReverseDiff.jacobian instead leads to significantly more allocations (obviously) and a this intrinsic must be compiled to be called error…could you perhaps give a more explicit answer? Thanks!

You don’t need to modify much. There’s examples in the DiffEqFlux repo where the RHS of the ODE uses Zygote internally, and ReverseDiff for the VJP. It has to be that order though: the other fails.