Zygote differentiation very slow

I have an inner function I apply like this (a simple replicable example pasted here) . The input are chosen to be (5,5) but they are usually much larger. I did not do benchmark as this is not even close to acceptable. It usually crashes for larger realistic inputs.
One issue I think is causing this is having to differentiate with respect to all the inputs when I only need with respect to first input. I did not find a way to do so other than computing with respect to all and indexing it afterwards.

EDIT: Code fixed to reflect the question

using Zygote
function inner_func(p1, p2, p3)
    return sum(p1) + sum(p2)  + sum(p3)
end

# Define the function with the sample inner function
function coefficients(x,y,z)
    num_vectors = size(x,1)
    K = [inner_func(x[I,:],
                        y[J,:],
                        z[J,:]) for J in 1:num_vectors, I in 1:num_vectors]
    return K
end

x = rand(5,5)
y = rand(5,5)
z = rand(5,5)

inp_vector = [x,y,z]
# Compute the Jacobian and value at once, using Zygote
Zygote.jacobian(coefficients, inp_vector...)[1]

I want to wrap the function with loop in it and automatically differentiate through it rather than differentiating the inner function and building the matrices with outer loop. What is recommended when working with Zygote? Building matrices by myself every time I want to get higher order derivatives is inconvenient as shown below.

dK_dx = [Zygote.jacobian(inner_func, [x[I,:],y[J,:],z[J,:]]...)[1] for J in 1:size(x,1), I in 1:size(x,1) ]

I suspect the issue is that Zygote doesn’t like scalar indexing like inner_func(x[I], y[J], z[J], x[J]), it is optimized to work well on vectorized code.

What do you mean by higher-order derivatives?

In this case you could try Enzyme.jl, which has an interface to specify the differentiable vs constant inputs. if I understand correctly, you need the Jacobian of K wrt x only?

1 Like

Might be worth noting that inner_func is called with four numbers, but it looks like you expect arrays. Here x[1] isa Float64, although only the first 50 of 2500 elements are ever used:

Such scalar indexing in a loop is indeed Zygote’s worst nightmare, as Guillaume says. But you might intend to be indexing eachcol(x) instead, which (aside from giving completely different answers) won’t be as bad for Zygote:

julia> begin
       x = rand(3,3)  # smaller example
       y = rand(3,3)
       z = rand(3,3)

       inp_vector = [x,y,z]
       end;

julia> Zygote.jacobian(coefficients, inp_vector...)[1]  # only first 3 entries of x are used
9×9 Matrix{Float64}:
 2.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0
 1.0  1.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0
 1.0  0.0  1.0  0.0  0.0  0.0  0.0  0.0  0.0
 1.0  1.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  2.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  1.0  1.0  0.0  0.0  0.0  0.0  0.0  0.0
 1.0  0.0  1.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  1.0  1.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  2.0  0.0  0.0  0.0  0.0  0.0  0.0

julia> function coefficients_2(x,y,z)
           xcols = eachcol(x)
           ycols = eachcol(y)
           zcols = eachcol(z)
           inds = eachindex(xcols)
           K = [inner_func(xcol,   # this calls inner_func with 4 vectors
                           ycols[j],
                           zcols[j],
                           xcols[j]) for j in inds, xcol in xcols]
       end
coefficients_2 (generic function with 1 method)

julia> jacobian(x -> coefficients_2(x, y, z), x)[1]  # jacobian w.r.t. x alone
9×9 Matrix{Float64}:
 2.0  2.0  2.0  0.0  0.0  0.0  0.0  0.0  0.0
 1.0  1.0  1.0  1.0  1.0  1.0  0.0  0.0  0.0
 1.0  1.0  1.0  0.0  0.0  0.0  1.0  1.0  1.0
 1.0  1.0  1.0  1.0  1.0  1.0  0.0  0.0  0.0
 0.0  0.0  0.0  2.0  2.0  2.0  0.0  0.0  0.0
 0.0  0.0  0.0  1.0  1.0  1.0  1.0  1.0  1.0
 1.0  1.0  1.0  0.0  0.0  0.0  1.0  1.0  1.0
 0.0  0.0  0.0  1.0  1.0  1.0  1.0  1.0  1.0
 0.0  0.0  0.0  0.0  0.0  0.0  2.0  2.0  2.0
2 Likes

Thank you for prompt help!

right ,the indexing could be one of the issues. By higher order, I mean second and third order derivative of K with respect to x, the first input. I actually only need first element of K with respect to the first element of x vector. Right now I am probably doing first element of K with respect to all of the elements of x vector which should mostly be zero. I am not sure how to do so in Zygote in the setup like this.

is there a utility function in zygote for derivative of a element matrix `K_{ij}’ with respect to elements of another matrix ‘x_{ij}’ ?

I had other issues with Enzyme and wrote most of the code to use Zygote. I might have to revert back to Enzyme then.

@KapilKhanal could you maybe fix the example so that it does exactly what you intend it to do? Some aspects of it are weird at the moment. Perhaps you expect x[i] to be the i-th row of the matrix? But it is not, it is actually the i-th coefficient of the flattened matrix.

Oh Gotcha, that was not indexing correctly like I thought but I think the issue is still same in this case also.

I have fixed the code. I will have to think how to avoid scalar indexing within Zygote so that change is not reflected in the code yet.

thank you for pointing out the error in indexing. I thought it was giving me the first row. Is there a way to vectorize this properly? It’s very weird that julia itself does not care about vectorization but zygote does

Why do you compute all of K and use all of x then? You are actually looking for a scalar-to-scalar derivative?

Yes element by element derivative. I have added the code that does that in the code above but I want to just take the derivative of the function with loop in it instead of differentiating inner function and building matrix. I have another code that calculates the 2nd and 3rd order derivative and I want it to be able to just differentiate through the matrix building as well.

I’m sorry I still don’t understand that part

I need all of K element’s gradients with respect to corresponding index element in x. I wanted to write it such that its easier to understand when user just specifies dK/dx or d^2K/d^2x and the function would do so instead of just providing the derivative of the inner function and asking them to build matrices themselves. Also, that way code would be end to end differentiable by itself. There will be no need to manually get the matrices and get it working for example, if I am going to use this in an adjoint equation then manual would make sense but I am hoping to do reverse diff automatically without writing out the equations and getting the required partial via AD and having to do the adjoint derivation.

I am confusing you maybe because there’s something that I am not understanding here. Hopefully that makes sense. let me know if that’s not the right way of thinking about this

Let’s ignore the y and z which are constant for differentiation purposes. If I understand correctly, you have a function

x \in \mathbb{R}^{n \times m} \longmapsto K \in \mathbb{R}^{n \times m}

and you are only interested in “diagonal” partial derivatives like

\frac{\partial K_{ij}}{\partial x_{ij}} \quad \text{and} \quad \frac{\partial^2 K_{ij}}{\partial^2 x_{ij}} \quad \text{and} \quad \frac{\partial^3 K_{ij}}{\partial^3 x_{ij}}

but not in “non-diagonal” partial derivatives like

\frac{\partial K_{ij}}{\partial x_{kl}}

for (i,j) \neq (k,l). Does that sound right?

If the above is correct, then this code computes one of the quantities you are interested in:

julia> using ForwardDiff, FillArrays

julia> K(x) = x * x
K (generic function with 1 method)

julia> x = rand(3, 3);

julia> dx(x, i, j) = OneElement(one(eltype(x)), (i, j), axes(x))
dx (generic function with 1 method)

julia> function diagonal_derivative(K, x, i, j)
           step(t) = K(x + t * dx(x, i, j))
           full_derivative = ForwardDiff.derivative(step, 0)
           return full_derivative[i, j]
       end
diagonal_derivative (generic function with 1 method)

julia> diagonal_derivative(K, x, 1, 2)
0.9432869747177048

You don’t gain anything by computing those derivatives in reverse mode, in this case you will need n^2 function calls either way, and forward mode is usually more efficient.

If you need even more efficiency, you should define your function K in a non-allocating way, like K!(x_dest, x)

2 Likes