Hi, I’m using Zygote to find the Vector-Jacobian Product (VJP) of a function. The MWE version of the function is as follows:
function ffunc(x::Vector{Float64})
# first reshape
xMat = reshape(x, (:, 2))
# apply function
res = [
xMat[:,1] .+ xMat[:,2],
xMat[:,1] .- xMat[:,2]
] # return value is a vector of two vectors
return vcat(vcat(res'...)...)
end
# test
xvec = [1., 1., 1., 2., 2., 2.]
ffunc(xvec) # returns 6-element Vector{Float64}: [3.0, -1.0, 3.0, -1.0, 3.0, -1.0]
The function takes in a vector, and produces a vector as an output. The specific reshaping of the output before returning is key here (and the source of all my heartburn).
I use the following function to obtain a VJP:
function VJP(func::Function,
primal::Vector{Float64},
cotangent::AbstractArray)
_, func_back = pullback(func, primal)
vjp_result, = func_back(cotangent)
return vjp_result
end
When I try to obtain the VJP of this function, however, I get endless error upon error. For the above version of ffunc
I get:
ERROR: MethodError: no method matching getindex(::ChainRulesCore.Tangent{Any, NTuple{6, Float64}}, ::UnitRange{Int64}, ::Colon)
This is related to the use of the splat operator, and is the open Issue 110.
I used vcat
with reduce
in the return statement instead,
return reduce(vcat, reduce(vcat, res'))
But this throws a different error:
ERROR: MethodError: no method matching adjoint(::Nothing)
I used Zygote.Buffer()
to help mutate the output before returning a copy()
of it, but that didn’t work either.
I’m at the end of my tether here, so any help would be greatly appreciated!