In my opinion, one of the key reasons why deep learning in Julia is light years behind PyTorch/JAX is the performance and convenience of Automatic Differentiation. There are so many AD packages, and each has its own tradeoff between speed and generality. I believe we should make it easy for users to pick the one that works best for them, which explains the creation of DifferentiationInterface.jl with @hill. That way, AD packages can coexist and even compete without causing confusion for downstream users. In addition, it reduces code duplication, because every ML ecosystem (Flux, Lux, Turing, SciML) has its own variant of an Enzyme/Zygote extension with gradient bindings, and we should just pool all of those.
I’ve been chatting with various power users of AD to see how they could leverage the interface. The conclusion is that it is much easier when what you want to differentiate is a vector (or ComponenVector), and not some arbitrarily complex (callable) struct like a Flux layer. So for this reason, I think Lux.jl is more suited to easy AD integration and backend switch.
Of course it doesn’t get us all of the way there, but to me it seems like a very important step.