Batched gradients and hessians with Flux

Hi,

I am new to Flux, so I am not sure how to do the following.

Say I have some model(params, x) where the model returns an A dimensional output, there are B params, and x is a C dimensional input.

I have D inputs such that the i^{\rm th} (with i ranging from 1 to D) input x_{i} is stored as the i^{\rm th} column in a C x D dimensional matrix X.

I would like to be able to calculate,

  1. \partial Y_{a} / \partial X_{c, d}, \partial^{2}Y_{a}/\partial X_{c_{1}, d}\partial X_{c_{2}, d} and
  2. \partial Y_{a} / \partial \theta_{b} for all D inputs.

From the documentation, I understand how to compute all of these quantities using a for loop over each of the D inputs. But is there a faster way to do this?

1 Like

Since you are new to using Flux, I presume the AD backend you are using is also Zygote, in which case the solution to the gradient for you is the Zygote.jacobian function, it is pretty smart and handles the case you are referring to pretty well:

f(x) = [x;2*x;3*x]
x = [1,2,3,4]
Zygote.jacobian(f, x)[1]

The output is a Matrix where each column corresponds to one of the inputs and the elements in a column go through all of the partial derivatives for each input entry, in that order.
For Hessians, however, you are out of luck because only the case of scalar output is considered.
I do find it to be counter-intuitive given Flux and Lux are formulated in a way that naturally leads you to think of columns as inputs and outputs but this line of thinking largely does not extend to Hessian matrices(even though the jacobian does). In the past, when I needed to use Hessian diagonals, I either resorted to numeric derivatives or implemented explicit Hessians myself.

That’s true but you can always compute the Jacobian of the Jacobian

I want to pass a vector input and get a vector output out. When I use jacobian from Zygote.jl, and pass a matrix of inputs X, it flattens all of X into a one-dimensional vector and then calculates the Jacobian. But this is essentially wrong for my problem, as effectively only the Jacobian wrt the first input (stored as the first column in X) has been evaluated. So maybe the way to go at this stage, would be to simply loop over all my samples?

This would calculate the full jacobian right? I forgot to mention this in my original post, but like @Hareruya I only need the diagonal part of the jacobian. Any suggestions to get this?

Correct me if I am wrong, but I believe the internal operations of Zygote.jacobian perform mutations since it calls inplace functions, which Zygote does not support, so this wouldn’t actually work. It fails in the example I wrote for instance.

A way of calculating only the Jacobians you want and the Hessians quickly, without scalar indexing, repeated calculations or mutations is to use the dual number implementations of ForwardDiff.jl, maybe? The issue here is that the Zygote gradient still does not work for unrelated reasons to mutations, so if you need derivatives of the derivatives in terms of the parameters (say for something like a PINN) you are out of luck there.

using Zygote, ForwardDiff, Enzyme
using ForwardDiff: Dual, partials

C = 4
D = 8
A = 2
T = Float32
model(M, x) = tanh.(M*x)
M = randn(T, A, C)

v = randn(T, C, D)
Partial_indicators = [Tuple([n==m ? 1 : 0 for m=1:C]) for n=1:C]
dv = Dual{T}.(v, Partial_indicators)
d2v = Dual{T}.(dv, Partial_indicators)

grads = getindex.(reshape(partials.(model(M,dv)), 1, A,D), 1:C)#Gradients in terms of each output dimension, of shape C x A x D
hess = getindex.(reshape(partials.(getindex.(reshape(partials.(model(M,d2v)), 1, A, D), 1:C)), 1, C, A, D), 1:C)#Hessians, of shape C x C x A x D

However, I presume you want to train with them afterwards since you are using Flux, and in that case the fact these jacobians and Hessians are not differentiable with Zygote could be troublesome. Doing a for loop and calculating these values one by one is a non-starter, since there are mutations in the gradient function in Zygote itself, so you can’t differentiate a Zygote gradient with Zygote again. I’ve found that Enzyme still works, though:

function loss(M, d2v)
    return sum(getindex.(reshape(partials.(getindex.(reshape(partials.(model(M, d2v)), 1, A, D), 1:C)), 1, C, A, D), 1:C))
end

Enzyme.gradient(Reverse, (M)->loss(M, d2v), M)

So if you are interested in training after, an idea could be to use ForwardDiff Duals to calculate the partial derivatives you are interested in and then doing the training with Enzyme as the AD backend.
EDIT: Also if you only want the diagonals of the Hessians, divide d2v into C different matrices where the only nonzero partial in the duals is the corresponding index of interest, maybe? Alternatively, you could avoid all of this hacky nonsense I cobbled together by using numerical derivatives instead of exact ones.

If your Jacobian only has a diagonal part, you can sum the output and take the gradient no ?

As a quite note I believe Enzyme supports multi arg gradients (which generally will be more performant and work on more code than creating closures).

function loss(M, d2v)
    return sum(getindex.(reshape(partials.(getindex.(reshape(partials.(model(M, d2v)), 1, A, D), 1:C)), 1, C, A, D), 1:C))
end

Enzyme.gradient(Reverse, loss, M, Const(d2v))
1 Like

Hi, thank you for this. For now, I am gonna try and use the solution you have proposed. I don’t understand partials very well, and I suspect it’s because I don’t understand dual numbers well - I will try to read up on these.

I was planning on updating the parameters of my Flux model directly, as my loss function is estimated using Monte Carlo and involves complex numbers at some stages.

@yolhan_mannes My Jacobian is not diagonal, but I only need the diagonal part of the Hessian.

@Hareruya @wsmoses Thank you for the suggestion regarding Enzyme.

1 Like

Don’t mention it!

I was planning on updating the parameters of my Flux model directly, as my loss function is estimated using Monte Carlo and involves complex numbers at some stages.

I am not sure how relevant this is to you, but both Zygote and Enzyme can differentiate through losses that use complex numbers as long as both ends of the process are real, i. e. as long as parameters and loss are all real numbers it will figure it all out.