Custom gradients for function. Resource for rrule and frule

I am trying to create a custom gradients for a greens function whose term by term direct differentiation is not that accurate and we have a derivation of it’s derivatives and would like to use those. For the simple MWE example below, how can we write frules and rrules ? Also, Can anybody direct me to a good resource on Chainrules (I got confused by documentation)? I need even simpler tutorial than that.

Thank you,

function greens(m,n)
    """m,n are size (3) vectors"""
    return m' * n
end

function ∂greens(m,n)  #say this is the accurate derivative 
    """Partial of greens with respect to first argument m"""
    return 2 .* m .* n
end 

f = greens([1,2,3], [3,2,0]) #7
∂f_∂m =  ∂greens([1,2,3], [3,2,0]) #[6,8,0]
# print("AD gradients \n ")
print(f)
print(" \n and \n")
print(∂f_∂m)

Hi @KapilKhanal!

A few questions about your MWE:

  • Does your true function take the same number of arguments? Why do you only consider the derivative with respect to the first one?
  • What do you mean by “derivative” here? Since it’s a vector-to-scalar function, the object you know how to compute is the gradient of the output with respect to m?
  • Why do you need a custom chain rule? Is your code not differentiable?

The ChainRulesCore.jl documentation is very good but it can be a little hard to grasp for a beginner. I tried to explain some of the concepts in my JuliaCon talk last week, you can check it out here:

Assume you have a function f with two arguments x1 and x2, and its partial gradient functions g1 and g2.
ChainRulesCore.jl expects you to write a “pullback”, which basically tells Julia how a change on the output (dy) would be pulled back onto a change on the inputs (dx1 and dx2). In the case of a scalar-valued f, it’s very simple: the pullback takes the partial gradient and multiplies it (elementwise) by dy.
So you’re gonna want to write a chain rule that looks like this:

using ChainRulesCore

f(x1, x2) = ...  # some function
g1(x1, x2) = ...  # gradient of f at (x1, x2) wrt x1
g2(x1, x2) = ...  # gradient of f at (x1, x2) wrt x2

function ChainRulesCore.rrule(::typeof(f), x1, x2)
    y = f(x1, x2)
    function f_pullback(dy)
        df = NoTangent()
        dx1 = g1(x1, x2) .* dy
        dx2 = g2(x1, x2) .* dy
        return (df, dx1, dx2)
    end
    return y, f_pullback
end
3 Likes

hi yes, sorry gradients. I do have analytical expression for the gradient of the function output with respect to the first arguments of the function. The need for custom chain rule is that this greens function is approximated by a series and doing the term by term or automatic differentiation of this is shown to be not so accurate and also slower. Since, I have access to the approximation of the gradients I prefer to use it rather than differentiation of the function approximation to get the gradients.

Added: The need for only first is that the gradients are antisymmetric. But I am assuming for ChainRulesCore needs both and hence you defined both. I will define both since they share computation.

Thank you for the link to your talk and the code snippets.

1 Like

While I have you here I would appreciate your expertise on something similar.
I have similar problem where I have an external code I finite difference to get the Jacobians. MWE below. How do I ask Zygote to finite difference this external code when required? Is that also part of chainRulesCore? Thank you again!


using PyCall
py"""
def demopython(r):
    return [r,r**3]
"""

demojulia(r) = [r,r^3]

#external function to call 
external = py"demopython"
r = 1.0
h = 1e-2

∂J_r_fd = (external(r+h) - external(r-h)) ./ (2*h)
print(∂J_r_fd)
∂J_r_ad = Zygote.jacobian(demojulia,1.0)
print( " compared to algo ")
print(∂J_r_ad)

thinking it further the pullback here would just be a call to finite difference function? I will try this.

Why not use another backend to compute the finite differences? Why do you want to trick Zygote into doing it?

so that when it builds the AD graph it can use it. I am not sure how Zygote builds AD graph but whenever it does, I would prefer it finite difference that part of the code.

Then yeah you can compute the finite differences inside the pullback. Alternately, you can try out DifferentiateWith from DifferentiationInterface.jl, which is designed for this exact purpose of tricking Zygote into using another backend (typically FiniteDiff.jl here).

1 Like

It seems this does not work because of the code in Pycall?

using PyCall
py"""
def demopython(r):
    return [r,r**3]
"""

demojulia(r) = [r,r^3]

#external function to call 
external = py"demopython"
r = 1.0
h = 1e-2

∂J_r_fd(r) = (external(r+h) - external(r-h)) ./ (2*h)


function ChainRulesCore.rrule(::typeof(external), r)
    y = external(r)
    function f_pullback(dy)
        df = NoTangent()
        dx = ∂J_r_fd(r) .* dy
        return (df,dx)
    end
    return y, f_pullback
  end



∂J_r_py = Zygote.jacobian(external,1.0)
print( " compared to in julia ")
∂J_r_jl = Zygote.jacobian(demojulia,1.0)
print(∂J_r_jl)

Below is the error trace

ERROR: Compiling Tuple{typeof(PyCall._pycall!), PyObject, PyObject, Tuple{Float64}, Int64, Ptr{Nothing}}: try/catch is not supported.
Refer to the Zygote documentation for fixes.
https://fluxml.ai/Zygote.jl/latest/limitations

Stacktrace:
  [1] macro expansion
    @ ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0 [inlined]
  [2] _pullback(::Zygote.Context{…}, ::typeof(PyCall._pycall!), ::PyObject, ::PyObject, ::Tuple{…}, ::Int64, ::Ptr{…})
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:87
  [3] _pycall!
    @ ~/.julia/packages/PyCall/1gn3u/src/pyfncall.jl:11 [inlined]
  [4] _pullback(::Zygote.Context{false}, ::typeof(PyCall._pycall!), ::PyObject, ::PyObject, ::Tuple{Float64}, ::@Kwargs{})
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
  [5] #_#114
    @ ~/.julia/packages/PyCall/1gn3u/src/pyfncall.jl:86 [inlined]
  [6] _pullback(::Zygote.Context{false}, ::PyCall.var"##_#114", ::@Kwargs{}, ::PyObject, ::Float64)
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
  [7] _apply(::Function, ::Vararg{Any})
    @ Core ./boot.jl:838
  [8] adjoint
    @ ~/.julia/packages/Zygote/nsBv0/src/lib/lib.jl:203 [inlined]
  [9] _pullback
    @ ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:67 [inlined]
 [10] PyObject
    @ ~/.julia/packages/

I was assuming that since we provide the function for jacobians it would see the external function and the julia wrapper as black-box.

Similar with forward as well.

That won’t work because external is just a PyObject — it doesn’t have a unique type to dispatch off of. The easiest thing to do would be to wrap your Python object in a Julia function.

1 Like

Gotcha. But still gives incorrect answer. I am not understanding the rrule functionality.



using PyCall
py"""
def demopython(r):
    return [r,r**3]
"""

demojulia(r) = [r,r^3]

#external function to call 
function wrapper(r)
    return py"demopython"(r)
end 
r = 1.0
h = 1e-2

∂J_r_fd(r) = (wrapper(r+h) .- wrapper(r-h)) ./ (2*h)



function ChainRulesCore.rrule(::typeof(wrapper), r)
    y = wrapper(r)
    function f_pullback(dy)
        df = NoTangent()
        dx1,dx2 = ∂J_r_fd(r) .* dy
        return (df,dx1,dx2)
    end
    return y, f_pullback
  end



∂J_r_jl = Zygote.jacobian(demojulia,1.0)
 print(∂J_r_jl)
# print( " compared to in julia ")
∂J_r_py = Zygote.jacobian(wrapper,1.0)
print(∂J_r_py)
([1.0, 3.0],) != ([1.0000000000000009, 0.0],)

Here you’re working with something very different from what we discussed above: a scalar-to-vector function with one scalar argument, instead of a vectors-to-scalar function with two vector arguments. Thus, the chain rule snippet I gave you must be adapted.
In the case of a scalar-to-vector function, the pullback is the dot product between the derivative vector and the output perturbation dy. In both cases, the general principle is to compute the product between the transposed Jacobian matrix and an output perturbation. It just so happens that when either the input or the output is scalar, this collides which other notions we are more familiar with, such as the gradient or derivative.

2 Likes

This works for anyone in my boat!

function ChainRulesCore.rrule(::typeof(wrapper), r)
    y = wrapper(r)
    function f_pullback(dy)
        df = NoTangent()
        dx = ∂J_r_fd(r)' * dy
        return (df,dx)
    end
    return y, f_pullback
  end
1 Like

Cool! If your problem is solved don’t hesitate to pick one of the messages as the solution, to let others know you no longer need help

For the sake of completeness, is it still going to be dot product (for each additional dimension) for R^n → R^n function pullbacks as well?

Yes and no.

Yes, a pullback computes a (\text{vector})^T (\text{Jacobian matrix}) product, which by the familiar “rows × columns” rule is formally equivalent to a dot product of the vector with each row of the Jacobian (one for each output).

No, because it is often not computed this way. In large computations with many variables, the goal is often to avoid explicitly computing the Jacobian matrix. This is often possible because your Jacobian is internally the product of several simpler Jacobians (via the chain rule), and if you have a computation like (\text{vector})^T (\text{Jacobian 1}) (\text{Jacobian 2}) \cdots it is probably vastly more efficient to compute this product from left-to-right (outputs-to-inputs). This is the principle behind “reverse-mode”/“backpropagation”/“adjoint method” computation of gradients, which are advantageous when you have many inputs and one scalar output (or a few outputs), and this is the mode of differentiation that “pullback” functions are designed to implement.

3 Likes

Okay that makes sense. In the context of using an external code and finite differencing (or analytical expression) to get the Jacobian, I have a simple MWE below that gives me either dimension error or index error no matter how group them left to right or other. Would you be able to run and find what’s wrong here?

py"""
def demopython(r1,r2,r3): # X = [r1,r2,r3]
    return [-r1*r2, r2*r3]
"""

function wrapper(r1,r2,r3)
    return py"demopython"(r1,r2,r3)
end 


h = 1e-3
∂J_r1_fd(f,r1,r2,r3) = (f(r1+h,r2,r3) .- f(r1-h,r2,r3)) ./ (2*h)
∂J_r2_fd(f,r1,r2,r3) = (f(r1,r2+h,r3) .- f(r1,r2-h,r3)) ./ (2*h)
∂J_r3_fd(f,r1,r2,r3) = (f(r1,r2,r3+h) .- f(r1,r2,r3-h)) ./ (2*h)

function ChainRulesCore.rrule(::typeof(wrapper), r1,r2,r3)
    y = wrapper(r1,r2,r3)
    function f_pullback(dy)
        df = NoTangent()
        dx = ∂J_r1_fd(wrapper,r1,r2,r3)' * ∂J_r2_fd(wrapper,r1,r2,r3)' *∂J_r3_fd(wrapper,r1,r2,r3)' * dy
        return (df,dx)
    end
    return y, f_pullback
  end
r1,r2,r3 = 1,2,3
Zygote.jacobian(X->wrapper(X...),[r1,r2,r3])

The pullback must return one tangent for each argument of your function (so dr1, dr2, dr3), plus one for the function itself (usually df = NoTangent()).
At this point I think it would make sense for you to try reading the ChainRulesCore.jl documentation again, perhaps after watching my autodiff talk?

1 Like

Thank you for your patience. Your talk is excellent! I am revisiting the docs and getting the experts input on my implementation here in public (so it may help other novice user with same questions as me)

I do have other questions on managing complex outputs for Julia code and also some external code but it does make more sense to drill into documentation more as you suggested.

1 Like