Compiler specialisation in automatic differentiation

I am working with finite element method code. In this context, it is (very!!!) useful to use forward automatic differentiation to differentiate the function that, from degrees of freedom, compute the element’s contribution to the residual. Anyway: a function of the form

r = foo(x)

where x, r and everything in between are static arrays.

I started early: ForwardDiff.jl did no exist yet, so I have my homebrew which does essentially the same - I’ll refer to it as Adiff.

At the start of a differentiation process, the partials of a dual number are populated with Adiff:SVector, ForwardDiff:Tuple of zeros and a single one, and some can be propagated quite far into he computations, depending on the structure of foo. Things only get worse when differentiating to higher order. I don’t know about you, but I feel adding zeros and multiplying by ones is a waste of time :smiley: . So I messed around a little:

struct Nix end # structuraly zero
const nix = Nix()
@inline Base.:*( ::Nix   ,b::Number) = nix
@inline Base.:+( ::Nix   ,b::Number) = b
#... you get the gist
Base.show(io::IO,::Nix) = print(io,"⋅")

# make a few "sparse static" vectors
a = (nix,nix,2,nix,nix,1,nix,nix,nix,nix,nix,nix)
b = (1,nix,nix,2,nix,3,nix,4,nix,5,nix,6)

e = sum(a.*b)
#...

Both @code_native and BenchmarkTools suggest that this technique reduces the number of operations and computing times, and there is no penalty when full vectors are involved in the computations.

The idea now is that the compiler must produce specialised machine code for every combination of sparsity patterns in operations. This will obviously require longer computing times.

But my concern and question is: faced will generating a large number of specialised instances, will the compiler end up throwing the sponge and create a general code - with much worse performance than by just leaving well alone in Adiff and ForwardDiff? Of course the best is just to try, but at least in Adiff, that would be some work, so I am checking whether we can tell outright if this is not a good idea.

Generally, no. The compiler will happilly compile as many specific method instances as the user asks it too until OOM.

However I think you can likely do better than your current approach. Since you likely need to compile a lot of methods, your program is likely to be dominated by compilation times (but there’s a chance that it’s not - so measure!). Alternatively e.g. you could create a single type that just stores at which index the only non-zero element is. In fact that exists already in FillArrays.jl under the name OneElement. I don’t know how it behaves under AD though. With this approach you only need to conpile once but can still save the multiplications.

Alternatively, you could just try regular sparse arrays from the standard library.

If you can come up with a realistic benchmark, then feel free to post it here. That will likely give you the best feedback :slight_smile:

2 Likes

Thanks for the answer. Nice to know about “limitless specialisation”.

Remarks from my side.

  1. It has to be on the stack (so no Sparse) or execution times will be atrocious or I need to work on preallocated memory - but expressing maths in imperative programming, and pre-allocating… thanks but no thanks.
  2. It’s not just vectors with a single non-zero: that’s just how a “seed” looks like in forward automatic differentiation. As partials combine, partials will emerge with arbitrary sparsity patterns. Hence:
  3. Compilation times are going to be atrocious under my scheme. I have to learn about pre-compilation and caching of machine code.

I have an idea on how to create a compilation-time benchmark. I’ll come back with that in a few days.

:slight_smile:

2 Likes

Here comes the benchmark. Sorry about the long code, I’m sure I should have found a smart way to avoid the explicit coding,

using StaticArrays, BenchmarkTools, LinearAlgebra


# implement "nix"
struct Nix end 
const nix = Nix()
@inline Base.:*( ::Nix   ,b::Number) = nix
@inline Base.:*(a::Number, ::Nix   ) = nix
@inline Base.:*( ::Nix   , ::Nix   ) = nix
@inline Base.:/( ::Nix   ,b::Number) = nix
@inline Base.:/(a::Number, ::Nix   ) = Inf*sign(a)
@inline Base.:/( ::Nix   , ::Nix   ) = NaN
@inline Base.:+( ::Nix   ,b::Number) = b
@inline Base.:+(a::Number, ::Nix   ) = a
@inline Base.:+( ::Nix   , ::Nix   ) = nix
@inline Base.:-( ::Nix   ,b::Number) = -b
@inline Base.:-(a::Number, ::Nix   ) = a
@inline Base.:-( ::Nix   , ::Nix   ) = nix
Base.show(io::IO,::Nix) = print(io,"⋅")

# create a "vector" of "vectors" of partials.
# a: uses nixes.  We have npattern sparsity patterns to give the compiler a jog
# b: classic, full partials
npattern = 10 # hard coded, don't change this one
npartial = 12
sparsity = .3 # in [0.,1.]
a = () 
b = Vector{SVector{npartial,Float64}}(undef,npattern) # the same without sparsity
for ipattern = 1:npattern
    global a,b
    tmp         = @SVector [rand()<sparsity ? randn() : 0.  for ipartial= 1:npartial]
    b[ipattern] = tmp
    a           = (a...,ntuple(ipartial->tmp[ipartial]==0 ? nix : tmp[ipartial],npartial))
end

# We multiply all the combinations of vector of partial to force the compiler to
# create O(npattern²) multiplication and addition codes
function foo(a,x)
    for i = 1:10
        x = x .+ a[1].*a[1] 
        x = x .+ a[1].*a[2]
        x = x .+ a[1].*a[3]
        x = x .+ a[1].*a[4]
        x = x .+ a[1].*a[5]
        x = x .+ a[1].*a[6]
        x = x .+ a[1].*a[7]
        x = x .+ a[1].*a[8]
        x = x .+ a[1].*a[9]
        x = x .+ a[1].*a[10]
        x = x .+ a[2].*a[1]
        x = x .+ a[2].*a[2]
        x = x .+ a[2].*a[3]
        x = x .+ a[2].*a[4]
        x = x .+ a[2].*a[5]
        x = x .+ a[2].*a[6]
        x = x .+ a[2].*a[7]
# sorry, long code...        
        x = x .+ a[2].*a[8]
        x = x .+ a[2].*a[9]
        x = x .+ a[2].*a[10]
        x = x .+ a[3].*a[1]
        x = x .+ a[3].*a[2]
        x = x .+ a[3].*a[3]
        x = x .+ a[3].*a[4]
        x = x .+ a[3].*a[5]
        x = x .+ a[3].*a[6]
        x = x .+ a[3].*a[7]
        x = x .+ a[3].*a[8]
        x = x .+ a[3].*a[9]
        x = x .+ a[3].*a[10]
        x = x .+ a[4].*a[1]
        x = x .+ a[4].*a[2]
        x = x .+ a[4].*a[3]
        x = x .+ a[4].*a[4]
        x = x .+ a[4].*a[5]
        x = x .+ a[4].*a[6]
        x = x .+ a[4].*a[7]
        x = x .+ a[4].*a[8]
        x = x .+ a[4].*a[9]
        x = x .+ a[4].*a[10]
        x = x .+ a[5].*a[1]
        x = x .+ a[5].*a[2]
        x = x .+ a[5].*a[3]
        x = x .+ a[5].*a[4]
        x = x .+ a[5].*a[5]
        x = x .+ a[5].*a[6]
        x = x .+ a[5].*a[7]
        x = x .+ a[5].*a[8]
        x = x .+ a[5].*a[9]
        x = x .+ a[5].*a[10]
        x = x .+ a[6].*a[1]
        x = x .+ a[6].*a[2]
        x = x .+ a[6].*a[3]
        x = x .+ a[6].*a[4]
        x = x .+ a[6].*a[5]
        x = x .+ a[6].*a[6]
        x = x .+ a[6].*a[7]
        x = x .+ a[6].*a[8]
        x = x .+ a[6].*a[9]
        x = x .+ a[6].*a[10]
        x = x .+ a[7].*a[1]
        x = x .+ a[7].*a[2]
        x = x .+ a[7].*a[3]
        x = x .+ a[7].*a[4]
        x = x .+ a[7].*a[5]
        x = x .+ a[7].*a[6]
        x = x .+ a[7].*a[7]
        x = x .+ a[7].*a[8]
        x = x .+ a[7].*a[9]
        x = x .+ a[7].*a[10]
        x = x .+ a[8].*a[1]
        x = x .+ a[8].*a[2]
        x = x .+ a[8].*a[3]
        x = x .+ a[8].*a[4]
        x = x .+ a[8].*a[5]
        x = x .+ a[8].*a[6]
        x = x .+ a[8].*a[7]
        x = x .+ a[8].*a[8]
        x = x .+ a[8].*a[9]
        x = x .+ a[8].*a[10]
        x = x .+ a[9].*a[1]
        x = x .+ a[9].*a[2]
        x = x .+ a[9].*a[3]
        x = x .+ a[9].*a[4]
        x = x .+ a[9].*a[5]
        x = x .+ a[9].*a[6]
        x = x .+ a[9].*a[7]
        x = x .+ a[9].*a[8]
        x = x .+ a[9].*a[9]
        x = x .+ a[9].*a[10]
        x = x .+ a[10].*a[1]
        x = x .+ a[10].*a[2]
        x = x .+ a[10].*a[3]
        x = x .+ a[10].*a[4]
        x = x .+ a[10].*a[5]
        x = x .+ a[10].*a[6]
        x = x .+ a[10].*a[7]
        x = x .+ a[10].*a[8]
        x = x .+ a[10].*a[9]
        x = x .+ a[10].*a[10]
    end
    return x
end
# create two accumulators
x = SVector{npartial}(zeros(npartial))
y = SVector{npartial}(zeros(npartial))

# measure compilation times  (restart Julia...)
@time foo(a,x)
@time foo(b,x)

# measure execution times
@btime foo($a,$x)
@btime foo($b,$y)

@show all(foo(a,x) .== foo(b,y))
;

so the gain in execution time is great, and the compile time, not that bad.

So, should we adopt this technology for forward automatic differentiation?

  1. acceleration depends on the sparsity of the partials across the function to differentiate. Good early in the function, typically bad after a few steps. That said I have worked on another technique to accelerate differentiation, that is great at doing just that, see McLaurin line 93 and “fast”, and BeamElement line 158 for application. What it does is, separately differentiate at piece of code, then compose the derivatives. Even without gain from sparsity, in BeamElement there is gain from having shorter vectors of partials in selected parts of the code. Incidentaly, this technique alone deserves implementation in ForwardDiff.
  2. we’d need to rewrite ForwardDiff or my Adiff. Some work, but I’m game. For ForwardDiff, I’d prefer to be roped in with a guide.
  3. Any vector or matrix in a function to be differentiated now has elements of different types. We need a StaticHeterogenousArray. That is more intimidating.

The technique would only apply for functions that would naturally be implemented using StaticArrays. Limited application? Well, imagine if Julia, not only allowed to create remarkably compact code for for finite elements (or any form of discretisation in which small “elements” contribute to a “system” matrix) as is the case today, but also beat “hand written” codes on performance? That’d be another killer-app for Julia.

Anyone interested?

:slight_smile: