How to compute partial derivatives efficiently with Zygote

Suppose we have a function f(x, y), and want to compute \partial_x f(x,y) for many y’s, ideally using no memory allocation. Using Zygote, I tried Zygote.forwarddiff(y->f(x,y), y).

However, I believe this creates a new anonymous function each time we call it with a different x. So the call to fowarddiff is not efficient.

Also, we can’t use gradient or forwarddiff on f directly, because in reality there are many more arguments (x,y,z,…) and computing and storing all of their derivatives is again inefficient.

Is there a way to achieve what I intend to do with Zygote or any other AD package?

1 Like

Creating an anonymous function (closure) is more or less free after it’s compiled in Julia so I don’t think you need to worry about it.

But there is a reason why taking partial derivative can be slow in Zygote at the moment: Slow backward pass when the forward pass touches a large array · Issue #323 · FluxML/Zygote.jl · GitHub. Hopefully, this will be fixed by [WIP] Use ChainRules by oxinabox · Pull Request #291 · FluxML/Zygote.jl · GitHub.

2 Likes

Thank you for the pointers, and let’s hope so!