I’m trying to define an rrule
for a functional, e.g. numerical integration of a given function, and the gradient would be with respect to the parameters of the function. This works no problem with a simple implementation:
import Flux: gradient
dx = diff(0:0.01:1)
x = cumsum(dx)
# The functional to differentiate
integrate(f) = sum( @. f(x)*dx )
function main(p)
f(x) = p[1]*x + p[2]
return integrate(f)
end
p = [2, 1]
gradient(main, p)
returns ([0.505, 1.0],)
as expected. The real implementation will try to minimize memory allocations with something like:
function integrate(f)
return mapreduce(+, zip(x, dx)) do (x, dx)
f(x)*dx
end
end
but applying AD to that is not efficient:
julia> @btime main($p)
217.615 ns (4 allocations: 96 bytes)
julia> @btime gradient(main, $p)
8.656 ms (42274 allocations: 1.82 MiB)
so I’d like to define a custom rrule
with ChainRules.jl (which is great, by the way). This works if I explicitly pass the parameters through all the functions:
import ChainRules: rrule
function integrate(f, p)
return mapreduce(+, zip(x, dx)) do (x, dx)
f(x, p)*dx
end
end
# Reverse rule for d/dp ∫f(x; p) dx
function rrule(::typeof(integrate), f, p)
∇f(x, p) = gradient(f, x, p)[2] # Gradient w.r.t. p
I = integrate( f, p)
∇I = integrate(∇f, p)
function int_pullback(Ī)
# Only return differential w.r.t. p, since we integrate over x
return NoTangent(), NoTangent(), Ī.*∇I
end
return I, int_pullback
end
function main(p)
f(x, p) = p[1]*x + p[2]
return integrate(f, p)
end
julia> @btime gradient(main, $p)
20.744 μs (729 allocations: 53.75 KiB)
I’m not sure that’s the most efficient way to implement this, but it seems to work. The problem is that there will ultimately be a lot of different parameters in the program and it would be nice to not have to keep track of all of them in an explicit parameter array. There may also be a way to handle mapreduce
better, but ultimately I’ll be using more complicated quadrature, so it’d be nice to figure out how to handle the functional case more generally.
From reading the documentation and docstrings in ChainRulesCore.jl, it seems like the way to deal with this is by differentiating with respect to the internal fields (p
) in f(x) = p[1]*x + p[2]
. From the documentation:
From the mathematical perspective, one may have been wondering what all this
Δself
,∂self
is. Given that a function with two inputs, sayf(a, b)
, only has two partial derivatives:∂f/∂a
,∂f/∂b
. Why then does apushforward
take in this extraΔself
, and why does apullback
return this extra∂self
?The reason is that in Julia the function
f
may itself have internal fields. For example a closure has the fields it closes over; a callable object (i.e. a functor) like aFlux.Dense
has the fields of that object.Thus every function is treated as having the extra implicit argument
self
, which captures those fields. So everypushforward
takes in an extra argument, which is ignored unless the original function has fields. It is common to writefunction foo_pushforward(_, Δargs...)
in the case whenfoo
does not have fields. Similarly everypullback
returns an extra∂self
, which for things without fields isNoTangent()
, indicating there are no fields within the function itself.
There is something similar in the source. This seems like what I’m looking for, but I can’t find any examples or descriptions of how you’d implement this in your own rrule
. Could anyone give (or point me to) an example of how this could be used in a custom rrule
? Or frule
, for that matter?
Thanks for any suggestions.