Consider the following example:
function ancestor_accumulate(x::AbstractMatrix, ancestors::AbstractVector{Int})
@assert size(x, 2) == length(ancestors)
@assert issorted(ancestors) && all(a .< 1:length(ancestors))
X = zero(x)
for (i, a) in enumerate(ancestors)
if iszero(a)
X[:,i] .= x[:,i]
else
X[:,i] .= x[:,i] .+ X[:,a]
end
end
return X
end
function ChainRulesCore.rrule(::typeof(ancestor_accumulate), x::AbstractMatrix, ancestors::AbstractVector{Int})
@assert size(x, 2) == length(ancestors)
function ancestor_accumulate_pullback(dX)
dx = zero(dX)
for (i, a) in Iterators.reverse(enumerate(ancestors))
dx[:,i] .+= dX[:,i]
if a > 0
dx[:,a] .+= dx[:,i]
end
end
return NoTangent(), dx, NoTangent()
end
return ancestor_accumulate(x, ancestors), ancestor_accumulate_pullback
end
By comparing to FiniteDifferences I am more or less sure that this gradient rule is correct. However, if I do:
ChainRulesTestUtils.test_rrule(ancestor_accumulate, x, a)
It complaints about some type support:
test_rrule: ancestor_accumulate on Matrix{Float64},Vector{Int64}: Error During Test at /home/cossio/.julia/packages/ChainRulesTestUtils/f5cNH/src/testers.jl:191
Got exception outside of a @test
MethodError: no method matching getindex(::ZeroTangent, ::Colon, ::Int64)
Closest candidates are:
getindex(::AbstractZero, ::Any) at /home/cossio/.julia/packages/ChainRulesCore/8vlYQ/src/tangent_types/abstract_zero.jl:33
Stacktrace:
[1] maybeview
@ ./views.jl:132 [inlined]
[2] dotview
@ ./broadcast.jl:1212 [inlined]
[3] (::var"#ancestor_accumulate_pullback#44"{Vector{Int64}})(dX::Thunk{ChainRulesTestUtils.var"#50#54"{Matrix{Float64}}})
@ Main ./In[90]:6
[4] macro expansion
@ ~/.julia/packages/ChainRulesTestUtils/f5cNH/src/testers.jl:228 [inlined]
[5] macro expansion
@ /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.6/Test/src/Test.jl:1151 [inlined]
[6] test_rrule(::RuleConfig, ::Any, ::Any, ::Vararg{Any, N} where N; output_tangent::Any, check_thunked_output_tangent::Any, fdm::Any, rrule_f::Any, check_inferred::Bool, fkwargs::NamedTuple, rtol::Real, atol::Real, kwargs::Any)
@ ChainRulesTestUtils ~/.julia/packages/ChainRulesTestUtils/f5cNH/src/testers.jl:194
[7] test_rrule
@ ~/.julia/packages/ChainRulesTestUtils/f5cNH/src/testers.jl:186 [inlined]
[8] #test_rrule#47
@ ~/.julia/packages/ChainRulesTestUtils/f5cNH/src/testers.jl:168 [inlined]
[9] test_rrule(::Any, ::Any, ::Any)
@ ChainRulesTestUtils ~/.julia/packages/ChainRulesTestUtils/f5cNH/src/testers.jl:167
[10] top-level scope
@ In[91]:1
[11] eval
@ ./boot.jl:360 [inlined]
[12] include_string(mapexpr::typeof(REPL.softscope), mod::Module, code::String, filename::String)
@ Base ./loading.jl:1116
[13] softscope_include_string(m::Module, code::String, filename::String)
@ SoftGlobalScope ~/.julia/packages/SoftGlobalScope/u4UzH/src/SoftGlobalScope.jl:65
[14] execute_request(socket::ZMQ.Socket, msg::IJulia.Msg)
@ IJulia ~/.julia/packages/IJulia/e8kqU/src/execute_request.jl:67
[15] #invokelatest#2
@ ./essentials.jl:708 [inlined]
[16] invokelatest
@ ./essentials.jl:706 [inlined]
[17] eventloop(socket::ZMQ.Socket)
@ IJulia ~/.julia/packages/IJulia/e8kqU/src/eventloop.jl:8
[18] (::IJulia.var"#15#18")()
@ IJulia ./task.jl:411
Test Summary: | Pass Error Total
test_rrule: ancestor_accumulate on Matrix{Float64},Vector{Int64} | 7 1 8
So the way I wrote the rrule above is not generic enough? What should I do here?