Simple challenge: Is there a more "Julian" solution?

I am porting someone else’s MATLAB code and found an indexing pattern that was baked in rather than generated programmatically. After hacking at it for a bit I managed to generalize the pattern. My solution isn’t very pretty or clever, nor is it a function that demands performance optimization. However, I am always curious to see demos of good idiomatic Julia.

So…for those that want to try solving it, here is an example of the pattern when N = 13:

index1=[
1,1,1,1,1,1,1,1,1,1,1,1,1,2,3,4,5,6,7,8,9,10,11,12,13,13,13,13,13,13,13,13,13,
13,13,13,13,12,11,10,9,8,7,6,5,4,3,2,2,2,2,2,2,2,2,2,2,2,2,3,4,5,6,7,8,9,10,11,
12,12,12,12,12,12,12,12,12,12,12,11,10,9,8,7,6,5,4,3,3,3,3,3,3,3,3,3,3,4,5,6,7,
8,9,10,11,11,11,11,11,11,11,11,11,10,9,8,7,6,5,4,4,4,4,4,4,4,4,5,6,7,8,9,10,10,
10,10,10,10,10,9,8,7,6,5,5,5,5,5,5,6,7,8,9,9,9,9,9,8,7,6,6,6,6,7,8,8,8,7,7]

I have a gist here with my own MWE of a working solution. You can either look at the gist and try to improve it, or, for bonus difficulty ignore the gist and just use the unrolled vector as a reference.

didn’t look at the gist. this is a lazy constructor

struct MyCustomIndex
    len::Int
end
Base.length(a::MyCustomIndex) = a.len*a.len

custom_state(i::Int,state::Int,current_n::Int,max_n::Int,counter::Int) = custom_state(i,Val(state),current_n,max_n,counter)

function custom_state(i::Int,state::Val{1},current_n::Int,max_n::Int,counter::Int) #constant increasing
    if current_n == max_n
        return (i+1,10,1,max_n,counter+1)
    else
        return (i,1,current_n+1,max_n,counter+1)
    end 
end
function custom_state(i::Int,state::Val{10},current_n::Int,max_n::Int,counter::Int) #increasing
    if current_n+1 == max_n
        return (i,-1,1,max_n-1,counter+1)
    else
        return (i+1,10,current_n+1,max_n,counter+1)
    end 
end
function custom_state(i::Int,state::Val{-1},current_n::Int,max_n::Int,counter::Int) #constant decreasing
    if current_n+1 == max_n
        return (i,-10,1,max_n,counter+1)
    else
        return (i,-1,current_n+1,max_n,counter+1)
    end 
end

function custom_state(i::Int,state::Val{-10},current_n::Int,max_n::Int,counter::Int) #decreasing
    if current_n == max_n
        return (i,1,1,max_n-1,counter+1)
    else
        return (i-1,-10,current_n+1,max_n,counter+1)
    end 
end

function Base.iterate(S::MyCustomIndex, state=(1,1,1,S.len,1))
    if state[end] > S.len^2
        return nothing
    else
        return (state[1],custom_state(state...))
    end
end

to test:

a = MyCustomIndex(13)
collect(a) == index1 #true

To explain my code, i take the challenge as builing a state machine capable of spitting those values. i observed 4 states:

  • constant increasing (1): the value is kept constant, the next state will increase the value
  • increasing (10): the value is increasing, the next state will maintain the value
  • constant increasing (-1): the value is kept constant, the next state will decrease the value
  • increasing (-10): the value is decreasing, the next state will maintain the value
    the states follow the order 1 → 10 → -1 → -10 → 1 …

Then, i build some functions that pass those states around. the termination criteria is in the iteration protocol instead of the state using the variable counter

I used the following julia capabilities:

  • Multiple dispatch on values: custom_state(x) = custom_state(Val(x)) allows me to dispatch on the value (in this case the state’s numbers. another option could be to build some structs to represent those states, but i didn’t feel like it :sweat_smile: . this allowed me to define he logic in different functions instead of a big if else branch
  • iterator interface: i can define Base.iterate from my custom type, and that allows me to build Iterators that are lazy calculated. for example doing MyCustomIndex(10_000) doesn’t allocate a vector of length 100_000_000, and you just consume the values

Shorter Version:

struct MyCustomIndex
    len::Int
end

Base.length(a::MyCustomIndex) = a.len*a.len

custom_state(i,state,n,max_n,j) = custom_state(i,Val(state),n,max_n,j)

custom_state(i,state::Val{1},n,max_n,j) = n == max_n ? (i+1,10,1,max_n,j+1) : (i,1,n+1,max_n,j+1)

custom_state(i,state::Val{10},n,max_n,j) = n+1 == max_n ? (i,-1,1,max_n-1,j+1) : (i+1,10,n+1,max_n,j+1)

custom_state(i,state::Val{-1},n,max_n,j) = n+1 == max_n ? (i,-10,1,max_n,j+1) : (i,-1,n+1,max_n,j+1)

custom_state(i,state::Val{-10},n,max_n,j) = n == max_n ? (i,1,1,max_n-1,j+1) : (i-1,-10,n+1,max_n,j+1)

function Base.iterate(a::MyCustomIndex, state=(1,1,1,a.len,1))
    if state[end] > length(a)
        return nothing
    else
        return (state[1],custom_state(state...))
    end
end

Note that the dynamic dispatching you’re doing with Val makes your solution 1000x times slower than the one that the OP linked to.

is enum an option?

No, you can’t dispatch on enum values. Instead, you can do pattern matching:

using Rematch
struct MyCustomIndex
    len::Int
end

Base.length(a::MyCustomIndex) = a.len*a.len

custom_state(i, state, n, max_n, j) = @match state begin
    1   => n   == max_n ? (i+1,10,1,max_n,j+1) : (i,1,n+1,max_n,j+1)
    10  => n+1 == max_n ? (i,-1,1,max_n-1,j+1) : (i+1,10,n+1,max_n,j+1)
    -1  => n+1 == max_n ? (i,-10,1,max_n,j+1)  : (i,-1,n+1,max_n,j+1)
    -10 => n   == max_n ? (i,1,1,max_n-1,j+1)  : (i-1,-10,n+1,max_n,j+1)
end

function Base.iterate(a::MyCustomIndex, state=(1,1,1,a.len,1))
    if state[end] > length(a)
        return nothing
    else
        return (state[1],custom_state(state...))
    end
end

or just write if-else statements.

added some types to dispatch on real types hahaha


struct P1 end
struct P10 end
struct M10 end
struct M1 end

struct MyCustomIndex
    len::Int
end

Base.length(a::MyCustomIndex) = a.len*a.len


custom_state(i,state::P1,n,max_n,j) = n == max_n ? (i+1,P10(),1,max_n,j+1) : (i,state,n+1,max_n,j+1)

custom_state(i,state::P10,n,max_n,j) = n+1 == max_n ? (i,M1(),1,max_n-1,j+1) : (i+1,state,n+1,max_n,j+1)

custom_state(i,state::M1,n,max_n,j) = n+1 == max_n ? (i,M10(),1,max_n,j+1) : (i,state,n+1,max_n,j+1)

custom_state(i,state::M10,n,max_n,j) = n == max_n ? (i,P1(),1,max_n-1,j+1) : (i-1,state,n+1,max_n,j+1)

function Base.iterate(S::MyCustomIndex, state=(1,P1(),1,S.len,1))
    if state[end] > S.len^2
        return nothing
    else
        return (state[1],custom_state(state...))
    end
end

yeah, finally the fastest code is a simple if else (what Rematch is doing i suppose, nice package!). I recommend adding the method eltype, to specify the type of the array created when collecting the results if necessary

struct MyCustomIndex
    len::Int
end

Base.eltype(::Type{MyCustomIndex}) = Int #important when collecting
Base.length(a::MyCustomIndex) = a.len*a.len

function custom_state(i::Int,state,n,max_n,j)
    if state == 1
        n == max_n ? (i+1,10,2,max_n,j+1) : (i,state,n+1,max_n,j+1)
    elseif state == 10
        n == max_n ? (i,-1,2,max_n-1,j+1) : (i+1,state,n+1,max_n,j+1)
    elseif state == -1
        n == max_n ? (i,-10,1,max_n,j+1) : (i,state,n+1,max_n,j+1)
    else
        n == max_n ? (i,1,1,max_n-1,j+1) : (i-1,state,n+1,max_n,j+1)
    end
end

function Base.iterate(S::MyCustomIndex, state=(1,1,1,S.len,1))
    i,state,n,max_n,j = state
    if j > S.len^2
        return nothing
    else
        return (i,custom_state(i,state,n,max_n,j))
    end
end

Is this indexing algorithm related somewhat with this image?
image

Dynamically dispatching on ‘real` types doesn’t solve the problem. It’s just that whenever you can’t statically predict what types go into a function, you should expect significant slowdowns since type level stuff needs to happen at runtime. Your example way fare a better in practice due to small union optimization though.

Could you explain? I thought the Dynamic Dispatching in Julia help choosing optimized code path, not create slow downs. Is it used incorrectly above?

Dynamic dispatch has a cost (you need to lookup what method to call). Depending on how much work the function does that cost may be significant or not.

This might qualify as overly clever rather than anything else. Enjoy.

julia> function index(n, i)
           a = n^2 - i
           b = floor(Int, sqrt(a))
           c = a - b^2
           return (n + 1) ÷ 2 - min(c - b ÷ 2, (b + 1) ÷ 2) * (-1)^b
       end
index (generic function with 1 method)

julia> index1 == index.(13, 1:169)
true

There are two patterns, the second being just an offset of the first (it starts with the increment instead of repeating integers). They’re used as a pair of indexes to slice into an ndarray of images for some volumetric deconvolution.

This is nice, and more along the lines of what I expected :smiley: Storing the pattern as a vector (as in my own solution) is not nearly as clever as transforming any given index. (Although, for my intended application I believe it’s worth it to compute and store the values just once.)

Indeed, later on the algorithm has to lookup the same values many thousands of times! :sweat_smile: However, that also depends on whether I rearrange the logic/computation after the straight port. When I stumbled upon the baked in pattern (and subsequently had to scratch my chin for a surprising amount of time) I thought it was a great challenge akin to MatLAB Cody that some might enjoy flexing on.