`rrule` for `StructArrays`

Hi!

I am trying to write the correct rrule for StructArray. However I struggle quite a lot with some of the errors returned by test_rrule.

Here is what I got so far:

ProjectTo(::StructArray{T}) where {T} = ProjectTo{StructArray{T}}()

function (project::ProjectTo{StructArray{T}})(dx::AbstractArray{Y}) where {T,Y<:Union{T,Tangent{T}}}
    fields = fieldnames(T)
    components = ntuple(length(fields)) do i
        getfield.(dx, fields[i])
    end
    StructArray{T}(components)
end
(project::ProjectTo{StructArray{T}})(dx::StructArray{Y}) where {T,Y<:Union{T,Tangent{T}}} = dx

function rrule(T::Type{StructArray}, x::Union{Tuple,NamedTuple})
    y = T(x)
    proj = ProjectTo(y)
    function StructArray_rrule(Δ::AbstractArray)
        return NoTangent(), proj(ChainRulesCore.backing.(Δ))
    end
    function StructArray_rrule(Δ::Tangent)
        return NoTangent(), StructArray(Δ.components)
    end
    return y, StructArray_rrule
end

Which is failing for this example

using StructArrays, ChainRulesTestUtils
a = randn(5)
b = rand(5)
test_rrule(StructArray, (a, b))

which returns

Got exception outside of a @test
  AssertionError: T <: NamedTuple
  Stacktrace:
    [1] test_approx(actual::Tangent{Tuple{Vector{Float64}, Vector{Float64}}, Tuple{Vector{Float64}, Vector{Float64}}}, expected::Any, msg::Any; kwargs::Base.Pairs{Symbol, V, Tuple{Vararg{Symbol, N}}, NamedTuple{names, T}} where {V, N, names, T<:Tuple{Vararg{Any, N}}})
      @ ChainRulesTestUtils ~/.julia/packages/ChainRulesTestUtils/YbVdW/src/check_result.jl:128
    [2] test_approx(x::Any, y::Tangent, msg::Any; kwargs::Base.Pairs{Symbol, V, Tuple{Vararg{Symbol, N}}, NamedTuple{names, T}} where {V, N, names, T<:Tuple{Vararg{Any, N}}})
      @ ChainRulesTestUtils ~/.julia/packages/ChainRulesTestUtils/YbVdW/src/check_result.jl:140
    [3] _test_cotangent(accum_cotangent::Any, ad_cotangent::Any, fd_cotangent::Any; check_inferred::Any, kwargs::Base.Pairs{Symbol, V, Tuple{Vararg{Symbol, N}}, NamedTuple{names, T}} where {V, N, names, T<:Tuple{Vararg{Any, N}}})
      @ ChainRulesTestUtils ~/.julia/packages/ChainRulesTestUtils/YbVdW/src/testers.jl:299

I am not sure that’s the solution but I iterated further to:

ProjectTo(sa::StructArray{T}) where {T} = ProjectTo{StructArray{T}}(;axes=axes(sa))

function (project::ProjectTo{StructArray{T}})(dx::AbstractArray{Y}) where {T,Y<:Union{T,Tangent{T}}}
    fields = fieldnames(T)
    components = ntuple(length(fields)) do i
        getfield.(dx, fields[i])
    end
    StructArray{T}(backing.(components))
end
(proj::ProjectTo{StructArray{T}})(dx::Tangent{<:StructArray{T}}) where {T} = begin 
    StructArray{T}(backing(dx.components))
end
function (project::ProjectTo{StructArray{T}})(dx::StructArray{Y}) where {T,Y<:Union{T,Tangent{T}}}
    StructArray{T}(StructArrays.components(backing.(dx)))
end

function rrule(::Type{StructArray}, x::T) where {T<:Union{Tuple,NamedTuple}}
    y = StructArray(x)
    function StructArray_rrule(Δ)
        return NoTangent(), Tangent{T}(StructArrays.components(backing.(Δ))...)
    end
    return y, StructArray_rrule
end
function rrule(::Type{StructArray{X}}, x::T) where {X,T<:Union{Tuple,NamedTuple}}
    y = StructArray{X}(x)
    function StructArray_rrule(Δ)
        return NoTangent(), Tangent{T}(StructArrays.components(backing.(Δ))...)
    end
    return y, StructArray_rrule
end