How to compute partial derivatives efficiently with Zygote

Suppose we have a function f(x, y), and want to compute \partial_x f(x,y) for many y’s, ideally using no memory allocation. Using Zygote, I tried Zygote.forwarddiff(y->f(x,y), y).

However, I believe this creates a new anonymous function each time we call it with a different x. So the call to fowarddiff is not efficient.

Also, we can’t use gradient or forwarddiff on f directly, because in reality there are many more arguments (x,y,z,…) and computing and storing all of their derivatives is again inefficient.

Is there a way to achieve what I intend to do with Zygote or any other AD package?

Creating an anonymous function (closure) is more or less free after it’s compiled in Julia so I don’t think you need to worry about it.

But there is a reason why taking partial derivative can be slow in Zygote at the moment: https://github.com/FluxML/Zygote.jl/issues/323. Hopefully, this will be fixed by https://github.com/FluxML/Zygote.jl/pull/291.

2 Likes

Thank you for the pointers, and let’s hope so!