ChainRulesCore: Custom adjoint ignored in certain cases

Hi,
A colleague of mine ran into the issue that custom adjoints get ignored in certain cases in recent versions of ChainRulesCore. Here is a minimal working example:

using Flux, ChainRulesCore

       my_fun(x) = x*x

       function ChainRulesCore.rrule(::typeof(my_fun), x)
           z = my_fun(x)
           function pullback(Δy)
                  @show "Calling custom pullback"
                  return NoTangent(), zero(x)
           end
           return z, pullback
       end

       @show gradient(my_fun, 1.0)  # calls custom rrule
       @show gradient(x -> sum(my_fun.(x)), [1.0])  # does not
       @show gradient(x -> sum(map(my_fun, x)), [1.0])   # calls custom rrule

When running this example with ChainRulesCore version 0.10.13 or older the custom gradient is used in all three cases. But when running the example using ChainRulesCore version 1.0.1 or newer the custom gradient is ignored in the second case. I ran into problems when trying to install ChainRulesCore version 1.0.0, so I could not test this version.

Is this a deliberate design choice or a bug?

It’s a performance trick.

When broadcasting, it checks Base.issingletontype(typeof(my_fun)), and concludes that the function is pure enough that it can use ForwardDiff within the broadcast, rather than Zygote’s own differentiation (which calls ChainRules). This is often 100x faster, because Zygote’s differentiation tends to break type stability.

You can write a method my_fun(x::ForwardDiff.Dual) to customise this. Or you can define a rule for the broadcasting, Zygote has many rules which look like this:

@adjoint function broadcasted(::typeof(tanh), x::Numeric)
  y = tanh.(x)
  y, ȳ -> (nothing, ȳ .* conj.(1 .- y.^2))
end
1 Like

Thanks!
Pretty neat that the call to gradient checks if forward mode will be beneficial! :slight_smile: I guess we had implicitly made the assumption that the gradient function would always use reverse mode, which is not really justified.

BTW: I was a bit confused about why use ForwardDiff instead of ChainRulesCore.frule. Found an answer to this here.

Right, so ForwardDiff doesn’t use ChainRules (it’s old & stable, and maybe wouldn’t fit well). And the broadcasting needs to work even if the function is something you’re written, or some composition of functions, not just elementary functions. I suppose it could check for the existence of a direct frule first but someone would have to write that.

It is bit of a pain that you potentially have to write 3 different rules for the same function, sorry about the surprise! I think we may still be some way from the optimal story here, although not many more factors of 100 surely.