Zygote and StaticArrays

I am currently evaluating Zygote. My “usecase” is maybe a bit different from what has driven Zygote’s development: I want to take the gradient of scalar-valued functions of 10 to 100 inputes, but I will compute millions of such gradients.

So my functions must not allocate [on the heap] and I am a keen user of StaticArrays. Here is a little benchmark:

using Zygote,StaticArrays,BenchmarkTools

struct MyType
    a  :: Float64
    b  :: SVector{2,Float64}
    c  :: Float64
    d  :: SVector{2,Int64}
function foo(o::MyType, δX,X,U,A,  t)
    δW       = (δX[o.d]' * o.b) * (o.a + A[1]) 
    δW      += δX[3  ] * -(o.c+A[2])
    return δW

el       = MyType(3e3,SVector(10.,0), -1e4,SVector(1,2))
δX,X,U,A = SVector(1.,1,1), SVector(1.,2,3), SVector(), SVector(0.,0)

@btime foo($el, $δX,$X,$U,$A,0.)
>    1.300 ns (0 allocations: 0 bytes)

@btime Lδx,Lx,Lu,La = gradient(δX,X,U,A) do δX,X,U,A
    foo($el, δX,X,U,A,0.) 
>  2.444 μs (48 allocations: 7.27 KiB)

@show typeof.((Lδx,Lx,Lu,La))
> typeof.((Lδx, Lx, Lu, La)) = (SizedVector{3, Float64, Vector{Float64}}, Nothing, Nothing, SizedVector{2, Float64, Vector{Float64}})

I have several questions, quite possibly related.

  1. While on a simple example (not included), I manage to get Zygote to provide me with SVector gradient, here I get SizedVector - which as I understand it, go on the heap. (Nothing is OK, the test includes edgecases). Why does “gradient(foo)” not return Static Arrays?

  2. Indeed the gradient computation does quite a bit of allocation, and is quite slow. Unfortunately, I find that also to be the case in my simpler example (not included). Is that a consequence of 1) and 2), or (I fear) just part of the deal with Zygote, which is efficient for larger numbers of parameters?


Do the timings change when you interpolate the inputs in the btime macro? E.g. by adding the $ sign in front of all dependent variables in particular the el inside the do block!

1 Like

Indeed! And I should have remembered. This certainly solves the issue of foo apparently allocating.

I have edited code and questions accordingly. This only makes the contrast in performance between foo and its gradient more apparent: a factor 1000, while nflops goes up by something of the order of 10.

In general, Zygote tends to “lose” StaticArray gradients because none of its internals are aware of them. That means something as pedestrian as gradient(x -> x[1] + sum(x), SA[1,2,3]) could (and does) end up on an allocating path and returning a SizedArray. Have you considered alternate AD libraries like ReverseDiff?

1 Like

Interesting. I will try ReverseDiff.jl

I have tried ReverseDiff, and while that does not prove anything, I, at least, get very poor performance out of it. As for Zygote, it’s a mixed bag for my use, here are my tests:

using Zygote,StaticArrays,BenchmarkTools,LinearAlgebra

const i = SMatrix{2,3}(1,0,0,1,0,0)
const j = SMatrix{2,3}(0,0,1,0,0,1)
const A  = SVector(1.,2.,3.)
const B  = SVector(2.,2.,2.)

const foo(A,B) = dot(A,B) 
const bar(A,B) = dot(view(A,SVector(1,2)),view(A,SVector(2,3))) 
const baz(A,B) = dot(i*A,j*B) 

@btime foo($A,$B)
@btime gradient($foo,$A,$B)
∇a,∇b = gradient(foo,A,B)
@show typeof.((∇a,∇b))

>  0.800 ns (0 allocations: 0 bytes)
>  0.800 ns (0 allocations: 0 bytes)
> typeof.((∇a, ∇b)) = (SVector{3, Float64}, SVector{3, Float64})

@btime bar($A,$B)
@btime gradient($bar,$A,$B)
∇a,∇b = gradient(bar,A,B)
@show typeof.((∇a,∇b))

>  1.000 ns (0 allocations: 0 bytes)
>  72.500 μs (489 allocations: 23.67 KiB)               OUCH!
> typeof.((∇a, ∇b)) = (MVector{3, Float64}, Nothing)

@btime baz($A,$B)
@btime gradient($baz,$A,$B)
∇a,∇b = gradient(baz,A,B)
@show typeof.((∇a,∇b))

>   2.000 ns (0 allocations: 0 bytes)
>  55.589 ns (1 allocation: 16 bytes)
> typeof.((∇a, ∇b)) = (SVector{3, Float64}, SVector{3, Float64})


  1. I get good performance of whole array operations. No allocations, the gradients are SVectors.
  2. If I read-access / view into arrays, performance is a disaster, I get MVectors which go on the heap
  3. I try a workaround, replacing views by a multiplication by matrices i and j of zeros and one. The direct code is a bit slower, and the gradient code performs reasonably. I get SVectors, but still an allocation.



In general, Zygote tends to “lose” StaticArray gradients because none of its internals are aware of them. That means something as pedestrian as gradient(x -> x[1] + sum(x), SA[1,2,3]) could (and does) end up on an allocating path and returning a SizedArray.

I have seen

@Zygote.adjoint (T::Type{<:SVector})(x::Number...     ) = T(x...), dv -> (nothing, dv...)
@Zygote.adjoint (T::Type{<:SVector})(x::AbstractVector) = T(x   ), dv -> (nothing, dv   )
# https://github.com/FluxML/Zygote.jl/issues/570

which I only vaguely understand. It’s maybe not directly relevant here, but makes me wonder: can I as a user make Zygote aware of SVectors? Could I somehow specify that the adjoint of reading into an immmutable array is creating a new immutable array from an old one, adding values to some of the indices?

I am not looking for a quick fix, but I currently need to make a choice of software design, which hinges on the ability to efficiently evaluate Hessians of “many to one” functions that operate on static arrays for performance. Performance can come later (and I can work my way up to fixing it myself), but I need to know whether it will be possible to read-access into static arrays and still get fast gradients.

Zygote + SVector is going to be a bit of an adventure. As ToucheSir points out, it’s not designed in, and quite a few rules will do something like call similar & then write into the result. Depending what your function does, you may be able to work around these things.

If your arrays are small enough to use SVectors, then you might be better off with ForwardDiff than Zygote. Especially for Hessians:

julia> @btime ForwardDiff.gradient(A -> bar(A,$B), $A)  # was 111.500 μs with Zygote
  min 2.000 ns, mean 2.082 ns (0 allocations)
3-element SVector{3, Float64} with indices SOneTo(3):

julia> @btime ForwardDiff.hessian(A -> bar(A,$B), $A) isa SMatrix
  min 45.795 ns, mean 46.482 ns (0 allocations)

julia> @btime Zygote.hessian(A -> dot(A,$B), $A) isa Matrix  # ForwardDiff over Zygote
  min 651.235 ns, mean 701.743 ns (7 allocations, 432 bytes)

This is certainly the least adventurous solution. You could also investigate using Enzyme, not certain it’ll work but someone may know.

Note that the view here is just a new SVector, it’s been overloaded. You could start providing gradients for such overloaded functions, but there are likely to be quite a few. (And I think this one is a Zygote @adjoint which takes preference over an rrule. So @opt_out may not work.)

julia> view(A,SVector(1,2))
2-element SVector{2, Float64} with indices SOneTo(2):

It is possibly to globally fix this, by writing a stricter rule for ProjectTo(::SArray). This won’t fix the use of MVectors etc. inside hand-written rules, but should be able to correct them back to SVectors before being passed to the next rule. How much this will help I’m not certain.

1 Like

Do you get the types that you want? If so, I think the performance part can be addressed by compiling the gradient tape (couldn’t find a good docs reference, but there are examples of this on Discourse). If not, you already know some of the nitty-gritty behind how Zygote (doesn’t) handles StaticArrays.

“a bit of an adventure”, aye, if not an epic. :smiley:

Having read a little about an adjoints, I believe the adjoint of

y = x[i]

is (in essence, forgive the Matlab notation)

dx[ i] = dy 
dx[~i] = 0

I attempt to code that on immutables, without allocating, restricted to vectors, to warm up:

function foo(j,i,∂y)
    for k ∈ eachindex(i)
        if j==i[k]
            return ∂y[k]
    return zero(eltype(∂y))
@Zygote.adjoint Base.getindex(x::SVector{Nx},i) where{Nx} = getindex(x,i), ∂y->ntuple(j->foo(j,i,∂y),Nx) # ∂x[i]=∂y

Note that I this code, the gradient is a Tuple. If I make the adjoint an SVector and run my tests above I get an error

ERROR: Gradient [0, 1, 0, 0, 0] should be a tuple
 [1] error(s::String)
   @ Base .\error.jl:33
 [2] gradtuple1(x::SVector{5, Int64})
   @ ZygoteRules C:\Users\philippem\.julia\packages\ZygoteRules\AIbCs\src\adjoint.jl:24
 [3] (::var"#16#back#9"{var"#5#7"{5, SVector{2, Int64}}})(Δ::Vector{Int64})
   @ Main C:\Users\philippem\.julia\packages\ZygoteRules\AIbCs\src\adjoint.jl:67
 [4] (::Zygote.var"#60#61"{var"#16#back#9"{var"#5#7"{5, SVector{2, Int64}}}})(Δ::Vector{Int64})
   @ Zygote C:\Users\philippem\.julia\packages\Zygote\xGkZ5\src\compiler\interface.jl:45
 [5] top-level scope
   @ show.jl:1047

which my brain cannot parse.
With my adjoint as coded above, there is no change in performance and I get an MVector. I know too little of Zygote…

So yes… ForwardDiff. For length(x)==3, your test show excellent performance of Hessian with ForwardDiff. I will have lengths up to 100, and there the penalty will be more significant.

I’ll also have a look at Enzyme.

I tested: Enzyme passes my tests! :grin: Incredible performance straight out of the box.
The only restriction I did identify is that differentiation is with respect to a single vector. But… if you can extract multiple vectors from the single vector, that’s no issue. And anyway, I’ll start really studying the API now, maybe multiples vectors are somehow supported.

using Enzyme,StaticArrays,BenchmarkTools,LinearAlgebra

const A  = SVector(1.,2.,3.)
const B  = SVector(2.,2.,2.)

const foo(A) = dot(A,B) 
const bar(A) = dot(A[SVector(1,2)],B[SVector(2,3)]) 

mode = Reverse
@btime foo($A)
∇a = @btime gradient(mode,$foo,$A)
@show typeof(∇a)

@btime bar($A)
∇a = @btime gradient(mode,$bar,$A) 
@show typeof(∇a)


  1.000 ns (0 allocations: 0 bytes)
  4.500 ns (0 allocations: 0 bytes)
typeof(∇a) = SVector{3, Float64}
  0.800 ns (0 allocations: 0 bytes)
  4.700 ns (0 allocations: 0 bytes)
typeof(∇a) = SVector{3, Float64}

And what’s more, it’s Friday afternoon and the sun is shining! :sunglasses: :sunny: :parasol_on_ground:


Nice! The multi-input API is a little more low-level, have a look at autodiff and how it’s used in the docs.

Here’s an example using the autodiff API to get all inputs:

A  = SVector(1.,2.,3.); B  = SVector(2.,2.,2.)
autodiff(Reverse, dot, Active, Active(A), Active(B))
# ([2.0, 2.0, 2.0], [1.0, 2.0, 3.0])
1 Like

Great, I’ll study the docs knowing that there are very useful thing in the API beyond the little I have seen.

For non technical reasons, I must put this project on the backburner for a few weeks, but I do so knowing I have the sketch of a solution.

Thank you everyone!