Type instability in FLoop reduction

I’m attempting to speed up a residual computation function with FLoops but am getting some type instability. The resid_simple function is the version that uses LinearAlgebra the allocations make it slow. The Floops version is faster, but I would like to eliminate the type instability.

using FLoops
using LinearAlgebra

function resid_simple(stage1::AbstractArray{T}, stage2::AbstractArray{T}) where{T <: AbstractFloat}
    stage1_denom = norm(stage1 , 2) 
    if isinf(stage1_denom) || iszero(stage1_denom)
        resid = NaN
    else
        resid = norm(stage2 .- stage1, 2) / stage1_denom
    end
    resid
end

function resid(stage1::AbstractArray{T}, stage2::AbstractArray{T}) where{T <: AbstractFloat}

    numerator = zero(T)
    stage1_denom = zero(T)

    @floop for I in CartesianIndices(stage1)
        a = stage1[I]^2
        @reduce(stage1_denom += a)
    end
    stage1_denom = sqrt(stage1_denom)

    if isinf(stage1_denom) || iszero(stage1_denom)
        resid = 0.0
    else
        @floop for I in CartesianIndices(stage1)
            a = (stage2[I] - stage1[I])^2
            @reduce(numerator += a)
        end
        resid = sqrt(numerator) / stage1_denom
    end
    resid
end

U = rand(4, 10_000, 10_000);
Unp1 = rand(4, 10_000, 10_000);

ρ = @view U[1,:,:];
ρnp1 = @view Unp1[1,:,:];

@code_warntype resid(ρ,ρnp1)

This gives shows some type instability on the resid, numerator and stage1_denom variables.

Variables
  #self#::Core.Const(resid)
  stage1::SubArray{Float64, 2, Array{Float64, 3}, Tuple{Int64, Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}}, true}
  stage2::SubArray{Float64, 2, Array{Float64, 3}, Tuple{Int64, Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}}, true}
  @_4::Union{}
  @_5::Union{}
  resid::Any
  numerator::Any
  grouped_accs#270::Any
  result#269::Any
  __##combine_function#268::var"#__##combine_function#268#6"
  __##reducing_function#267::var"#__##reducing_function#267#5"{SubArray{Float64, 2, Array{Float64, 3}, Tuple{Int64, Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}}, true}, SubArray{Float64, 2, Array{Float64, 3}, Tuple{Int64, Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}}, true}}
  __##oninit_function#266::var"#__##oninit_function#266#4"
  stage1_denom::Any
  grouped_accs#263::Any
  result#262::Any
  __##combine_function#261::var"#__##combine_function#261#3"
  __##reducing_function#260::var"#__##reducing_function#260#2"{SubArray{Float64, 2, Array{Float64, 3}, Tuple{Int64, Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}}, true}}
  __##oninit_function#259::var"#__##oninit_function#259#1"
  @_19::Any

Body::Any

How do I get rid of this? I’m running Julia 1.6 with FLoops v0.1.10.

You can use specify the identity element of + using @reduce(stage1_denom = zero(T) + a) and @reduce(numerator = zero(T) + a). This often helps resolving type instability.

Ok, I tried your suggestion, but still see the same instabilities. I updated to FLoops v0.1.11 as well, but that made no difference.

Ah, actually, the type instability in the caller is expected since fetch(::Task) is not inferrable. In Julia 1.8 maybe I can fix this properly. Internally (i.e., the basecase reductions running on each task), there should be no type instability. Meanwhile, maybe you can just do stage1_denom = sqrt(stage1_denom::T) or stage1_denom = sqrt(convert(T, stage1_denom)) etc. to improve the stability of the user code of resid.

Would it be possible to parametrize Task{OutputType}?

I don’t think we can provide such API since OutputType depends on the type inference specificity. Rather, what can be implemented is that, if the compiler can figure out the exact code corresponding to a fetch call, it can insert a type assert for the return value of fetch.

Adding stage1_denom::T = sqrt(stage1_denom) gets rid of the type instability on stage1_denom, but it doesn’t significantly change the runtime. I’ve compared against a version that uses @tturbo from LoopVectorization.jl and the runtime is comparable. It doesn’t appear that the type instability hampers the runtime in this case. I may be wrong, but I’ll move on.

Thanks for your help!