How efficient is the automatic differentation and optimization in comparison to PyTorch

Hi all,

how efficient is the current Julia AD and optimization landscape in comparsion to let’s say PyTorch?


For example in PyTorch I can do

arr1 = arr1.requires_grad_(True)
arr2 = arr2.requires_grad_(True)
optimizer = torch.optim.Adam([arr1, arr2])
loss = f(arr1, arr2)

# only with respect to arr1
arr1 = arr1.requires_grad_(True)
arr2 = arr1.requires_grad_(False)
optimizer = torch.optim.Adam([arr1])
loss = f(arr1, arr2)

I guess, PyTorch stores the computational graph for arr1 and arr2 and l.backward() avoids unnecessary computation as in the second part.


How do I achieve a similar behavior in Julia?
I always used ComponentArrays, Zygote and Optim.jl for such problems.

But Optim.jl only accepts one array, so I had to do:

# full
arr = ComponentArray(arr1=arr1, arr2=arr2)
loss = f(arr)
gradient(f, arr)
## put it in Optim

# only with respect to arr1 
# manual closure to avoid that arr2 is a optimizeable argument
f2(arr1) = f(arr1, arr2)
gradient(f2, arr1)
## put it in optim

But does Zygote avoid the unnecessary AD computations for arr2? Also, how do I easily exclude some arrays from optimization as in PyTorch? Is there a native way in Julia? Optim.jl always expects one array, so I feel like it does not work? To provide a easy construct as in PyTorch is not possible?

This is more like an open question but it always hinders me to start developing a toolbox for my field since I don’t know how to solve this challenge nicely. Maybe I can learn from reading some Lux.jl source code? I’m especially worried about expensive unnecessary calculations in the gradient pass.

Any hints would be appreciated :slight_smile:



If you make the x -> f(x, constant) closure then the AD engine will take a derivative only with respect to the first argument (there is no derivative with respect to the second argument calculated and thrown away).

If I understand correctly, you also have a UX question. You do not like the style of declaring a closure like this and you prefer the PyTorch style of annotating the argument itself. I personally disagree and believe PyTorch’s style is rather cumbersome and more of an implementation detail leaking out. My reasoning is that requiring your arguments to have a requires_grad_ method makes it difficult to differentiate through arbitrary types. It is also less natural of a way to write things, compared to how it would be written in a math textbook. Lastly, it forces you to decide in advance which arguments will be fixed, while you might want to have something more dynamic – the pytorch code for something like that would be quite stateful.

Nonetheless, you can totally do it that way as well.

  • If you are using ForwardDiff, you can make sure to use Dual types only for the arguments you derive with respect to.
  • You can build a cute little wrapper type struct RequiresDerivative x end and a derivative function such that you can write derivative(f, arg1, arg2, arg3) that checks isa(arg, RequiresDerivative) and builds up a closure where only these arguments are left free.
  • It can be a macro instead of a function if you want a neater domain specific language.

Lastly, consider how in Flux it is pretty straightforward to just take derivatives with respect to any global variable, no matter how it shows up in the expression. That is probably even more neater than what you describe. But I think one of the reasons Lux was created was because people did not like that part of the Flux API: it was super neat for demos, but it was unnecessarily magical for large engineering projects.

Edit: it seems the Flux folks also have opinions on the UX for taking derivatives. Check the comment at the very bottom of this page Quick Start · Flux about what the API looked like before version 0.14 and how it will look like in 0.15

Edit: I think I am still missing the situation in which PyTorch is much more comfortable to use than the various Julia AD engines in terms of UX. If you could elaborate on that and show more complete snippets of unpleasant Julia code we might be better able to suggest more pleasant to use Julia constructs.


Unless you need optimizers which PyTorch doesn’t have, save yourself a headache and use Optimisers.jl ;). It’s the real equivalent to torch.optim and supports (nested) structures:

# full
# setup() optimizer with (f1, f2)
loss = f(arr1, arr2)
gradient(f, arr1, arr2)
# update! with (arr1, arr2) and optimizer

# only with respect to arr1
# setup() optimizer with just f1
gradient(x -> f(x, arr2), arr1) # reusing f
# update! with just arr1 and optimizer

To the more general question. The currently available Julia ADs can be much, much faster than PyTorch if you fit 2/3 of the following criteria (depends on the AD):

  • Mostly work with scalars
  • Don’t need GPU support
  • Aren’t working with NNs

But since you’re currently using PyTorch, you probably fit 0/3.

Sometimes. It’s probably the worst when it comes to avoiding unnecessary computation, more on that later.

Use an overloading-based AD like ForwardDiff, ReverseDiff, GitHub - denizyuret/AutoGrad.jl: Julia port of the Python autograd package. or GitHub - FluxML/Tracker.jl: Flux's ex AD. No, seriously. As @Krastanov noted, using the AD array/parameter type only for actual parameters is basically what PyTorch does. Depending on your appetite for bleeding-edge software, Enzyme would work as well even though it’s a source-to-source AD.

The natural follow-up question may be, “why doesn’t Zygote do this?” Well, doing this as a static compiler analysis is difficult and requires many smarts. Zygote’s source code transform is quite “dumb” by comparison. The best way to stop gradients from being calculated for certain code is to wrap it in API · ChainRules, but that requires changes to the actual functions being called. If you can’t do that, see @Krastanov’s suggestion or my alternate AD list above.

Small correction: this is how it looks like as of 0.14, right this moment. The old alternative path is deprecated (we still need to add the actual warnings) and will completely disappear in the next major release. If you’re using it now and read this post, consider it a pre-PSA.


Ah great, didn’t know that.
But I’m using quite often L-BFGS (which is also exists in PyTorch).

But what about Optimization.jl. I thought that might take over the optimization landscape with a unified interface. But it requires the rosenbrock(x, p) with parameters p and argument x style?

Yeah I require reverse AD, I usually optimize on big CuArrays.

Actually, I’m more a fan of closures too :slight_smile:

But, from time to time I hit julia#15276 which is really a pain.
For me it’s not 100% clear in which situations it occurs, so I stick always with

f = let a=a, b=b
     function f(c)
             a .* b .+ c

But let’s say we have a code like this:

f(xx::ComponentArray) = xx.a .* xx.b .+ xx.c

We could probably do a macro:

f_new = @optimize_only c begin
     f(xx::ComponentArray) = xx.a .* xx.b .+ xx.c

which transforms it to:

f_new = let a=xx.a, b=xx.b
    f_new(xx) = a .* b .+ xx.c

You see the issue here? I think I need to use ComponentArrays in case for Optim.jl and Optimization.jl. But otherwise, the constant arrays would get optimized too.

Bigger example for Zygote

Let’s say we have a function like:

function conv(a, b)
    real.(ifft(fft(a) .* fft(b))

and I define:

a = randn((100, 100))
b = randn((100, 100))

g(a) = conv(a, b)
gradient(g, a)

does Zygote avoid the to calculate the gradient for b? If not, that decreases performance quite a bit. And there are more expensive cases of course.

This is where I’m really worried whereas in PyTorch this is handled correctly and efficiently.

Right, so L-BFGS is I think one of the few notable optimizers which is in PyTorch but not Optimisers.jl. I think nobody has added it so far because most people are training moderate-large sized NNs, but it should be a single PR away.

That’s precisely because it’s trying to have a uniform interface. Since libraries such as Optim.jl and ForwardDiff only accept arrays as parameters and not more complex structures, you’re stuck with that as a lowest common denominator.

I’m not sure, but I don’t think it can avoid the calculation.

See my earlier comment about how other Julia ADs can handle this.