I’m trying to define an
rrule for a functional, e.g. numerical integration of a given function, and the gradient would be with respect to the parameters of the function. This works no problem with a simple implementation:
import Flux: gradient dx = diff(0:0.01:1) x = cumsum(dx) # The functional to differentiate integrate(f) = sum( @. f(x)*dx ) function main(p) f(x) = p*x + p return integrate(f) end p = [2, 1] gradient(main, p)
([0.505, 1.0],) as expected. The real implementation will try to minimize memory allocations with something like:
function integrate(f) return mapreduce(+, zip(x, dx)) do (x, dx) f(x)*dx end end
but applying AD to that is not efficient:
julia> @btime main($p) 217.615 ns (4 allocations: 96 bytes) julia> @btime gradient(main, $p) 8.656 ms (42274 allocations: 1.82 MiB)
so I’d like to define a custom
rrule with ChainRules.jl (which is great, by the way). This works if I explicitly pass the parameters through all the functions:
import ChainRules: rrule function integrate(f, p) return mapreduce(+, zip(x, dx)) do (x, dx) f(x, p)*dx end end # Reverse rule for d/dp ∫f(x; p) dx function rrule(::typeof(integrate), f, p) ∇f(x, p) = gradient(f, x, p) # Gradient w.r.t. p I = integrate( f, p) ∇I = integrate(∇f, p) function int_pullback(Ī) # Only return differential w.r.t. p, since we integrate over x return NoTangent(), NoTangent(), Ī.*∇I end return I, int_pullback end function main(p) f(x, p) = p*x + p return integrate(f, p) end julia> @btime gradient(main, $p) 20.744 μs (729 allocations: 53.75 KiB)
I’m not sure that’s the most efficient way to implement this, but it seems to work. The problem is that there will ultimately be a lot of different parameters in the program and it would be nice to not have to keep track of all of them in an explicit parameter array. There may also be a way to handle
mapreduce better, but ultimately I’ll be using more complicated quadrature, so it’d be nice to figure out how to handle the functional case more generally.
From reading the documentation and docstrings in ChainRulesCore.jl, it seems like the way to deal with this is by differentiating with respect to the internal fields (
f(x) = p*x + p. From the documentation:
From the mathematical perspective, one may have been wondering what all this
∂selfis. Given that a function with two inputs, say
f(a, b), only has two partial derivatives:
∂f/∂b. Why then does a
pushforwardtake in this extra
Δself, and why does a
pullbackreturn this extra
The reason is that in Julia the function
fmay itself have internal fields. For example a closure has the fields it closes over; a callable object (i.e. a functor) like a
Flux.Densehas the fields of that object.
Thus every function is treated as having the extra implicit argument
self, which captures those fields. So every
pushforwardtakes in an extra argument, which is ignored unless the original function has fields. It is common to write
function foo_pushforward(_, Δargs...)in the case when
foodoes not have fields. Similarly every
pullbackreturns an extra
∂self, which for things without fields is
NoTangent(), indicating there are no fields within the function itself.
There is something similar in the source. This seems like what I’m looking for, but I can’t find any examples or descriptions of how you’d implement this in your own
rrule. Could anyone give (or point me to) an example of how this could be used in a custom
frule, for that matter?
Thanks for any suggestions.