How to write a generic rrule?

Consider the following example:

function ancestor_accumulate(x::AbstractMatrix, ancestors::AbstractVector{Int})
    @assert size(x, 2) == length(ancestors)
    @assert issorted(ancestors) && all(a .< 1:length(ancestors))
    X = zero(x)
    for (i, a) in enumerate(ancestors)
        if iszero(a)
            X[:,i] .= x[:,i]
        else
            X[:,i] .= x[:,i] .+ X[:,a]
        end
    end
    return X
end

function ChainRulesCore.rrule(::typeof(ancestor_accumulate), x::AbstractMatrix, ancestors::AbstractVector{Int})
    @assert size(x, 2) == length(ancestors)
    function ancestor_accumulate_pullback(dX)
        dx = zero(dX)
        for (i, a) in Iterators.reverse(enumerate(ancestors))
            dx[:,i] .+= dX[:,i]
            if a > 0
                dx[:,a] .+= dx[:,i]
            end
        end
        return NoTangent(), dx, NoTangent()
    end
    return ancestor_accumulate(x, ancestors), ancestor_accumulate_pullback
end

By comparing to FiniteDifferences I am more or less sure that this gradient rule is correct. However, if I do:

ChainRulesTestUtils.test_rrule(ancestor_accumulate, x, a)

It complaints about some type support:

test_rrule: ancestor_accumulate on Matrix{Float64},Vector{Int64}: Error During Test at /home/cossio/.julia/packages/ChainRulesTestUtils/f5cNH/src/testers.jl:191
  Got exception outside of a @test
  MethodError: no method matching getindex(::ZeroTangent, ::Colon, ::Int64)
  Closest candidates are:
    getindex(::AbstractZero, ::Any) at /home/cossio/.julia/packages/ChainRulesCore/8vlYQ/src/tangent_types/abstract_zero.jl:33
  Stacktrace:
    [1] maybeview
      @ ./views.jl:132 [inlined]
    [2] dotview
      @ ./broadcast.jl:1212 [inlined]
    [3] (::var"#ancestor_accumulate_pullback#44"{Vector{Int64}})(dX::Thunk{ChainRulesTestUtils.var"#50#54"{Matrix{Float64}}})
      @ Main ./In[90]:6
    [4] macro expansion
      @ ~/.julia/packages/ChainRulesTestUtils/f5cNH/src/testers.jl:228 [inlined]
    [5] macro expansion
      @ /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.6/Test/src/Test.jl:1151 [inlined]
    [6] test_rrule(::RuleConfig, ::Any, ::Any, ::Vararg{Any, N} where N; output_tangent::Any, check_thunked_output_tangent::Any, fdm::Any, rrule_f::Any, check_inferred::Bool, fkwargs::NamedTuple, rtol::Real, atol::Real, kwargs::Any)
      @ ChainRulesTestUtils ~/.julia/packages/ChainRulesTestUtils/f5cNH/src/testers.jl:194
    [7] test_rrule
      @ ~/.julia/packages/ChainRulesTestUtils/f5cNH/src/testers.jl:186 [inlined]
    [8] #test_rrule#47
      @ ~/.julia/packages/ChainRulesTestUtils/f5cNH/src/testers.jl:168 [inlined]
    [9] test_rrule(::Any, ::Any, ::Any)
      @ ChainRulesTestUtils ~/.julia/packages/ChainRulesTestUtils/f5cNH/src/testers.jl:167
   [10] top-level scope
      @ In[91]:1
   [11] eval
      @ ./boot.jl:360 [inlined]
   [12] include_string(mapexpr::typeof(REPL.softscope), mod::Module, code::String, filename::String)
      @ Base ./loading.jl:1116
   [13] softscope_include_string(m::Module, code::String, filename::String)
      @ SoftGlobalScope ~/.julia/packages/SoftGlobalScope/u4UzH/src/SoftGlobalScope.jl:65
   [14] execute_request(socket::ZMQ.Socket, msg::IJulia.Msg)
      @ IJulia ~/.julia/packages/IJulia/e8kqU/src/execute_request.jl:67
   [15] #invokelatest#2
      @ ./essentials.jl:708 [inlined]
   [16] invokelatest
      @ ./essentials.jl:706 [inlined]
   [17] eventloop(socket::ZMQ.Socket)
      @ IJulia ~/.julia/packages/IJulia/e8kqU/src/eventloop.jl:8
   [18] (::IJulia.var"#15#18")()
      @ IJulia ./task.jl:411
Test Summary:                                                    | Pass  Error  Total
test_rrule: ancestor_accumulate on Matrix{Float64},Vector{Int64} |    7      1      8

So the way I wrote the rrule above is not generic enough? What should I do here?

The first thing that comes to mind is that you need to explicitly call unthunk(dX) within the pullback. and then only use the unthunked cotangent.

If you provide example inputs for the function, I could give more useful advice. Also, the 3rd line of ancestor_accumulate includes a variable a that is not defined in your example, so does this function even work as written?