How to: High-performance differentiable programming with broad AD-library support?

Hi at all,

I am currently struggling with three (competing?) goals on code design. I want code, that: (1) is fast, (b) is differentiable and (c) differentiable by current (and future) AD-libraries (like at least ForwardDiff.jl and ReverseDiff.jl). And last but not least, easy maintainability (which is somehow connected to readable code) would be a thing :slight_smile:
I am facing some issues, thinking about how to make good code with that goals in mind. I am just sharing my thoughts and I am happy about corrections, comments or updates on them :slight_smile:

(1) Using Buffers
One of the core design patterns to make code fast (independent of Julia) is using buffers, that are allocated once and used e.g. during computations inside of a loop. Often, this doesn’t add that much complexity to the code. In Julia, I want static-typed buffers (for performance) so I need to preallocate with a given type (like Float64), however in an AD-run, I need the corresponding AD-primitives (like e.g. ForwardDiff.Dual or ReverseDiff.TrackedReal). What is the “nice” way of doing this?

(1a) I know the PreallocationTools.jl, but as I understand it’s for ForwardDiff only. However, there is now ReverseDiff as optional dependency… is there a ReverseDiff-support planned maybe?
EDIT: ReverseDiff.jl works with PreallocationTools.jl!

(1b) Zygote.jl seems complicated with buffers, because of allowing only non-mutable array operations. However, there is the new Zygote.Buffer - but this feels like additional code is necessary, especially for Zygote.

(2) Common AD-interface: ChainRules.jl
I really like the idea of ChainRules.jl and that deploying one forward and one backward rule (frule and rrule) is mathematically enough to build an interface to further AD-“backends”. Adding custom rules is not just for non-differntiable foreign calls, but also necessary for performance (e.g. AD-shortcuts over iterative procedures). However, I see multiple (big!) libraries, that don’t use ChainRules.jl and decide to implement dedicated dispatches for AD-primitives instead. Is this because ChainRules adds overhead compared to a pure AD-dispatch (like for ForwardDiff.Dual)?

Finally, it might be target-oriented to discuss this at an example:

# a struct, that "lives" some time during the application 
# and stores values and buffers, that are reused (it's mutable)
mutable struct LongLifeStruct 
   # the type of `a` will change during the application
   # `AbstractArray{<:Real}` is bad, but works with AD, because AD-primitives are 
   # subtypes of Real. I assume PreallocationTools is the way to go here?
   a::AbstractArray{<:Real} 
end

function doSomething(str::LongLifeStruct)
   str.a[:] = ... # whatever
end

# a struct, that is allocated for a special calculation and 
# freed afterwards (immutable)
struct ShortLifeStruct{T}
   # the type of `a` will not change during its "lifespan"
   # therefore we can allocate it "typed"
   # is this correct?
   a::AbstractArray{T} 
end

What would be (special cases neglected) the “correct” (=fast) way of implementing AD-support (ForwardDiff, ReverseDiff, Zygote, …) for the little code example above?

Thank you all!

3 Likes

(1) Using Buffers

You don’t need to preallocate with a fixed type, you just need the type to be statically inferrable.
For instance, if you have a function f!(buffer, x), you can allocate the buffer with the same eltype as x, and this will make it compatible with automatic differentiation in ForwardDiff.jl.

(2) Common AD-interface: ChainRules.jl

Unfortunately, as you have noticed, not all AD backends have the same rule definition mechanism. If you want to write differentiable code, here’s what you should care about most:

  • Ensure type generality for ForwardDiff.jl and ReverseDiff.jl (see previous message about buffers and eltype as an example)
  • Ensure no mutation / define custom chain rules for Zygote.jl and other ChainRules.jl-compatible backends
  • Ensure type stability / define custom Enzyme rules for Enzyme.jl, although it should be able to autodiff through a lot of code already without fixes

If you’re too lazy for that, there are bridges between the rule systems:

  • ForwardDiffChainRules.jl, which you wrote ^^
  • DifferentiationInterface.jl, especially the new DifferentiateWith(f, backend) mechanism. Essentially this defines chain rules that call another AD backend under the hood. So for instance you can tell Zygote.jl to “differentiate this function f with Enzyme.jl instead cause it does mutation under the hood”

(3) Maintenance and testing

Once you have written this beautiful, AD-universal source code you dream of, take a look at DifferentiationInterfaceTest.jl for your test suite. It will allow you to easily test several AD engines at once, compare their results against a reference, and even benchmark them against each other.

1 Like

Thanks for your replies!
I will check the provided links …

The current state feels a little unsatisfactory… fast and broad applicable AD is quite challenging to implement… it feels like there should be at least something like a “best practices” or “common coding patterns” for fast multi-library-AD (each library has its own coding examples, but always with focus on only supporting this specific library). Because these coding patterns depend on the AD libraries one want to support, it would be multiple cases to distinguish between (like only ForwardDiff.jl or ForwardDiff.jl and ReverseDiff.jl and so on).

Also having a coding pattern could be the first step for developing a macro for it (like a more mighty version of the current ForwardDiff.@ForwardDiff_frule and ReverseDiff.@grad_from_chainrules).

Or are there technical restrictions and the hard-to-swallow-pill is: If you want it fast, you need to implement it for every AD-library separate and optimized.

It’s fine for ReverseDiff. It has been for years. Why do you think it’s ForwardDiff only? I don’t think anything in the package says that? ReverseDiff, Tracker, Symbolics, etc. type-based ADs in general should be fine. If it says it’s for ForwardDiff anywhere on that page we should update it, but I did a find in its docs and I don’t see where this is said.

No, that’s misunderstanding the problem and solution. That’s only for differentiated values, not cache buffers.

That’s true. In this case this was a wrong assumption of mine… the readme however seems very focussed on ForwardDiff.jl… (it does not explicitly say it’s ForwardDiff only however). maybe one could add a line with AD-frameworks for which buffers can be deployed in the readme?

Any type-based is fine. It can be extended to do Zygote as well. It needs a hook for Enzyme that I’m looking into.

Oh right, my bad, so we’re closer to the preparation mechanism in DifferentiationInterface.jl?

1 Like

No, DI doesn’t handle this or touch these variables.

I’m sorry I’m struggling to understand without a concrete example

Take the tutorial from PreallocaitonTools.jl:

using ForwardDiff, PreallocationTools
randmat = rand(5, 3)
sto = similar(randmat)
stod = DiffCache(sto)

function claytonsample(sto, τ, α; randmat = randmat)
    sto = get_tmp(sto, τ)
    sto .= randmat
    τ == 0 && return sto

    n = size(sto, 1)
    for i in 1:n
        v = sto[i, 2]
        u = sto[i, 1]
        sto[i, 1] = (1 - u^(-τ) + u^(-τ) * v^(-(τ / (1 + τ))))^(-1 / τ) * α
        sto[i, 2] = (1 - u^(-τ) + u^(-τ) * v^(-(τ / (1 + τ))))^(-1 / τ)
    end
    return sto
end

ForwardDiff.derivative(τ -> claytonsample(stod, τ, 0.0), 0.3)

Or this might be more clear:

u = ones(5, 5)
du = ones(5, 5)
p = (ones(5, 5), DiffCache(zeros(5, 5), chunk_size))
function foo!(du, u, (A, tmp))
    tmp = get_tmp(tmp, u)
    mul!(tmp, A, u)
    @. du = u + tmp
    nothing
end
ForwardDiff.jacobian((dx,x) -> foo!(du, u, p), u)

It’s the tmp handling.

1 Like

Thanks, when I have more time I’ll dive into that and think about whether this would make sense in DI or not