Hi!
I am trying to write the correct rrule
for StructArray
. However I struggle quite a lot with some of the errors returned by test_rrule
.
Here is what I got so far:
ProjectTo(::StructArray{T}) where {T} = ProjectTo{StructArray{T}}()
function (project::ProjectTo{StructArray{T}})(dx::AbstractArray{Y}) where {T,Y<:Union{T,Tangent{T}}}
fields = fieldnames(T)
components = ntuple(length(fields)) do i
getfield.(dx, fields[i])
end
StructArray{T}(components)
end
(project::ProjectTo{StructArray{T}})(dx::StructArray{Y}) where {T,Y<:Union{T,Tangent{T}}} = dx
function rrule(T::Type{StructArray}, x::Union{Tuple,NamedTuple})
y = T(x)
proj = ProjectTo(y)
function StructArray_rrule(Δ::AbstractArray)
return NoTangent(), proj(ChainRulesCore.backing.(Δ))
end
function StructArray_rrule(Δ::Tangent)
return NoTangent(), StructArray(Δ.components)
end
return y, StructArray_rrule
end
Which is failing for this example
using StructArrays, ChainRulesTestUtils
a = randn(5)
b = rand(5)
test_rrule(StructArray, (a, b))
which returns
Got exception outside of a @test
AssertionError: T <: NamedTuple
Stacktrace:
[1] test_approx(actual::Tangent{Tuple{Vector{Float64}, Vector{Float64}}, Tuple{Vector{Float64}, Vector{Float64}}}, expected::Any, msg::Any; kwargs::Base.Pairs{Symbol, V, Tuple{Vararg{Symbol, N}}, NamedTuple{names, T}} where {V, N, names, T<:Tuple{Vararg{Any, N}}})
@ ChainRulesTestUtils ~/.julia/packages/ChainRulesTestUtils/YbVdW/src/check_result.jl:128
[2] test_approx(x::Any, y::Tangent, msg::Any; kwargs::Base.Pairs{Symbol, V, Tuple{Vararg{Symbol, N}}, NamedTuple{names, T}} where {V, N, names, T<:Tuple{Vararg{Any, N}}})
@ ChainRulesTestUtils ~/.julia/packages/ChainRulesTestUtils/YbVdW/src/check_result.jl:140
[3] _test_cotangent(accum_cotangent::Any, ad_cotangent::Any, fd_cotangent::Any; check_inferred::Any, kwargs::Base.Pairs{Symbol, V, Tuple{Vararg{Symbol, N}}, NamedTuple{names, T}} where {V, N, names, T<:Tuple{Vararg{Any, N}}})
@ ChainRulesTestUtils ~/.julia/packages/ChainRulesTestUtils/YbVdW/src/testers.jl:299