Yota.jl getindex error

I am trying to calculate the gradient of this function:

using Yota 

f(x) = sum(getindex(x, [1 3; 4 2]))
s = rand(2,2)
val, g = Yota.grad(f, s)

If s = rand(2) my code works but the above code returns this error:


> Failed to find a derivative for %11 = hvcat(%6, %7, %8, %9, %10) ::Matrix{Int64} at position one

Why isn’t multidimensional indexing supported by default? Should I use @diffrule macro?
I edited my post.

What version of Julia are you using? I can’t run your function on Julia 1.5.0:

ulia> using Yota

julia> f(x) = sum(exp.(x), getindex(x, [1 3; 4 2]))
f (generic function with 1 method)

julia> s = rand(2,2)
2×2 Array{Float64,2}:
 0.235071  0.271145
 0.994342  0.97109

julia> f(x) = sum(exp.(x), getindex(x, [1 3; 4 2]))
f (generic function with 1 method)

julia> f(s)
ERROR: MethodError: objects of type Array{Float64,2} are not callable
Use square brackets [] for indexing an Array

Most likely this issue is not related to getindex(), but to sum(). I guess you try to use this method of sum():

sum(A::AbstractArray, w::StatsBase.UnitWeights; dims)

Which we don’t have a diff rule for. Is it the case?

No. If I use these functions I still see the same error:

f2(x) = sum(getindex(exp.(x), [1 3; 4 2]))
f3(x) = sum(exp.(getindex(x, [1 3; 4 2]))

Ok, with Julia 1.5.0 and Yota 0.4.4 I see another error related to the recent changes in getindex():

ERROR: BoundsError: attempt to access 1×2×2 Array{Float64,3} at index [1:1, 4, 1]

The changes were needed to support the case with repeating indices, e.g. getindex(x, [1, 2, 1]) (see related issue in Zygote if you’re curious about the whole story), but implementation is bit of a hack, so we get what we get.

I see 2 options for you here:

  1. If you don’t plan to use repeating indices, just use Yota@0.4.2 - I just checked your f2() and f3() on it and they work perfectly fine.
  2. If your code may have repeating indices, you will have to wait for this PR in NNlib, which I will then use to replace my hacky implementation of scatter operations.