As far as I understand, the problem is in cat(...,dims=2)
. That can return an array of different kind, depending on the input. For example:
julia> B = [1,2,3]
3-element Array{Int64,1}:
1
2
3
julia> cat(B,B,dims=2)
3×2 Array{Int64,2}:
1 1
2 2
3 3
In this case, the input is a vector and the output is a two-dimensional array. In the code that contains that conditional clause, there is no way for the compiler to know in advance if that change in dimensions will occur or not, because it is dependent on an input parameter of the function. Thus, the code does not become specialized for the type of array that you are operating on. That is independent on entering or not the conditional, the problem is the possibility of entering and changing the meaning of the foregoing operations.
If you put that in an outer
function, as I suggested, the inner
function will get specialized for the type of array received every time. Consider the following minimal example:
julia> function f(A,N)
if N == 0
println("got here")
A = cat(A,A,dims=2)
end
s = 0.
for a in A
s += a
end
s
end
f (generic function with 1 method)
julia> function outer(A,N)
if N == 0
println("got here")
A = cat(A,A,dims=2)
end
g(A)
end
function g(A)
s = 0.
for a in A
s += a
end
s
end
g (generic function with 1 method)
julia> A = rand(3,3);
julia> using BenchmarkTools
julia> @btime f($A,1)
269.185 ns (18 allocations: 288 bytes)
4.425721907655399
julia> @btime outer($A,1)
17.008 ns (1 allocation: 16 bytes)
4.425721907655399
Note that it also has the same problem, and the same solution. If you do @code_warntype
on f(A,1)
or on outer(A,1)
, you will see that both are type-unstable. Yet, no matter which is the input of g(A)
, it is always type-stable. Since g
is the function doing the expensive computations, that specialization is absolutely determinant to performance.
Sometimes, depending on the conditional involved, the compiler figures out that it can split the function in two branches, and that does not happen. But here it cannot do it.