I am computing gradients of functions using Zygote. The argument of my function is a ComponentArray
, because my functions depend on DifferentialEquation.jl solvers using adjoint method. I have been naively assuming Zygote.jl plays well with ComponentArrays.jl, and on that basis, the following feels like a bug to me:
using Zygote, ComponentArrays
x = (a=2.0, b=3.0) |> ComponentArray
g(x) = x.a + x. b
# two-step definition of `h`:
h(a, b) = a * b
h(x) = h(x...)
f(x) = g(x)*h(x)
# julia> f(x)
# 30.0
gradient(f, x)
# ERROR: MethodError: no method matching +(::Tuple{Float64, Float64}, ::ComponentVector{Float64, Vector{Float64}, Tuple{Axis{(a = 1, b = 2)}}})
I can remove the problem in three ways:
- Leave
x
as a named tuple - Replace the two-step definition of
h
with the one-liner,h(x) = x.a + x.b
. - Add the following hack (fix?) for the definition of
accum
in Zygote:
Zygote.accum(x::AbstractArray, y::Tuple) = accum(x, collect(y))
Zygote.accum(x::Tuple, y::AbstractArray) = accum(collect(x), y)
So, is this a bug, or am I expecting too much of Zygote?