# Implementing derivative in ChainRules for function with internal fields

I’m trying to define an `rrule` for a functional, e.g. numerical integration of a given function, and the gradient would be with respect to the parameters of the function. This works no problem with a simple implementation:

``````import Flux: gradient

dx = diff(0:0.01:1)
x = cumsum(dx)

# The functional to differentiate
integrate(f) = sum( @. f(x)*dx )

function main(p)
f(x) = p*x + p
return integrate(f)
end

p = [2, 1]
``````

returns `([0.505, 1.0],)` as expected. The real implementation will try to minimize memory allocations with something like:

``````function integrate(f)
return mapreduce(+, zip(x, dx)) do (x, dx)
f(x)*dx
end
end
``````

but applying AD to that is not efficient:

``````julia> @btime main(\$p)
217.615 ns (4 allocations: 96 bytes)

8.656 ms (42274 allocations: 1.82 MiB)
``````

so I’d like to define a custom `rrule` with ChainRules.jl (which is great, by the way). This works if I explicitly pass the parameters through all the functions:

``````import ChainRules: rrule

function integrate(f, p)
return mapreduce(+, zip(x, dx)) do (x, dx)
f(x, p)*dx
end
end

# Reverse rule for d/dp ∫f(x; p) dx
function rrule(::typeof(integrate), f, p)
I  = integrate( f, p)
∇I = integrate(∇f, p)
function int_pullback(Ī)
# Only return differential w.r.t. p, since we integrate over x
return NoTangent(), NoTangent(), Ī.*∇I
end
return I, int_pullback
end

function main(p)
f(x, p) = p*x + p
return integrate(f, p)
end

20.744 μs (729 allocations: 53.75 KiB)
``````

I’m not sure that’s the most efficient way to implement this, but it seems to work. The problem is that there will ultimately be a lot of different parameters in the program and it would be nice to not have to keep track of all of them in an explicit parameter array. There may also be a way to handle `mapreduce` better, but ultimately I’ll be using more complicated quadrature, so it’d be nice to figure out how to handle the functional case more generally.

From reading the documentation and docstrings in ChainRulesCore.jl, it seems like the way to deal with this is by differentiating with respect to the internal fields (`p`) in `f(x) = p*x + p`. From the documentation:

From the mathematical perspective, one may have been wondering what all this `Δself` , `∂self` is. Given that a function with two inputs, say `f(a, b)` , only has two partial derivatives: `∂f/∂a`, `∂f/∂b`. Why then does a `pushforward` take in this extra `Δself` , and why does a `pullback` return this extra `∂self` ?

The reason is that in Julia the function `f` may itself have internal fields. For example a closure has the fields it closes over; a callable object (i.e. a functor) like a `Flux.Dense` has the fields of that object.

Thus every function is treated as having the extra implicit argument `self`, which captures those fields. So every `pushforward` takes in an extra argument, which is ignored unless the original function has fields. It is common to write `function foo_pushforward(_, Δargs...)` in the case when `foo` does not have fields. Similarly every `pullback` returns an extra `∂self` , which for things without fields is `NoTangent()` , indicating there are no fields within the function itself.

There is something similar in the source. This seems like what I’m looking for, but I can’t find any examples or descriptions of how you’d implement this in your own `rrule`. Could anyone give (or point me to) an example of how this could be used in a custom `rrule`? Or `frule`, for that matter?

Thanks for any suggestions.

Here is the rule for `sum`

It doesn’t have a `∂self` for `sum` but it does handle `f` being a closure/functor with `f̄`
(it is even a little bit smart with it so it can avoid computing it for if `f` is not a closure.

It also features calling back into AD, which may or may not be useful to you

That’s great, thank you. Looks like calling back into AD is what I’m after.

The generator form of sum will avoid memory allocations already:

``````julia> dx = diff(0:0.0001:1);

julia> x = cumsum(dx);

julia> integrate(f) = sum(@. f(x)*dx)
integrate (generic function with 1 method)

julia> @btime main(\$p)
12.583 μs (8 allocations: 78.34 KiB)
2.0000999999999998

julia> integrate(f) = sum(f(x)*dx for (x,dx) in zip(x,dx))
integrate (generic function with 1 method)

julia> @btime main(\$p)
10.041 μs (5 allocations: 128 bytes)
2.000099999999998

julia> function integrate(f)
return mapreduce(+, zip(x, dx)) do (x, dx)
f(x)*dx
end
end
integrate (generic function with 1 method)

julia> @btime main(\$p)
10.125 μs (5 allocations: 128 bytes)
2.000099999999998
``````

I’m not familiar with how that affects the gradient, though. Just FYI.