Implementing derivative in ChainRules for function with internal fields

I’m trying to define an rrule for a functional, e.g. numerical integration of a given function, and the gradient would be with respect to the parameters of the function. This works no problem with a simple implementation:

import Flux: gradient

dx = diff(0:0.01:1)
x = cumsum(dx)

# The functional to differentiate
integrate(f) = sum( @. f(x)*dx )

function main(p)
    f(x) = p[1]*x + p[2]
    return integrate(f)
end

p = [2, 1]
gradient(main, p)

returns ([0.505, 1.0],) as expected. The real implementation will try to minimize memory allocations with something like:

function integrate(f)
    return mapreduce(+, zip(x, dx)) do (x, dx)
        f(x)*dx
    end
end

but applying AD to that is not efficient:

julia> @btime main($p)
  217.615 ns (4 allocations: 96 bytes)

julia> @btime gradient(main, $p)
  8.656 ms (42274 allocations: 1.82 MiB)

so I’d like to define a custom rrule with ChainRules.jl (which is great, by the way). This works if I explicitly pass the parameters through all the functions:

import ChainRules: rrule

function integrate(f, p)
    return mapreduce(+, zip(x, dx)) do (x, dx)
        f(x, p)*dx
    end
end

# Reverse rule for d/dp ∫f(x; p) dx
function rrule(::typeof(integrate), f, p)
    ∇f(x, p) = gradient(f, x, p)[2]  # Gradient w.r.t. p
    I  = integrate( f, p)
    ∇I = integrate(∇f, p)
    function int_pullback(Ī)
        # Only return differential w.r.t. p, since we integrate over x
        return NoTangent(), NoTangent(), Ī.*∇I
    end
    return I, int_pullback
end

function main(p)
    f(x, p) = p[1]*x + p[2]
    return integrate(f, p)
end

julia> @btime gradient(main, $p)
  20.744 μs (729 allocations: 53.75 KiB)

I’m not sure that’s the most efficient way to implement this, but it seems to work. The problem is that there will ultimately be a lot of different parameters in the program and it would be nice to not have to keep track of all of them in an explicit parameter array. There may also be a way to handle mapreduce better, but ultimately I’ll be using more complicated quadrature, so it’d be nice to figure out how to handle the functional case more generally.

From reading the documentation and docstrings in ChainRulesCore.jl, it seems like the way to deal with this is by differentiating with respect to the internal fields (p) in f(x) = p[1]*x + p[2]. From the documentation:

From the mathematical perspective, one may have been wondering what all this Δself , ∂self is. Given that a function with two inputs, say f(a, b) , only has two partial derivatives: ∂f/∂a, ∂f/∂b. Why then does a pushforward take in this extra Δself , and why does a pullback return this extra ∂self ?

The reason is that in Julia the function f may itself have internal fields. For example a closure has the fields it closes over; a callable object (i.e. a functor) like a Flux.Dense has the fields of that object.

Thus every function is treated as having the extra implicit argument self, which captures those fields. So every pushforward takes in an extra argument, which is ignored unless the original function has fields. It is common to write function foo_pushforward(_, Δargs...) in the case when foo does not have fields. Similarly every pullback returns an extra ∂self , which for things without fields is NoTangent() , indicating there are no fields within the function itself.

There is something similar in the source. This seems like what I’m looking for, but I can’t find any examples or descriptions of how you’d implement this in your own rrule. Could anyone give (or point me to) an example of how this could be used in a custom rrule? Or frule, for that matter?

Thanks for any suggestions.

Here is the rule for sum
https://github.com/JuliaDiff/ChainRules.jl/blob/26e46081e5210f4a81f41c9ab61ea466ca8dc006/src/rulesets/Base/mapreduce.jl#L66-L89
It doesn’t have a ∂self for sum but it does handle f being a closure/functor with
(it is even a little bit smart with it so it can avoid computing it for if f is not a closure.

It also features calling back into AD, which may or may not be useful to you

That’s great, thank you. Looks like calling back into AD is what I’m after.

The generator form of sum will avoid memory allocations already:

julia> dx = diff(0:0.0001:1);

julia> x = cumsum(dx);

julia> integrate(f) = sum(@. f(x)*dx)
integrate (generic function with 1 method)

julia> @btime main($p)
  12.583 μs (8 allocations: 78.34 KiB)
2.0000999999999998

julia> integrate(f) = sum(f(x)*dx for (x,dx) in zip(x,dx))
integrate (generic function with 1 method)

julia> @btime main($p)
  10.041 μs (5 allocations: 128 bytes)
2.000099999999998

julia> function integrate(f)
           return mapreduce(+, zip(x, dx)) do (x, dx)
               f(x)*dx
           end
       end
integrate (generic function with 1 method)

julia> @btime main($p)
  10.125 μs (5 allocations: 128 bytes)
2.000099999999998

I’m not familiar with how that affects the gradient, though. Just FYI.

I’m looking for any example implementation of rrule where ∂self is different from NoTangent(), to see how internal fields are handled, which I can’t find in the ChainRules docs either. For example, how would I directly implement rrule for Dense(2,3)?

Here is one

Not really much different to if the function took it as an argument in any other position.

If you could PR the docs once you understand it, that would be appreciated

1 Like