Using `CartesianIndex` with Flux.jl

Is it possible to use CartesianIndex with Flux.jl within a Flux.with_gradient block?

Specifically I have a::Vector{Float32} and b::Vector{CartesianIndex}, and a[b] is in the gradient block. When I used this I got

ERROR: LoadError: ArgumentError: unable to check bounds for indices of type CartesianIndex{2}
Stacktrace:
  [1] checkindex(::Type{Bool}, inds::Base.OneTo{Int64}, i::CartesianIndex{2})
    @ Base ./abstractarray.jl:751
  [2] checkindex
    @ ./abstractarray.jl:767 [inlined]
  [3] checkbounds
    @ ./abstractarray.jl:689 [inlined]
  [4] checkbounds
    @ ./abstractarray.jl:699 [inlined]
  [5] _getindex
    @ ./multidimensional.jl:955 [inlined]
  [6] getindex
    @ ./abstractarray.jl:1342 [inlined]
  [7] rrule
    @ ~/.julia/packages/ChainRules/14CDN/src/rulesets/Base/indexing.jl:63 [inlined]
  [8] rrule
    @ ~/.julia/packages/ChainRulesCore/Vsbj9/src/rules.jl:138 [inlined]
  [9] chain_rrule
    @ ~/.julia/packages/Zygote/55SqB/src/compiler/chainrules.jl:234 [inlined]
 [10] macro expansion
    @ ~/.julia/packages/Zygote/55SqB/src/compiler/interface2.jl:-1 [inlined]
 [11] _pullback(::Zygote.Context{false}, ::typeof(getindex), ::Matrix{Float16}, ::Vector{CartesianIndex})
    @ Zygote ~/.julia/packages/Zygote/55SqB/src/compiler/interface2.jl:81

From the stack trace it looks like you’re doing this, getindex( ::Matrix{Float16}, ::Vector{CartesianIndex}), which doesn’t immediately fail:

julia> let a = randn(3,3)
        b = [CartesianIndex(1,1), CartesianIndex(2,2)]
       gradient(x -> sum(x[b]), a)
       end
([1.0 0.0 0.0; 0.0 1.0 0.0; 0.0 0.0 0.0],)

Can you reduce your case to a MWE which keeps the failure?

This example reproduces the error:

using Flux, Statistics

function to_cartesian_mask(is::Vector{Int})::Vector{CartesianIndex}
    [CartesianIndex(j, i) for (i, j) in enumerate(is)]
end

a = randn(3, 3)
mask = BitVector([1,0,1])
is = [2, 1]
loss, grad = Flux.withgradient(a) do a
    #js = [CartesianIndex(j, i) for (i, j) in enumerate(is)]
    #z = a[:, mask][js]
    z = a[:, mask][to_cartesian_mask(is)]
    mean(z .^ 2)
end

I think the problem is with this to_cartesian_mask function call, because if I replace that with the inlined version, the error goes away.

The return type specification ::Vector{CartesianIndex} is usually redundant and in this case detrimental, as it overrides the concrete type Vector{CartesianIndex{2}}.
When I removed the return type, and left:

function to_cartesian_mask(is::Vector{Int})
    [CartesianIndex(j, i) for (i, j) in enumerate(is)]
end

seemed to work better.

2 Likes