How to profile Zygote gradients

Hi,

I’m currently running an expensive optimisation with a code that uses Zygote for automatic differentiation. The gradient computation seems to dominate the computational cost. Is it possible to profile which parts of the code cause an expensive gradient computation?

2 Likes

Any profiler you’re comfortable with should work. There are two main internal functions to look for:

  1. _pullback(::F, args...): this is the augmented primal function/foward pass of a particular callable of type F. You can expect to see it show up roughly in the same place you’d usually see the same method of F when not differentiating.
  2. Pullback{F, ...} or ∂(F): this is the wrapper that handles the backwards pass. It will forward to a ChainRules rrule or Zygote @adjoint if one exists, and otherwise contains an automatically generated pullback.

See Changing the Primal · ChainRules for more on the terminology above. Note that because Zygote uses generated functions, some line numbers will be fancy. Wherever possible though, Zygote tries to preserve the original source locations of function calls so that you can see which method is being sent through AD.

2 Likes

Thanks for the answer!

1 Like