ChainRulesCore.rrule for custom struct: does the pullback need to support Composite explicitly?

I’m trying to differentiate code involving GeometryBasics.Point construction (always with two arguments).
This seems to allow Zygote to compute the gradient correctly:

function ChainRulesCore.rrule(::Type{Point}, x::Number, y::Number)
    p = Point(x,y)
    function Point_pullback(Δp)
        @show Δp typeof(Δp)
        NO_FIELDS, Δp[1], Δp[2]
    end
    p, Point_pullback
end

n2(p) = √(p[1]^2+p[2]^2)
let (x,y) = rand(2)
    @test collect(gradient((x,y)->n2(Point(x,y)), x, y)) ≈ collect(gradient((x,y)->hypot(x,y), x, y))
end

this passes, and shows that the input Δp is a StaticArrays.MArray, which seems reasonable.
But this definition fails the test by ChainRulesTestUtils.test_rrule,

test_rrule(Point,1.0,2.0) 

which apparently tries to call it with a Composite input.
I can make it pass by adding a separate method for Composite to the pullback function, but it seems odd that I should do this for every custom rrule. Did I get something wrong here?

1 Like

Looks like ChainRulesTestUtils.rand_tangent only pulls back arrays for StridedArrays and falls back to Composites for all other AbstractArrays: ChainRulesTestUtils.jl/generate_tangent.jl at 340e21c3aa277f55609681337ad85702c8b6d302 · JuliaDiff/ChainRulesTestUtils.jl · GitHub

The easiest solution here would be to just overload random_tangent for Point, but perhaps the default should be changed?

which apparently tries to call it with a Composite input.

This is the automatic tangent generation kicking in.

You can teach ChainRulesTestUtils what the tangent type for your type is, by overloading ChainRulesTestUtils.rand_tangent(::AbstractRNG, ::Point)

Or you can pass in the tangent manually by using the output_tangent=... keyword argument.

You just need to make that return a valid differential for your type.
Basically a vector fieldish type that overload +, and zero, and scalar multiplication.
you can read more about differential types here and here, though i don’t think either part of our docs on them are fantastic.

Composite is the generic differnetial that is valid for all structs.

The easiest solution here would be to just overload random_tangent for Point , but perhaps the default should be changed?

What other default is possible for a struct we know nothing about?

I guess in this case we know it is a subtype of AbstractArray.
maybe we should have a default for that which is an Array?
That kind of thing has landed us in trouble before for structed sparse arrays like Diagonal.
but it might cause less trouble than erroring, i doubt the error in this case is very clear

1 Like

Yes, sorry, that’s what I meant.

So where did the MArray differential come from? Is this something Zygote is doing, rather than ChainRulesCore?
If I choose some differential type, I would expect to be able to declare it in one place so that it will be used both by Zygote and ChainRulesTestUtils.

So where did the MArray differential come from? Is this something Zygote is doing, rather than ChainRulesCore ?

This is a good question.
In this case there is a ZygoteRules.@adjoint returning it.
(ZygoteRules being the legacy system that ChainRulesCore replaces, but that is still used a lot e.g. internally in Zygote.).
It says that getindex on any AbstractArray always returns something given by the _zeros function.

It could be a good call, or it might not.
I stand by the uncertainy i expressed before:

That kind of thing has landed us in trouble before for structed sparse arrays like Diagonal .
but it might cause less trouble than erroring, i doubt the error in this case is very clear

I have openned an issue to discuss Should we generate AbstractArray rand_tangent for all AbstractArray subtypes? · Issue #198 · JuliaDiff/ChainRulesTestUtils.jl · GitHub

Surprisingly tricky. This would be easy and infact probably required in a static language.
Differential types while conceptually are paired to primal types, are in practice specified in every rule instance.

But because julia functions don’t specify return types due largely to julia’s dynamic nature, we can’t enforce it.
Since the return type of the primal should inform the input type of the pullback.
This has some downsides, in that it is not a closed system where you are sure you know your inputs.
Since anyone could add a rule at any time, and change it.
Mostly this should be safe as things should just work (since all differnetials should just need to support linear operations), but as you are seeing this doesn’t aways workout.
We are working on it.
It is a challenge but it gives some opertunities.
See https://juliadiff.org/ChainRulesCore.jl/dev/design/many_differentials.html

The short being: Really one would like to write things in terms of natural differnetial types, like arrays, but AD systems without they help of rules can only actually generate structural differnetials like Composite.
We can support both, which lets people write rules with natural differnetial types, but still be able to let the AD do its thing otherwise; but we can run into tricky case since they can be mixed.
I am confident we will work that out fully and elegantly eventually; and in practice it mostly works right now.
I am sorry that you are running into one where it doesn’t

2 Likes