On simple reduction tasks Zygote seems to perform exceptionally poorly, and I would be grateful to hear why and what alternative patterns I should use to avoid these problems? I’ve started implementing custom adjoints via rrule but if I do this for every little piece of code then I might as well implement all gradients myself and forget about Zygote.
Are you saying that Zygote is to be used in the ms and not ns regime? That’s Ok!
But what if my code is sufficiently complex that it evaluates the models in the ms to s regime but goes many layers deep into inner functions that evaluate in the 100ns regime and this is where I currently loose a factor 1000 and more in performance.
I’m working a lot with custom adjoints. Here, I really just want to understand what is going on.
Back to the topic at hand, the long and short of it seems to be that there aren’t optimized adjoints in place for mapreduce-like functions (including sum(f, ...)) yet. This becomes pretty clear from looking at the currently defined adjoints. Higher-order functions in general are not well optimized for AD right now, so it’s no surprise that they perform terribly. You can verify this for yourself by pulling up GitHub - JuliaDebug/Cthulhu.jl: The slow descent into madness and stepping into the pullback.
On the bright side, things are slowly changing for the better. https://github.com/JuliaDiff/ChainRules.jl/pull/441 is the first step in getting better higher-order function AD working. I don’t remember if Diffractor (next gen AD system) will help with this, but it’s worth looking out for that as well. In the meantime, your options are to a) avoid higher-order functions, or b) look into alternative AD systems other than Zygote (e.g. ForwardDiff)
Maybe I should explain why I’m hunting for performance in the first place - it is because a collaborator has implemented a very similar model as mine (nothing to do with the the model above) in Tensorflow and claims his AD gradients are only a factor 1.5 slower than my hand-optimised ones which are I believe within a factor 2-3 from optimal.
I figured combining Zygote + ForwardDiff was my best chance at getting something like that in Julia. But maybe his claim is just bogus and I should test that myself.
To my knowledge the TF AD doesn’t support higher-order functions at all. If/when you get more info on good comparative examples, feel free to start a thread here and we can take a look at it.
thanks for suggesting those alternative packages. I understand that there are many alternatives. My choice of Zygote was that this was going to be the future of AD and I cannot afford to invest in something that will be replaced in a few months.
But then I just learned that Zygote will be replaced with something called Diffractor?
So what should I be investing in? Will I be relatively safe if I focus my development around ChainRules and use Zygote or any of these others only sparingly?
Thank you, that’s an important distinction I just haven’t made because I’m so used to these in Julia that I don’t even think about them anymore. But would you say that avoiding higher order functions will automatically lead to faster Zygote gradients?
You shouldn’t have to invest. The AD systems (will) share a rule system, ChainRules.jl, so the modifications you make (i.e. added rules) would cross over to other systems in a nice way.
I would say any function where you have a cheap to compute callback that is called many, many times (e.g. sum(f, x)) doesn’t perform great right now relative to the forward pass. If f is expensive to compute or the number of elements is small, it’ll be less of an issue.