Problems with AD inference on wrapper function

The answer is relatively straightforward but does require some background on how keyword function dispatch works under the hood, Julia Functions · The Julia Language. As seen from the example, the auto-generated “keyword sorter” function (a method of Core.kwcall on newer versions of Julia) uses conditionals to handle the presence or absence of certain kwargs. Because Zygote will unconditionally generate type unstable code when it encounters branching control flow, this means all calls with keyword arguments will be type unstable under AD.

As you’ve noted, this can be worked around by defining an rrule for the function in question. This is how e.g. sum(...; dims=...) can be type stable with Zygote: there are rules for it in ChainRules.jl. Unfortunately, having an rrule does preclude differentiating wrt. keyword arguments. @oxinabox has a far more in-depth series of posts about this and possible workarounds at Rrule (or frule) with kwargs.

1 Like