Compute gradient of gradient norm using zygote

I’m new to Zygote.jl and I’d like to compute the squared sum of a gradient, i.e.,

g = gradient(f, x),

q = gradient(g’*g, x)

Sucessfully did this using Forward mode:

function grad_g_sqr(f, d, x)
    function gs(x)
        a = ForwardDiff.gradient(f, x)
        return a' * a / 2
    end
    gg = ForwardDiff.gradient(gs, x)
    return gg
end

I heard backward mode is better so try to use Zygote,

# change to backward ad.
function hessz(f, d, x)
    gs(x) = sum(Zygote.gradient(f, x)[1] .^ 2)
    gg = Zygote.gradient(gs, x)[1]
    Hd = Zygote.gradient(gd, x)[1]
    return gg
end

Then I caught the exception,

ERROR: Can't differentiate foreigncall expression

Is there anything I missed?

Hi and welcome to the community!
Could you post a minimal working example so that it is easier to help you?

At first glance, it seems the problem comes from nested derivatives with Zygote: maybe this link can help

Nested reverse AD with Zygote is very limited in terms of what it supports and (as I understand it) is almost never what you want. If you want to compute a hessian, check out Utilities · Zygote.

1 Like

I don’t think th OP wants a hessian? (Correct me if I’m wrong!)

I actually have the same situation quite often and fix it by writing wrappers for the chain rules. And then chainrules for the chainrules.

Sorry this is a really short and confusing reply. I’ll try to put together a script later to demonstrate what I mean.

That would be beautiful, because we receive quite a number of complaints regarding AD and it would simplify things to have a common line of defense;)

I was going off of function hessz(...), but mixing forward + reverse modes is just as applicable to any kind of nested AD.

Thanks gdalle, I finally wrapped up an example that is working. What I did:

julia> ff(x) = sum((x .- 1).^2)
ff (generic function with 1 method)
julia> using Zygote
julia> function hessz(f, x)
                  gs(x) = sum(Zygote.gradient(f, x)[1] .^ 2 / 2)
                  gg = Zygote.gradient(gs, x)[1]
                  return gg
              end
hessz (generic function with 1 method)

julia> hessz(ff, zeros(3))
3-element Vector{Float64}:
 -4.0
 -4.0
 -4.0

In a more complex example, I realized that the problem is using Flux.logitcrossentropy, something like,

using Flux
using MLDatasets
using Flux: logitcrossentropy, normalise, onecold, onehotbatch
using Statistics: mean
using Zygote
using Parameters: @with_kw

@with_kw mutable struct Args
    lr::Float64 = 0.5
    repeat::Int = 110
end

function get_processed_data(args)
    labels = MLDatasets.Iris.labels()
    features = MLDatasets.Iris.features()

    # Subract mean, divide by std dev for normed mean of 0 and std dev of 1.
    normed_features = normalise(features, dims=2)

    klasses = sort(unique(labels))
    onehot_labels = onehotbatch(labels, klasses)

    # Split into training and test sets, 2/3 for training, 1/3 for test.
    train_indices = [1:3:150; 2:3:150]

    X_train = normed_features[:, train_indices]
    y_train = onehot_labels[:, train_indices]

    X_test = normed_features[:, 3:3:150]
    y_test = onehot_labels[:, 3:3:150]

    #repeat the data `args.repeat` times

    train_data_iter = Iterators.repeated((X_train, y_train), args.repeat)
    train_data = (X_train, y_train)
    test_data = (X_test, y_test)

    return train_data, train_data_iter, test_data
end



# Initialize hyperparameter arguments
args = Args(; lr=0.1)

#Loading processed data
train_data, train_data_iter, test_data = get_processed_data(args)
x_train, yc_train = train_data
x_test, yc_test = test_data
function logit_model(wbv, x)
    wb = reshape(wbv, 3, :)
    return wb[:, 1:end-1] * x .+ wb[:, end]
end

loss_train(wb) = Flux.logitcrossentropy(logit_model(wb, x_train), yc_train)

w0 = ones(15)

Then proceed,

julia> function hessz(f, x)
                         gs(x) = sum(Zygote.gradient(f, x)[1] .^ 2 / 2)
                         gg = Zygote.gradient(gs, x)[1]
                         return gg
                     end
hessz (generic function with 1 method)

julia> hessz(loss_train, w0)
ERROR: Can't differentiate foreigncall expression
Stacktrace:
  [1] error(s::String)
    @ Base ./error.jl:33
  [2] Pullback
    @ ./iddict.jl:102 [inlined]
  [3] (::typeof(∂(get)))(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/H6vD3/src/compiler/interface2.jl:0
  [4] Pullback
    @ ~/.julia/packages/Zygote/H6vD3/src/lib/lib.jl:68 [inlined]
  [5] (::typeof(∂(accum_global)))(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/H6vD3/src/compiler/interface2.jl:0
  [6] Pullback
    @ ~/.julia/packages/Zygote/H6vD3/src/lib/lib.jl:79 [inlined]
  [7] (::typeof(∂(λ)))(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/H6vD3/src/compiler/interface2.jl:0
  [8] Pullback
    @ ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:67 [inlined]
  [9] (::typeof(∂(λ)))(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/H6vD3/src/compiler/interface2.jl:0
 [10] getindex
    @ ./tuple.jl:29 [inlined]
 [11] map
    @ ./tuple.jl:222 [inlined]
 [12] unthunk_tangent
    @ ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:36 [inlined]
 [13] #1630#back
    @ ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:67 [inlined]
 [14] (::typeof(∂(λ)))(Δ::Tuple{Nothing, Vector{Float64}})
    @ Zygote ~/.julia/packages/Zygote/H6vD3/src/compiler/interface2.jl:0
 [15] Pullback
    @ ~/.julia/packages/Zygote/H6vD3/src/compiler/interface.jl:41 [inlined]
 [16] (::typeof(∂(λ)))(Δ::Tuple{Vector{Float64}})
    @ Zygote ~/.julia/packages/Zygote/H6vD3/src/compiler/interface2.jl:0
 [17] Pullback
    @ ~/.julia/packages/Zygote/H6vD3/src/compiler/interface.jl:76 [inlined]
 [18] (::typeof(∂(gradient)))(Δ::Tuple{Vector{Float64}})
    @ Zygote ~/.julia/packages/Zygote/H6vD3/src/compiler/interface2.jl:0
 [19] Pullback
    @ ./REPL[1]:2 [inlined]
 [20] (::typeof(∂(gs)))(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/H6vD3/src/compiler/interface2.jl:0
 [21] (::Zygote.var"#56#57"{typeof(∂(gs))})(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/H6vD3/src/compiler/interface.jl:41
 [22] gradient(f::Function, args::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/H6vD3/src/compiler/interface.jl:76
 [23] hessz(f::Function, x::Vector{Float64})
    @ Main ./REPL[1]:3
 [24] top-level scope
    @ REPL[2]:1

Thanks for the comment. I was actually trying to avoid using full Hessian. Do you have an example on how to use this mixing mode?

Looking forward!

Have a gander at Gradient of Gradient in Zygote - #3 by ChrisRackauckas.

1 Like

I really think the whole discussion of gradient of gradient vs Jacobean and using mixed backward forward doesn’t apply here. To compute gradient of f we use of course backward. Then g'*g is again a scalar so we should of course use backward mode. The canonical situation which is why I’m interested in it, and I assume might be the case here too, is when you train a model on gradients of the output.

So I looked back into my own codes and I have to admit I misremembered a few details. The reason we can do this is because we implement our own rrules for all first derivatives. If you do that, then you can take two zygote gradients. But - so far at least - I haven’t managed to produce a simple example where I take two derivatives without any intervention of this kind. I’ll keep trying and report back here if I find something.

Here is my toy example - I fully appreciate the is not what OP asked. The reason this works really well for us is that we are more than happy implementing the gradients ourselves (in fact we have to for performance reasons) and then just let Zygote differentiate the loss.

using Zygote, ChainRules
import ChainRules: rrule, NoTangent 

f(x) = sum(x[i]*x[i+1] for i = 1:length(x)-1)

function rrule(::typeof(f), x) 
  _pb_f(x, w::Number) = w * [ [x[2]]; [ x[i-1]+x[i+1] for i = 2:length(x)-1]; [x[end-1]] ]
  _pb_f(x, w) = (@show w; error("no pb for this")) 
  return f(x), w -> (NoTangent(), _pb_f(x, w))
end

grad_f(x) = Zygote.gradient(f, x)[1] 
L(x) = sum( grad_f(x).^2 )

x = rand(10)
@show L(x)
@show Zygote.gradient(L, x)[1]
1 Like

You are right about my motivations here. The interest on gradient of g’g basically is that we are trying to find an alternative for ADAM. Thanks for you example and efforts here, and I will keep you informed if I find anything useful. but I doubt it since obviously you got more expertise on this : )

not necessarily - I usually try until it works and then move on :). I’d be grateful to hear if you learn more about this.