Compute gradient of gradient norm using zygote

I’m new to Zygote.jl and I’d like to compute the squared sum of a gradient, i.e.,

g = gradient(f, x),

q = gradient(g’*g, x)

Sucessfully did this using Forward mode:

function grad_g_sqr(f, d, x)
    function gs(x)
        a = ForwardDiff.gradient(f, x)
        return a' * a / 2
    gg = ForwardDiff.gradient(gs, x)
    return gg

I heard backward mode is better so try to use Zygote,

# change to backward ad.
function hessz(f, d, x)
    gs(x) = sum(Zygote.gradient(f, x)[1] .^ 2)
    gg = Zygote.gradient(gs, x)[1]
    Hd = Zygote.gradient(gd, x)[1]
    return gg

Then I caught the exception,

ERROR: Can't differentiate foreigncall expression

Is there anything I missed?

Hi and welcome to the community!
Could you post a minimal working example so that it is easier to help you?

At first glance, it seems the problem comes from nested derivatives with Zygote: maybe this link can help

Nested reverse AD with Zygote is very limited in terms of what it supports and (as I understand it) is almost never what you want. If you want to compute a hessian, check out Utilities · Zygote.

1 Like

I don’t think th OP wants a hessian? (Correct me if I’m wrong!)

I actually have the same situation quite often and fix it by writing wrappers for the chain rules. And then chainrules for the chainrules.

Sorry this is a really short and confusing reply. I’ll try to put together a script later to demonstrate what I mean.

That would be beautiful, because we receive quite a number of complaints regarding AD and it would simplify things to have a common line of defense;)

I was going off of function hessz(...), but mixing forward + reverse modes is just as applicable to any kind of nested AD.

Thanks gdalle, I finally wrapped up an example that is working. What I did:

julia> ff(x) = sum((x .- 1).^2)
ff (generic function with 1 method)
julia> using Zygote
julia> function hessz(f, x)
                  gs(x) = sum(Zygote.gradient(f, x)[1] .^ 2 / 2)
                  gg = Zygote.gradient(gs, x)[1]
                  return gg
hessz (generic function with 1 method)

julia> hessz(ff, zeros(3))
3-element Vector{Float64}:

In a more complex example, I realized that the problem is using Flux.logitcrossentropy, something like,

using Flux
using MLDatasets
using Flux: logitcrossentropy, normalise, onecold, onehotbatch
using Statistics: mean
using Zygote
using Parameters: @with_kw

@with_kw mutable struct Args
    lr::Float64 = 0.5
    repeat::Int = 110

function get_processed_data(args)
    labels = MLDatasets.Iris.labels()
    features = MLDatasets.Iris.features()

    # Subract mean, divide by std dev for normed mean of 0 and std dev of 1.
    normed_features = normalise(features, dims=2)

    klasses = sort(unique(labels))
    onehot_labels = onehotbatch(labels, klasses)

    # Split into training and test sets, 2/3 for training, 1/3 for test.
    train_indices = [1:3:150; 2:3:150]

    X_train = normed_features[:, train_indices]
    y_train = onehot_labels[:, train_indices]

    X_test = normed_features[:, 3:3:150]
    y_test = onehot_labels[:, 3:3:150]

    #repeat the data `args.repeat` times

    train_data_iter = Iterators.repeated((X_train, y_train), args.repeat)
    train_data = (X_train, y_train)
    test_data = (X_test, y_test)

    return train_data, train_data_iter, test_data

# Initialize hyperparameter arguments
args = Args(; lr=0.1)

#Loading processed data
train_data, train_data_iter, test_data = get_processed_data(args)
x_train, yc_train = train_data
x_test, yc_test = test_data
function logit_model(wbv, x)
    wb = reshape(wbv, 3, :)
    return wb[:, 1:end-1] * x .+ wb[:, end]

loss_train(wb) = Flux.logitcrossentropy(logit_model(wb, x_train), yc_train)

w0 = ones(15)

Then proceed,

julia> function hessz(f, x)
                         gs(x) = sum(Zygote.gradient(f, x)[1] .^ 2 / 2)
                         gg = Zygote.gradient(gs, x)[1]
                         return gg
hessz (generic function with 1 method)

julia> hessz(loss_train, w0)
ERROR: Can't differentiate foreigncall expression
  [1] error(s::String)
    @ Base ./error.jl:33
  [2] Pullback
    @ ./iddict.jl:102 [inlined]
  [3] (::typeof(∂(get)))(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/H6vD3/src/compiler/interface2.jl:0
  [4] Pullback
    @ ~/.julia/packages/Zygote/H6vD3/src/lib/lib.jl:68 [inlined]
  [5] (::typeof(∂(accum_global)))(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/H6vD3/src/compiler/interface2.jl:0
  [6] Pullback
    @ ~/.julia/packages/Zygote/H6vD3/src/lib/lib.jl:79 [inlined]
  [7] (::typeof(∂(λ)))(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/H6vD3/src/compiler/interface2.jl:0
  [8] Pullback
    @ ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:67 [inlined]
  [9] (::typeof(∂(λ)))(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/H6vD3/src/compiler/interface2.jl:0
 [10] getindex
    @ ./tuple.jl:29 [inlined]
 [11] map
    @ ./tuple.jl:222 [inlined]
 [12] unthunk_tangent
    @ ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:36 [inlined]
 [13] #1630#back
    @ ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:67 [inlined]
 [14] (::typeof(∂(λ)))(Δ::Tuple{Nothing, Vector{Float64}})
    @ Zygote ~/.julia/packages/Zygote/H6vD3/src/compiler/interface2.jl:0
 [15] Pullback
    @ ~/.julia/packages/Zygote/H6vD3/src/compiler/interface.jl:41 [inlined]
 [16] (::typeof(∂(λ)))(Δ::Tuple{Vector{Float64}})
    @ Zygote ~/.julia/packages/Zygote/H6vD3/src/compiler/interface2.jl:0
 [17] Pullback
    @ ~/.julia/packages/Zygote/H6vD3/src/compiler/interface.jl:76 [inlined]
 [18] (::typeof(∂(gradient)))(Δ::Tuple{Vector{Float64}})
    @ Zygote ~/.julia/packages/Zygote/H6vD3/src/compiler/interface2.jl:0
 [19] Pullback
    @ ./REPL[1]:2 [inlined]
 [20] (::typeof(∂(gs)))(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/H6vD3/src/compiler/interface2.jl:0
 [21] (::Zygote.var"#56#57"{typeof(∂(gs))})(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/H6vD3/src/compiler/interface.jl:41
 [22] gradient(f::Function, args::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/H6vD3/src/compiler/interface.jl:76
 [23] hessz(f::Function, x::Vector{Float64})
    @ Main ./REPL[1]:3
 [24] top-level scope
    @ REPL[2]:1

Thanks for the comment. I was actually trying to avoid using full Hessian. Do you have an example on how to use this mixing mode?

Looking forward!

Have a gander at Gradient of Gradient in Zygote - #3 by ChrisRackauckas.

1 Like

I really think the whole discussion of gradient of gradient vs Jacobean and using mixed backward forward doesn’t apply here. To compute gradient of f we use of course backward. Then g'*g is again a scalar so we should of course use backward mode. The canonical situation which is why I’m interested in it, and I assume might be the case here too, is when you train a model on gradients of the output.

So I looked back into my own codes and I have to admit I misremembered a few details. The reason we can do this is because we implement our own rrules for all first derivatives. If you do that, then you can take two zygote gradients. But - so far at least - I haven’t managed to produce a simple example where I take two derivatives without any intervention of this kind. I’ll keep trying and report back here if I find something.

Here is my toy example - I fully appreciate the is not what OP asked. The reason this works really well for us is that we are more than happy implementing the gradients ourselves (in fact we have to for performance reasons) and then just let Zygote differentiate the loss.

using Zygote, ChainRules
import ChainRules: rrule, NoTangent 

f(x) = sum(x[i]*x[i+1] for i = 1:length(x)-1)

function rrule(::typeof(f), x) 
  _pb_f(x, w::Number) = w * [ [x[2]]; [ x[i-1]+x[i+1] for i = 2:length(x)-1]; [x[end-1]] ]
  _pb_f(x, w) = (@show w; error("no pb for this")) 
  return f(x), w -> (NoTangent(), _pb_f(x, w))

grad_f(x) = Zygote.gradient(f, x)[1] 
L(x) = sum( grad_f(x).^2 )

x = rand(10)
@show L(x)
@show Zygote.gradient(L, x)[1]
1 Like

You are right about my motivations here. The interest on gradient of g’g basically is that we are trying to find an alternative for ADAM. Thanks for you example and efforts here, and I will keep you informed if I find anything useful. but I doubt it since obviously you got more expertise on this : )

not necessarily - I usually try until it works and then move on :). I’d be grateful to hear if you learn more about this.

What does the Zygote.hessian(f, x) mean? Is it implemented with mixing forward + reverse mode? Does it take the reverse derivative of f to get g, and then take the forward derivative of g to get h ?

1 Like

Yea. Think of it like Zygote generates some code which computes the gradient, then you push an array of ForwardDiff Dual’s through that code to get the jacobian (jac of grad being the hessian).

Should also be pretty equiv. to this:

using AbstractDifferentiation, LinearAlgebra, Zygote, ForwardDiff
AD.jacobian(AD.ForwardDiffBackend(), x -> AD.gradient(AD.ZygoteBackend(), x -> norm(x), x)[1], [1,2,3])[1]

which can also be written

AD.hessian(AD.HigherOrderBackend((AD.ForwardDiffBackend(), AD.ZygoteBackend())), norm, [1,2,3])[1]

(I’ve recently been playing more with AbstractDifferentiation.jl which IMO is coming out really nice, and an easy way to quickly swap out these different backends or try different combinations and see what works / is fast)


I wonder if we should have documentation somewhere showing how to mix AD libraries. e.g. for the MWE in this thread, maybe Zygote over Tracker could work. That would require people to test out some combinations (i.e. non-existent extra maintainer time), so if anyone is interested let me know and I can help you get started.

1 Like

Why doesn’t this work? What’s the interface for taking the second derivative of a function of one variable?

AD.hessian(AD.HigherOrderBackend((AD.ForwardDiffBackend(), AD.ZygoteBackend())), x->x^3, 1)[1]

There is none, you’ll need to make it a length-1 vector and wrap/unwrap it:

AD.hessian(AD.HigherOrderBackend((AD.ForwardDiffBackend(), AD.ZygoteBackend())), x->x[1]^3, [1])[1][1]

I actually asked a related (yet-unanswered) question here: Whats the reason for the derivative / gradient difference? · Issue #61 · JuliaDiff/AbstractDifferentiation.jl · GitHub