Automatic Differentiation (AD) in Julia vs. Python (or PyTorch)

(I find that there were related discussions already:

But they were at least 4 years old and do not answer my specific questions, so I create a new topic.)

One of my collaborators who are familiar with both Julia and Python shared with me his experience that it was much easier to build an AD-compatible code in Julia than in Python. I had written a Julia package that was built without AD-compatibility in mind (actually without even knowing what AD is), but he found that my package was already AD-compatible and was able to use it in his code that requires AD. On the other hand, his Python package developed without AD-compatibility in mind took a significant amount of time to be made AD-compatible (more than the time taken for developing the original code itself). His experience gave him an impression that that Julia code requires almost no effort for AD compatibility, whereas Python code requires a significant effort.

When I told this to another collaborator who is a PyTorch proponent, she disagreed with the observation. (Maybe the difference is from Python for the 1st collaborator vs. PyTorch for the 2nd collaborator?) She had an experience of building a simple AD framework herself, and she says that any AD framework relies on chain rules and a lookup table mapping elementary functions to their derivatives (sin to cos, x to 1, etc). She argues that Julia’s AD framework should have the same reliance, and therefore there can’t be any Julia-specific advantage over Python.

I don’t know much how AD is implemented in general, and I was not able to find a document that compares AD in Julia and PyTorch. So, I asked ChatGPT what it thinks, and the answer seems quite favorable to Julia, which is promising to me who wants to promote Julia in my team. But I am wondering how accurate these statements are. I will appreciate insights from experts!

Here goes the answer from ChatGPT:


Short Answer

Yes, Julia’s AD libraries do rely on a table of derivatives for built-in functions (\sin, \exp, \log, etc.), just like PyTorch or any other AD system. However, Julia’s advantage lies in how its AD can automatically transform and analyze all of your Julia code—loops, conditionals, user-defined types—at the compiler or IR level. By contrast, Python-based AD often revolves around tracing Python calls (or bridging to C/C++ extensions), which can be less transparent and more limiting.


1. The Same “Derivative Table,” but That’s Just the Start

  • Elementary Function Derivatives:
    Like any AD system, both Julia’s AD and PyTorch have rules for \sin'(x) = \cos(x),\ \exp'(x) = \exp(x), \ldots.
  • The Real Difference:
    These rules are only a small part of the puzzle. The rest hinges on how the AD system processes loops, conditionals, and custom library code that compose these elementary operations.

2. Julia’s AD Operates at the Language/Compiler Level

  • Source-to-Source / IR Transformations:
    Julia tools like Zygote or Enzyme transform Julia’s internal IR or use LLVM-level analysis. They see and rewrite your entire function—including loops and branching—rather than building a separate “graph” or tape at runtime.
  • No Special Primitives Needed:
    This means ordinary for loops and conditionals in Julia are differentiable without rewriting them into specialized ops like lax.fori_loop (JAX) or tf.while_loop (TensorFlow).

3. Libraries Written in Julia, Not C/C++

  • Fewer Black Boxes:
    Many Python numeric libraries are C/C++ extensions that AD cannot easily inspect unless custom backward passes exist. In Julia, packages are predominantly written in Julia.
  • Uniform Visibility:
    Because AD sees all-Julia code down to the lowest levels, it can differentiate advanced library routines without hitting opaque boundaries.

4. Seamless Integration vs. Runtime Tracing

  • Tracing Overhead:
    PyTorch (and others) often trace operations on tensors at runtime to build a computational graph, which can become awkward for complex control flow.
  • Julia’s “Native” Approach:
    Directly transforms your code at compile-time or IR-level, so advanced patterns (recursion, custom data structures, etc.) work out of the box. There’s less need for rewriting or special AD-friendly APIs.

Conclusion

Yes, both Julia and PyTorch rely on a derivative “lookup table” for elementary math functions. But Julia’s AD also benefits from:

  1. Source/IR-level integration (rather than pure runtime tracing),
  2. Uniform, all-in-Julia libraries (minimizing black-box native extensions), and
  3. Natural handling of loops and conditionals (no special APIs needed).

All of this makes Julia’s AD feel more native and flexible, providing a straightforward path to differentiating any code you write in Julia—beyond just the basic math functions.

4 Likes

In my opinion the key differentiator is multiple dispatch, moreso than anything else.

  • In Python, you cannot pass a PyTorch Tensor to a JAX function. At an abstract level, this is the result of Python being a single dispatch language.
  • In Julia, you can use the same sum function over any array type you want – normal arrays, sparse arrays, CUDA arrays, Metal arrays, distributed arrays, fill arrays, lazy arrays, static arrays, block arrays, masked arrays, etc., etc. And it will call the library’s implementation for sum, rather than a generic unoptimized one. This is why your code will just work, and won’t have to worry about calling those custom methods, like using jnp.sum instead of sum – the correct library-specific method will always be routed to.

This is why a lot of ADs just work out-of-the-box in Julia, but not in Python. Julia’s design, at its core, is very modular and makes it easier for libraries to fit together. Your codebase can use the generic functions, and those will be routed through to the library-specific (or AD-specific) method.

I made this comparison to help Python people “get it”:

This is also what makes it hard for libraries to be compatible with eachother, and as a result this is why PyTorch and JAX are wholly incompatible. (And also why Python AD doesn’t work out of the box!)

6 Likes

In Python, nearly all performance-critical functions need to be implemented via call-outs to libraries written in other languages, which AD tools can’t typically handle. So, your AD system needs explicit rules for every supported performance-sensitive library function, and often requires custom implementations of those functions in order to track the necessary state for backpropagation. That’s why JAX basically “re-implemented the universe” (it has its own numpy, its own scipy, its own ODE solver, …). This is practical if you’re Google, or if you only want to support a very small set of building blocks (e.g. standard neural-net components).

In Julia, while there are some important libaries in other languages that we exploit (e.g. LAPACK), much more software is written in Julia itself, which makes it more practical to compose generic AD tools like Enzyme or ForwardDiff with arbitrary Julia packages. There are still cases where custom chain rules are necessary or desirable, but it is a lot fewer.

That’s because the sum function is implemented in Julia. Both the built-in Python sum and numpy.sum are implemented in C — this is less about multiple dispatch, and more whether the language’s semantics allow high-performance abstractions. (e.g. the element-type uniformity of numpy arrays cannot be readily expressed in pure Python.) You can write a generic sum function in Python for any iterable container, but it is terribly slow.

6 Likes

Your code will likely work, if following good practices. Your code generally needs to be composable to AD well. You can break composability with overly narrow typing and array indexing, for example

function addone!(x::Array{Float64,1})
    for i in 1:length(x)
        x[i] .= x[i] + 1
    end
end

Some packages want to call your supplied function with other types such as dual numbers, and Float64 is too restrictive for some AD. Some people want to call your function with arbitrarily offset arrays, e.g. indexing x[-5:5], and an assumption of 1-indexing can cause problems elsewhere, even it differentiates well. Also, why restrict to one-dimensional arrays? Might as well drop most assumptions, including some we might do in other languages “for safety.”

I’m not sure if there is a definitive guide to writing code that plays well with others, but code that is generic and type-stable can often be differentiated with no modification. There are edge cases to watch out for, and sometimes one AD package will work better for a particular problem than others.

In Python, you might get away with import jax.numpy as np and differentiating existing code, but it’s not good coding practice and makes many assumptions. For AD to work, you may have to pray to many gods. In Julia, often more than sufficient to toss some salt over your shoulder.

As a mere amateur, I have picked up a few pointers from discourse and have enjoyed reliable AD in Julia without much thought.

1 Like

I tried to get such a guide started here (focused on playing well with AD packages): Differentiability · DifferentiationInterface.jl

Additions are welcome in the form of PRs!

3 Likes

ForwardDiff.jl cannot handle narrow typing, but Zygote.jl and Enzyme.jl can. A contrived example:

julia> using Zygote

julia> f(x::Float32) = 2x

julia> f(x::Float64) = 3x

julia> f'(1f0) # Float32
2.0f0

julia> f'(1.0) # Float64
3.0
2 Likes

Speaking from my experience, it was always faster to get code differentiable in PyTorch than in Julia.

Yes, you restrict yourself to PyTorch code but at least all PyTorch functions are differentiable. I guess there is also more PyTorch users than Julia users hence a lot of things are implemented in PyTorch already.

Some years ago, you also had to write Zygote-flavored Julia code to get the AD working. Now it is becoming a bit easier with DifferentiationInterface.jl + Enzyme.jl

@roflmaostc, thanks for sharing your experience!

You say as long as we use PyTorch functions, the resulting PyTorch code is differentiable. I am wondering, though, if there are any situations where it is difficult or overly complex to write a code using only PyTorch functions.

For example, back in the days when I was coding in MATLAB, I had to vectorize everything and avoid loops in order to achieve good performance. This was fine (and honestly even fun to devise clever ways to vectorize things) until I had to implement a code that needed to apply different operations on the elements of an array depending on the properties of individual elements. In that code I had to use a for loop, and the performance of the MATLAB code was unacceptably slow. I ported the code to Julia and observed orders-of-magnitude speedup (from an hour to <1 sec). It was an experience that made me to switch completely to Julia. I am wondering coding using only PyTorch functions might have similar limitations.

Yes, some code cannot be implemented efficiently in the vectorized style.

In my case, it was the Radon transform. I implemented it in Julia but until you would use Enzyme there was nothing to differentiate this efficiently in a reverse mode. Zygote was not able to handle this code.

Recently, I tested my it with Enzyme and it was within a performance factor of ~2-5 compared to my handwritten version.

But in Julia, you will find many real world examples where automatic differentiation fails because no-one has defined a rule for this specific function. In PyTorch, I personally never encountered such a case.