Using Zygote with ComponentArrays

I am computing gradients of functions using Zygote. The argument of my function is a ComponentArray, because my functions depend on DifferentialEquation.jl solvers using adjoint method. I have been naively assuming Zygote.jl plays well with ComponentArrays.jl, and on that basis, the following feels like a bug to me:

using Zygote, ComponentArrays

x = (a=2.0, b=3.0) |> ComponentArray

g(x) = x.a + x. b

# two-step definition of `h`:
h(a, b) = a * b
h(x) = h(x...)

f(x) = g(x)*h(x)

# julia> f(x)
# 30.0

gradient(f, x)

# ERROR: MethodError: no method matching +(::Tuple{Float64, Float64}, ::ComponentVector{Float64, Vector{Float64}, Tuple{Axis{(a = 1, b = 2)}}})

I can remove the problem in three ways:

  1. Leave x as a named tuple
  2. Replace the two-step definition of h with the one-liner, h(x) = x.a + x.b.
  3. Add the following hack (fix?) for the definition of accum in Zygote:
Zygote.accum(x::AbstractArray, y::Tuple) = accum(x, collect(y))
Zygote.accum(x::Tuple, y::AbstractArray) = accum(collect(x), y)

So, is this a bug, or am I expecting too much of Zygote?

Zygote thinks the gradient of a splat is always a Tuple, it’s a longstanding bug and I presume where this Tuple comes from. If that were fixed, perhaps it would make a gradient which is an array here, “natural” representation.

However, the gradient of g(x) = x.a + x. b is “structural”, a NamedTuple. These two gradient representations cannot by default be added. Although in this case, overloading accum may work.

ChainRules has a mechanism for standardising on one of these types. It’s used here to turn both such representations into another ComponentArray. That would ideally allow them to be added.

tl;dr is that you should probably always avoid splats of arrays.