Type stability while looping over Tuple

In the performance critical part of our code we encountered a type instability. Is there an easy way to resolve this?

We have some input

struct Sinput{X,D}
	input::X
end
s_input = Sinput{Tuple{Float64,Int64}, 2}((1.3,4))

which gets passed to a function

function test_for(x::Sinput{X,D}) where {X,D}
	for i=1:D
		x.input[i]^2
	end
end

In reality the loop in the function sits inside another loop and calls a function which dispatches on the type of x.input[i]. Whats the best way to resolve this and get type-stable code? Why is this not resolved by the compiler, despite knowing the loop size and the type of the input at compile time?

1 Like

Here is a possible solution using Unrolled as long as the tuple is not too long.

Base.length(::Type{Sinput{X,D}}) where {X,D} = D
@unroll function test_for_unrolled(x::Sinput{X,D}) where {X,D}
                             
                       @unroll for i in 1:length(x)
                                     x.input[i]^2
                             end
                     return nothing
                     end
end

Unfortunately, in our application the usage of @unroll gets blocked by the outer loop and its nesting with the inner loop.

Maybe I am misunderstanding, but why can’t you wrap the inner loop in a function and use @unroll on that inner function? Maybe a MWE code closer to the structure you have can help…

Also this thread is useful for a similar discussion.

map(a -> a^2, x.input) should be type-stable.

Mostly because it results in errors such as AssertionError: Can only unroll a loop over one of the function's arguments or BoundsError: attempt to access 0-element Vector{Any} at index [1] depending on how I try to break things up.

Here is a more complete example, where it seems to be not possible to trivialize the problem away that easily. Apart from the practical aspect in our code, I’m also interested in the conceptual answer to this question.

Here is the more complete example:

abstract type AbstractBaseType{T,N} end

mutable struct BaseStruct_1{T, N} <: AbstractBaseType{T,N}
	some_parameter::T
end
mutable struct BaseStruct_2{T, N} <: AbstractBaseType{T,N}
	some_parameter::T
end

Base.length(::Type{BaseStruct_1{T,N}}) where {T,N} = N + 1
Base.length(::Type{BaseStruct_2{T,N}}) where {T,N} = N^2
Base.eltype(::Type{<:AbstractBaseType{T,N}}) where {T,N} = T

struct Sinput{T,X,D}
	input::X
	someArray::Array{T,D}
	buffer::Vector{T}
end


function Sinput(input::X) where {X}
	D = length(input)
	T = eltype(typeof(input[1]))
	dims = Tuple([length(typeof(x)) for x in input])
	someArray = rand(T, dims)
	buffer = ones(T, D)

	Sinput{T,X,D}(input, someArray, buffer)
end


function innercall(el::BaseStruct_1{T,N}, x::T) where {T,N}
	x^2 + el.some_parameter - 3.23
	el.some_parameter += 1.2
end
function innercall(el::BaseStruct_2{T,N}, x::T) where {T,N}
	x - el.some_parameter^2
	el.some_parameter += 1.2
end

function resetInternalParam(el::AbstractBaseType)
	el.some_parameter = 0.
end


function (intP::Sinput{T,X,D})(x::T) where {T,X,D}
	result = zero(T)
	for ind in CartesianIndices(intP.someArray)
		for i=1:D
			if ind[i]==1
				resetInternalParam(intP.input[i])
				intP.buffer[i] = innercall(intP.input[i], x)
			else
				intP.buffer[i] = innercall(intP.input[i], x)
				break
			end
		end
		result += intP.someArray[ind] * prod(intP.buffer)
	end

	return result
end


base_1 = BaseStruct_1{Float64, 5}(1.);
base_2 = BaseStruct_2{Float64, 6}(1.);
s_input = Sinput((base_1, base_2));

s_input(1.2)
@code_warntype s_input(1.2)

Discovering this thread it can be resolve by recursion

function RecursiveUnroll(x::Tuple, i::Int)
	if length(x) == 0
		return nothing
	end
	# do what you want here
	RecursiveUnroll(x[2:end], i+1)
end

It still seems to me that this should be something the compiler could infer, as the range of the loop is static. Maybe someone could chip in here.

1 Like