AD of LoopVectorization with Zygote?

I’m writing a package to perform some physics calculations on large 3D arrays. In order to perform these calculations I need to compute a few quantities, in particular differential operators of the fields like the gradient with respect to the position (the 3 dimensions of the array correspond to the x, y and z coordinates), which I compute in Fourier space by multiplying by the wavevectors. In practice this is done by computing a FFT of the field and doing the product with a precomputed array containing the wavevectors.

Since the arrays are large and I need to perform many operations like the above, I defined mutable structs ScalarFieldCore and VectorFieldCore holding the fields as arrays in the field S and V, and definitions for basic operations such as Base.:* which leverage LoopVectorization to gain a significant boost in performance.
Here’s a relevant example of one of the Base.:* methods I implemented:

function Base.:*(
    a::ScalarFieldCore{Complex{T}},
    b::Array{T, 3}
)::ScalarFieldCore{Complex{T}} where T<:Real
    A = StructArray(a.S)
    return ScalarFieldCore{Complex{T}}(
        a.L,
        a.Ng,
        Array(@tturbo StructArray{Complex{T}}(@. (A.re*b, A.im*b)))
    )
end
Base.:*(a::Array{T, 3}, b::ScalarFieldCore{Complex{T}}) where T<:Real = Base.:*(b, a)

This turns out to work really well, and I was able to gain great leaps in performance. The issue is that now I want to make my code autodifferentiable. The initial field I feed into my code depends on some parameters, and I want to be able to compute the gradient with respect to those.

I found that Zygote works out of the box with generic structs, so I built a simple example with my code of a function that takes the parameters and spits out a number which is the sum of the final field across the x, y and z positions. However I got the following error:

MethodError: no method matching vmaterialize(::Array{Float32, 3}, ::Val{:GridSPTCore}, ::Val{(true, 0, 0, 0, true, 0, 32, 15, 64, 0x0000000000000010, 1, true)})
The function `vmaterialize` exists, but no method is defined for this combination of argument types.

The stacktrace points to the Base.:* method I pasted above.[1] I found this issue: Hook into Zygote.jl? Β· Issue #108 Β· JuliaSIMD/LoopVectorization.jl Β· GitHub, that looks identical to what I’m finding, but it’s still open after 5 years…

Unfortunately I’m really a beginner when it comes to AD, and what I’m finding online is too obscure for me to get a good grip on how I can make LoopVectorization and Zygote work together. On the other hand I’m really reluctant in letting go of the performance improvement I got out of LoopVectorization. How hard do you think it is to code a β€œadjoint of vmaterialize”? Is it worth the effort or should I adopt some other performance-enhancement method?


  1. Just to be sure I tried redefining Base:.* so that it just does a normal broadcast a.S .* b, and that works fine with Zygote β†©οΈŽ

The low-tech way is to provide a gradient rule, and accelerate both forward & backward with LV. Which will be most easily done on just the arrays, and then call this numerical function from the method involving your structs:

mul(A::AbstractArray{<:Complex}, B::AbstractArray{<:Real}) = A .* B  # replace with @tturbo implementation of same

function rrule(::typeof(mul), A::AbstractArray{<:Complex}, B::AbstractArray{<:Real})
  C = mul(A, B)
  function mul_back(dCraw)
     dC = unthunk(dCraw)
     dA = dC .* conj.(B)  # these might need reduction, depending on sizes!
     dB = dC .* conj.(A)  # can similarly use @tturbo within this function
     (NoTangent(), dA, dB)
  end
  C, mul_back
end

function Base.:*(a::ScalarFieldCore{Complex{T}}, b::Array{T, 3}) where T<:Real
    # call mul() to do the work
    # then re-build ScalarFieldCore
end

What size are A, B here? What I wrote is probably right if they are the same size, but if broadcasting is extending dimensions, you will need to sum them back down.

Aside, I think that the way you are using StructArray creates more copies than ideal. You could consider storing these the whole time? Or just reinterpreting the Array{Complex{T}} to Array{T}, multiplying by b, and reinterpteting back.

2 Likes

Thank you! Your example prompted me to check out the documentation of ChainRules which was very illuminating (I admit it took me a good afternoon to figure out why the conj was there…). In order to make it work with Zygote I had to explicitly call using ChainRulesCore. Looking at the documentation I might decide to implement rules directly for my structs, since it doesn’t look too difficult.

A and B are both the same size always, so no reduction problems.

About StructArray, I’m not sure if I can improve it much further. My code solves some perturbation theory, from a linear order fields it computes all the non-linear orders iteratively, using the previous orders already computed. I did some benchmarks and for a starting Float32 array of (220, 220, 220), in the end getting 6 (220, 220, 220) fields and 6 (3, 220, 220, 220) fields, these are the results:

BenchmarkTools.Trial: 7 samples with 1 evaluation per sample.
 Range (min … max):  9.069 s …    9.705 s  β”Š GC (min … max): 4.23% … 6.64%
 Time  (median):     9.182 s               β”Š GC (median):    3.94%
 Time  (mean Β± Οƒ):   9.227 s Β± 218.405 ms  β”Š GC (mean Β± Οƒ):  4.41% Β± 1.00%

  β–ˆ      ▁  β–ˆ  ▁                                           ▁  
  β–ˆβ–β–β–β–β–β–β–ˆβ–β–β–ˆβ–β–β–ˆβ–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–ˆ ▁
  9.07 s         Histogram: frequency by time         9.71 s <

 Memory estimate: 15.92 GiB, allocs estimate: 37608019.

The idea of reinterpreting the fields as Real is interesting, I can give it a shot. Thanks for the tip!

Now that I solved the problem with the gradient, I still get an unrelated error:

type ScaledPlan has no field region

I’m doing a planned irfft from the FFTW package. Not sure why it complains since it should’ve been fixed according to ScaledPlan: region subfield missing (needed for AD) Β· Issue #182 Β· JuliaMath/FFTW.jl Β· GitHub (my dependencies are also up to date). Since I’m performing FFTs and iFFTs in sequence to get derivatives, I think I can get away with using @ignore_derivatives on them, what do you think?