Gradient of Gradient in Zygote

The gradient of the gradient is not necessarily something that is well-defined. As @stevengj mentions, what you’re looking for is a Jacobian of a gradient. The rows of a Jacobian are the gradients with respect to each output. So you’re really looking for the Jacobian of the gradient, which is the Jacobian of the Jacobian of a function with a single output, which is also known as the Hessian.

I couldn’t figure out in my course why so many students were confused about this until I heard “but TensorFlow can do it”. How can TensorFlow do the impossible? Good question. It just hides what’s really going on.

https://www.tensorflow.org/api_docs/python/tf/gradients

Constructs symbolic derivatives of sum of ys w.r.t. x in xs .

So tf.gradients on a function f is actually doing the gradient of sum(f(x)), not on f(x) itself. Since the sum of all outputs is always a single output, that makes the gradient well-defined. But honestly, this is just confusing a whole generation of ML practioners and from what I can tell most people doing tf.gradients(tf.gradients(...)) aren’t aware of this assumption.

In terms of reverse-mode automatic differentiation, what’s going on here is that you do a pushforward of f with the value x, and then you have to pullback some vector. In Zygote, this is:

out,back = Zygote.pullback(f,x)
back(y)

where y is the pullback value. The size of y should match size(f(x)), i.e. the size of the output of the function. If you have a scalar function then y=1 calculates the gradient, i.e.

function gradient(f,x)
  out,back = Zygote.pullback(f,x)
  back(1)
end

is essentially how it’s implemented. If you have multiple outputs, then the pullback doesn’t have a single uniquely “trivial” pullback. However, note that using y = e_i, i.e. the ith basis vector e_i = [0,0,0,..,1,...,0,0] where the only non-zero is the ith value, makes

  out,back = Zygote.pullback(f,x)
  back(e_i)

compute the ith row of the Jacobian, since that is the gradient w.r.t. the ith output. So when linearized pulling back the identity matrix gives the Jacobian. But what it’s actually computing is J'v, which means that pulling back the vector of all ones is equivalent to the gradient of sum(f(x)) which is what TensorFlow just decides to do (because… ? :thinking: yeah I don’t like it) .

So once you do one gradient, you have a multiple output function on which you can calculate the Jacobian which is what Zygote.Hessian does. But since your Hessian is going to be symmetric, it probably doesn’t make sense to use reverse mode to calculate the Jacobian. So what makes sense instead is to use forward-mode automatic differentiation to calculate the Jacobian over the gradient function of Zygote, which is known as forward-over-reverse. The code for this is:

function forward_over_reverse_hessian(f,θ)
  ForwardDiff.jacobian(θ) do θ
    Zygote.gradient(x -> _f(x, args...), θ)[1]
  end
end

Note you need the following adjoints to connect ForwardDiff into Zygote:

# ForwardDiff integration

ZygoteRules.@adjoint function ForwardDiff.Dual{T}(x, ẋ::Tuple) where T
  @assert length(ẋ) == 1
  ForwardDiff.Dual{T}(x, ẋ), ḋ -> (ḋ.partials[1], (ḋ.value,))
end

ZygoteRules.@adjoint ZygoteRules.literal_getproperty(d::ForwardDiff.Dual{T}, ::Val{:partials}) where T =
  d.partials, ṗ -> (ForwardDiff.Dual{T}(ṗ[1], 0),)

ZygoteRules.@adjoint ZygoteRules.literal_getproperty(d::ForwardDiff.Dual{T}, ::Val{:value}) where T =
  d.value, ẋ -> (ForwardDiff.Dual{T}(0, ẋ),)

(@dhairyagandhi96 this should be added to Zygote IMO, but it also points out limitations in the derivative rule handling on types which is where this issue and this issue come from).

Note that you don’t always need to compute the full Hessian. If your algorithm only needs to compute H*v, you can pre-seed a direction for the forward mode (kind of like choosing y=v) in order to directly compute the Hessian-vector product without computing the full Hessian. You can do this with:

function autoback_hesvec(f, x, v)
    g = x -> first(Zygote.gradient(f,x))
    ForwardDiff.partials.(g(ForwardDiff.Dual{Nothing}.(x, v)), 1)
end

as implemented in:

Hopefully that’s pretty comprehensive on the double differentiation. If you want to know more about the topic, check out the 18.337 lecture notes:

specifically the part on reverse-mode AD:

https://mitmath.github.io/18337/lecture10/estimation_identification

One final thing to note is that there is a new AD that’s coming and will improve nested AD performance. It’s called Diffractor.jl and you can hear more about it here:

That said, even that won’t make gradients of gradients a real thing :wink:.

8 Likes