Simple Reshaping in Zygote Throwing MethodError

Hi, I’m using Zygote to find the Vector-Jacobian Product (VJP) of a function. The MWE version of the function is as follows:

function ffunc(x::Vector{Float64})
    # first reshape
    xMat = reshape(x, (:, 2))
    # apply function
    res = [
        xMat[:,1] .+ xMat[:,2],
        xMat[:,1] .- xMat[:,2]
        ] # return value is a vector of two vectors

    return vcat(vcat(res'...)...)
end

# test
xvec = [1., 1., 1., 2., 2., 2.]
ffunc(xvec) # returns 6-element Vector{Float64}: [3.0, -1.0, 3.0, -1.0, 3.0, -1.0]

The function takes in a vector, and produces a vector as an output. The specific reshaping of the output before returning is key here (and the source of all my heartburn).
I use the following function to obtain a VJP:

function VJP(func::Function, 
    primal::Vector{Float64}, 
    cotangent::AbstractArray)

    _, func_back = pullback(func, primal)
    vjp_result, = func_back(cotangent)

    return vjp_result
end

When I try to obtain the VJP of this function, however, I get endless error upon error. For the above version of ffunc I get:

ERROR: MethodError: no method matching getindex(::ChainRulesCore.Tangent{Any, NTuple{6, Float64}}, ::UnitRange{Int64}, ::Colon)

This is related to the use of the splat operator, and is the open Issue 110.
I used vcat with reduce in the return statement instead,

return reduce(vcat, reduce(vcat, res'))

But this throws a different error:

ERROR: MethodError: no method matching adjoint(::Nothing)

I used Zygote.Buffer() to help mutate the output before returning a copy() of it, but that didn’t work either.

I’m at the end of my tether here, so any help would be greatly appreciated!

The problem is that there’s no AD rule for reduce(vcat, ::AbstractMatrix) and res' is a matrix. This can be worked around by taking the adjoint of the inner arrays only:

function ffunc(x::Vector{Float64})
    # first reshape
    xMat = reshape(x, (:, 2))
    # apply function
    res = [
        (xMat[:,1] .+ xMat[:,2])',
        (xMat[:,1] .- xMat[:,2])'
    ] # return value is a vector of two vectors

    return vec(reduce(vcat, res))
end

I’ve taken the liberty of using vec instead of the outer vcat because it’s more AD-friendly and faster.

If there are always a fixed number of arrays in res, however, you could consider making it not an array:

function ffunc(x::Vector{Float64})
    # first reshape
    xMat = reshape(x, (:, 2))
    # apply function
    res = (
        (xMat[:,1] .+ xMat[:,2])',
        (xMat[:,1] .- xMat[:,2])'
    ) # return value is a tuple of two vectors

    return vec(vcat(res...))
end

Splatting generally works fine with tuples under Zygote.

Thank you for the suggestions, @ToucheSir - I’ll give these a shot and return if I still have issues. I didn’t know that vec is more AD friendly than vcat. Seems like that would help too.