[ANN] DifferentiationInterface - gradients for everyone

Julia’s composability has rather interesting consequences for its automatic differentiation ecosystem. Whereas Python programmers first choose a backend (like PyTorch or JAX) and then write code that is specifically tailored to it, Julians first write their code and then make it differentiable for one or more of the many available backends (like ForwardDiff.jl or Zygote.jl). Forward and reverse mode, numeric and symbolic, Julia has it all. But it’s not always obvious which option is best suited for a given application.

We hold these truths to be self-evident that all autodiff backends are created equal, and that Julia users should be able to quickly try them out without diving into a dozen different APIs to figure out correct syntax and performance tricks. @hill and I are thus proud to present DifferentiationInterface.jl, an efficient and flexible interface to every differentiation method you could possibly dream of.

Basics

With DifferentiationInterface, an elementary gradient computation looks like this:

using DifferentiationInterface
import ForwardDiff

backend = AutoForwardDiff()

f(x) = sum(abs2, x)
x = rand(10)

g = gradient(f, backend, x)

Backend objects are defined by ADTypes.jl and we support a dozen of them, from golden oldies ForwardDiff.jl and Zygote.jl to new fan favorites like Enzyme.jl… and even experimental ones like FastDifferentiation.jl and Tapir.jl!

We provide operators for first-order and second-order differentiation: derivative, gradient, jacobian, second_derivative, hessian, as well as the lower-level pushforward, pullback and hvp. These operators support only numbers or arrays as inputs and outputs. They use nearly optimal backend-specific implementations whenever possible, and reasonable fallbacks the rest of the time.

Check out the tutorial for a simple example, the overview for a detailed walkthrough, and the list of backends for a taste of the endless possibilities.

Performance

The Julia community loves optimizing for maximum performance, and would never give up speed for convenience. That is why DifferentiationInterface was designed from the ground up with type stability and memory efficiency in mind. Users can very easily:

  • choose between function signatures f(x) = y and f!(y, x)
  • perform differentiation out-of-place (allocating a new gradient) or in-place (mutating the gradient)
  • prepare a config / tape / cache / etc. when they plan several differentiation calls in a row

Advanced

Among the fancier features offered by DifferentiationInterface, let us mention:

  • sparse Jacobians and Hessians with arbitrary backends (a generic and clean reimplementation of SparseDiffTools.jl, albeit still partial)
  • flexible backend combinations for second order, with efficient Hessian-vector products
  • DifferentiateWith(f, backend), which defines a chain rule to make Zygote.jl use another backend (like Enzyme.jl) for part of the reverse-pass

Testing and benchmarking

The companion package DifferentiationInterfaceTest.jl allows you to quickly compare autodiff backends, on pre-defined or custom scenarios. It includes:

  • correctness tests, either against a ground truth or against a reference backend
  • type stability tests
  • detailed benchmarks
  • scenarios with weird array types (JLArrays.jl for GPU, StaticArrays.jl, ComponentArrays.jl)

We think this can be of great use for package developers to test which backends support their functions, and how fast differentiation can be. It has also enabled us to help diagnose and fix bugs in several autodiff packages.
Once again, the tutorial is a great place to start.

What about AbstractDifferentiation?

AbstractDifferentiation.jl was our main inspiration for this work, and we learned a lot of lessons from its design. To alleviate some of its current limitations (notably around mutation and caching), we imposed certain restrictions on our code. DifferentiationInterface accepts only one input x and one output y, and guarantees support for number or array types but nothing beyond that. Given the existence of ComponentArrays.jl, these rules seem fairly mild, and they made our life considerably easier for robust implementation and testing.

Still, some use cases may require differentiation with respect to multiple inputs or non-array types. This level of generality has always been the goal of AbstractDifferentiation, and so it will remain.
After discussing the matter with @mohamed82008 to avoid XKCD #927, the way forward looks like this:

  • DifferentiationInterface will keep its narrow focus, corresponding to the intersection of all autodiff packages in terms of functionalities
  • AbstractDifferentiation will wrap DifferentiationInterface for the simple cases, and aim for the union of all autodiff packages when it comes to handling funky inputs

Roadmap

Now that the foundations are solid, here are some things we would like to work on next:

  • help downstream users adopt our package, with Optimization.jl as a first target (thanks to @Vaibhavdixit02) and larger ecosystems like SciML and Turing in our sights (ping @ChrisRackauckas)
  • support chunked / batched pushforwards and pullbacks, to reduce the number of function calls necessary for Jacobian and Hessian computations
  • increase performance of second order operators, which are not yet fully optimized
  • improve sparsity handling, especially with respect to structured matrices
  • transfer the package to JuliaDiff once its design has finally stabilized

Interested?

DifferentiationInterface is not yet completely stable, but it follows semantic versioning, and the latest v0.3 release is very well tested.
We encourage you to try it out in your own projects, and open an issue if you find yourself struggling.
Oh, and of course, a star on the GitHub repo would be very much appreciated :wink:

66 Likes