Which autodiff to currently use for a neural network backend?

Dear community,

I’ve been trying out some of the different autodiff packages out there. I’ve run some tests and simulations on the type of problem that I’m addressing. Basically neural networks with high dimensional input and low dimensional output.

A bit of background is that I’m trying to migrate parts of my code from pytorch into Julia and hopefully gaining the speed and expressive programming power that Julia may provide. :slight_smile:

So far the candidates I have in mind, along with my comments, are the following.

  • Knets gradient: Ran into problems with speed while running on CPU’s
  • ReverseDiff: Couldn’t get it to work, and don’t understand why.
  • ForwardDiff: Works fine and easy to use, but is slower than Zygote for my problems.
  • Zygote: Works well and seems really quite fast. Beats pytorch in differentiation speed, but the author warns that this package is not ready to be used as it is in it’s infancy.
  • Capstan: I really like the idea and believe that this could be really fast and nice, but currently I haven’t seen any momentum in the efforts despite Cassette being released.

All in all each of these packages are really cool and good at what they do in their own right. For my application though there might be some experiences here in the community that I can learn from.

I need them to work well on GPU’s and CPU’s and be fast.

Any pointers or advice from others who have played around with this?

5 Likes

Missing from your list is Flux, which seems like the obvious answer here. I assume there will be an upgrade path to Zygote at some point when sufficiently mature.

I had the impression ReverseDiff wasn’t working on 0.7 yet but could be wrong / out of date.

4 Likes

Does Flux spin it’s own reverse based autodiff? If so do you know how it performs against Zygote?

1 Like

I am pretty much in the same situation with derivative-based MCMC. I am experimenting with Zygote, and use ForwardDiff as a reliable fallback (not ideal for \mathbb{R}^n \to \mathbb{R}, but surprisingly effective nevertheless, larger chunks seem to work better for me). I am aware that Zygote is WIP, but issues get fixed quickly, the maintainers seem to be committed, and I can always fall back to ForwardDiff.

4 Likes

Note that MikeInnes is the author of both Flux and Zygote (which is located in FluxML).
So it may be reasonable to stick with Flux, and then when Zygote is ready making that switch ought to be seamless.

Also worth pointing out that because Zygote’s API is so simple, you don’t have to dedicate much code to building around it (eg, no need to manage compiled tapes).

1 Like

Re Flux, yes it implements its own reverse-mode AD. For my purposed (which are n-to-1) this has proved faster than the alternatives you listed.

Zygote I have barely played with, but as pointed out above it’s written by the same Mike, and shares some syntax e.g. for defining custom gradients.

You may want to check out


and

I have done many tests myself, and found Reversediff.jl to be the most versatile packages for reverse mode diff. It’s one of the few that supports my use case with nested differentiation.

1 Like

AFAIK Nabla.jl does not (yet?) work on 0.7/1.0.

Thanks everyone for your suggestions and pointers. :slight_smile:

You also have AutoGrad.jl (Part of KNet.jl).

Just a follow-up to this topic: I am experimenting with incorporating many of the above tools into a common framework to handle AD for the gradient (for use in Bayesian inference, primarily).

I found that

  1. When Zygote.jl works, it is amazing. But it is experimental, and the usual caveats apply.

  2. Flux.jl is pretty good for reverse-mode AD. But when it breaks, I find debugging difficult.

  3. ReverseDiff.jl is pretty reliable, except when it is missing AD rules for methods. Then these need to be defined.

  4. In general, when looking at the discussions of some of the AD packages, there is a general sentiment that great AD tools are just around the corner and perhaps maintaining/developing existing ones is wasted effort (eg #81 in Nabla.jl, ReverseDiff.jl’s README, and some others, I won’t link all of them). While this is understandable and projects based on eg Cassette.jl are indeed very promising, I think there would be value in keeping the existing tools working in the transition period which may take many months, if not years.

8 Likes