I’m trying to define an `rrule`

for a functional, e.g. numerical integration of a given function, and the gradient would be with respect to the parameters of the function. This works no problem with a simple implementation:

```
import Flux: gradient
dx = diff(0:0.01:1)
x = cumsum(dx)
# The functional to differentiate
integrate(f) = sum( @. f(x)*dx )
function main(p)
f(x) = p[1]*x + p[2]
return integrate(f)
end
p = [2, 1]
gradient(main, p)
```

returns `([0.505, 1.0],)`

as expected. The real implementation will try to minimize memory allocations with something like:

```
function integrate(f)
return mapreduce(+, zip(x, dx)) do (x, dx)
f(x)*dx
end
end
```

but applying AD to that is not efficient:

```
julia> @btime main($p)
217.615 ns (4 allocations: 96 bytes)
julia> @btime gradient(main, $p)
8.656 ms (42274 allocations: 1.82 MiB)
```

so I’d like to define a custom `rrule`

with ChainRules.jl (which is great, by the way). This works if I explicitly pass the parameters through all the functions:

```
import ChainRules: rrule
function integrate(f, p)
return mapreduce(+, zip(x, dx)) do (x, dx)
f(x, p)*dx
end
end
# Reverse rule for d/dp ∫f(x; p) dx
function rrule(::typeof(integrate), f, p)
∇f(x, p) = gradient(f, x, p)[2] # Gradient w.r.t. p
I = integrate( f, p)
∇I = integrate(∇f, p)
function int_pullback(Ī)
# Only return differential w.r.t. p, since we integrate over x
return NoTangent(), NoTangent(), Ī.*∇I
end
return I, int_pullback
end
function main(p)
f(x, p) = p[1]*x + p[2]
return integrate(f, p)
end
julia> @btime gradient(main, $p)
20.744 μs (729 allocations: 53.75 KiB)
```

I’m not sure that’s the most efficient way to implement this, but it seems to work. The problem is that there will ultimately be a lot of different parameters in the program and it would be nice to not have to keep track of all of them in an explicit parameter array. There may also be a way to handle `mapreduce`

better, but ultimately I’ll be using more complicated quadrature, so it’d be nice to figure out how to handle the functional case more generally.

From reading the documentation and docstrings in ChainRulesCore.jl, it seems like the way to deal with this is by differentiating with respect to the internal fields (`p`

) in `f(x) = p[1]*x + p[2]`

. From the documentation:

From the mathematical perspective, one may have been wondering what all this

`Δself`

,`∂self`

is. Given that a function with two inputs, say`f(a, b)`

, only has two partial derivatives:`∂f/∂a`

,`∂f/∂b`

. Why then does a`pushforward`

take in this extra`Δself`

, and why does a`pullback`

return this extra`∂self`

?The reason is that in Julia the function

`f`

may itself have internal fields. For example a closure has the fields it closes over; a callable object (i.e. a functor) like a`Flux.Dense`

has the fields of that object.

Thus every function is treated as having the extra implicit argumentSo every`self`

, which captures those fields.`pushforward`

takes in an extra argument, which is ignored unless the original function has fields. It is common to write`function foo_pushforward(_, Δargs...)`

in the case when`foo`

does not have fields. Similarly every`pullback`

returns an extra`∂self`

, which for things without fields is`NoTangent()`

, indicating there are no fields within the function itself.

There is something similar in the source. This seems like what I’m looking for, but I can’t find any examples or descriptions of how you’d implement this in your own `rrule`

. Could anyone give (or point me to) an example of how this could be used in a custom `rrule`

? Or `frule`

, for that matter?

Thanks for any suggestions.