Second order derivatives with ChainRules

How can I get the second order derivarive of a function using ChainRules? Bellow is my approach for the sin function.

See the following code:

using ChainRules
using ChainRulesCore

# from the docs, first derivative of sin
x = 1.0
sinx, sin_pullback = rrule(sin, x)
sin_pullback(x) == (NoTangent(), cos(x)) # true

Let’s try to make it more general:

using ChainRules
using ChainRulesCore

sin_pullback(u) = rrule(sin, u)[2]
sin_pullback(pi/4)(1) == (NoTangent(), cos(pi/4)) # true
dsin(u) = sin_pullback(u)(1)[2] # get the cos(u) value
dsin(pi/4) == cos(pi/4) # true

# does not work:
rrule(dsin, pi/4) # returns nothing, we want (cos(pi/4), -sin(pi/4)

This happens because dsin != cos and there is no rrule defined for dsin.

So how can we achieve this?

In short you need an AD engine.

The difference between rrule(f, x) and (say) Zygote.pullback(f, x) is that the first only exists when someone has written a rule by hand for this exact f.

Whereas Zygote will will look inside an arbitrary function, and see what other functions it calls. Here, dsin(u) calls some sin_back function, which (at very least) calls cos and Tuple, and then calls getindex the result. These individually have rules, but it needs to splice them together.

1 Like