Confusion about type stability of function within a struct

I have a struct called Operators that contains a bunch of functions. However, I can’t get it to be type stable. I have read Type Unstable: Function within a Struct and implemented the solution but it has not helped for this specific problem, although it did in other situations.

Currently each step of the loop causes 24 allocations (only 9 in my production code, can’t figure out why but the idea is the same).

MWE:

using Memoization
using FFTW
using BenchmarkTools

# This structure is type unstable
struct Operators{F1, F2}
    K̂::F1
    T̂::F2
end # Simulation

function Operators(α, x_order, ω)
    # Create this function with memoization since it is called repeatedly
    @memoize function K̂(dx)
        println("Computing and caching K(dx = $dx)")
        if α== 0
            ifftshift(cis.(dx*ω.^2/2))
        elseif α > 0
            ifftshift(cis.(dx*α*ω.^3))
        #elseif ...
            #more functions to select from
        end
    end

    # Select function 2
    if x_order == 2 && α == 0
        T̂ = T₂ˢ 
    elseif x_order == 2 && α > 0
        T̂ = T₂ˢʰ
    # elseif ...
        # many more functions to select from
    end

    ops = Operators(K̂, T̂)

    return ops
end

function T₂ˢ(ψ, dx, ops)
    # This line is 18 allocs (only 3 in my production code)
    ψ .= ops.K̂(dx) .* ψ 
end #T₂ˢ

function run(N)
    M = 512
    x_order = 2
    α = 0
    dx = 1e-4
    ψ = Array{Complex{Float64}}(undef, M, N)
    ψ[:, 1] = rand(M) + im*rand(M)
    ω = rand(M)

    # Decide which functions to use
    ops = Operators(α, x_order, ω)

    # 24 allocs per loop iteration (only 9 in production code)
    for i = 1:N-1
        @views ψ[:, i+1] .= ops.T̂(ψ[:, i], dx, ops) # 6 allocs from calling ops.T̂
    end

end

@code_warntype warns as follows:

@code_warntype run(4)

  ops::Operators{_A,_B} where _B where _A  <-- RED
  @_10::Union{Nothing, Tuple{Int64,Int64}} <-- YELLOW

...
│   %31 = (%28)(%29, %30, ops)::Any <-- RED
│   %32 = Base.broadcasted(Base.identity, %31)::Base.Broadcast.Broadcasted{_A,Nothing,typeof(identity),_B} where _B<:Tuple where _A<:Union{Nothing, Base.Broadcast.BroadcastStyle} <-- RED
│         Base.materialize!(%27, %32)
│         (@_10 = Base.iterate(%17, %24))
...

I would appreciate any thoughts, this is quite confusing for me.

A friend helped me optimize this. The end result is the following:

using Memoization
using FFTW
using BenchmarkTools

# This structure is type unstable
struct Operators{F1<:Function, F2<:Function}
    K̂::F1
    T̂::F2
end # Simulation

function Operators(α, x_order, ω)
    # Select function 
    function K̂(α)
        fun = if α == 0 
            @memoize function K_cubic(dx::Real)
                println("Computing and caching K(dx = $dx) for cubic NLSE")
                ifftshift(cis.(dx*ω.^2/2))
            end
            K_cubic
        elseif α > 0
            @memoize function K_hirota(dx::Real)
                println("Computing and caching K(dx = $dx) for Hirota Equation")
                fftshift(cis.(dx*α*ω.^3))
            end
            K_hirota
        end
        return fun
    end

    # Select function 2
    T̂ = T₂ˢ 
    if x_order == 2 && α == 0
        T̂ = T₂ˢ 
    elseif x_order == 2 && α > 0
        T̂ = T₂ˢʰ
    #elseif ...
       # many more functions to select from
    end

    ops = Operators(K̂(α), T̂)

    return ops
end

function T₂ˢ(ψ, dx, ops)
    # 3 allocs
    ψ .= ops.K̂(dx) .* ψ 
end #T₂ˢ

function run(N)
    M = 512
    x_order = 2
    α = 0
    dx = 1e-4
    ψ = Array{Complex{Float64}}(undef, M, N)
    ψ[:, 1] = rand(M) + im*rand(M)
    ω = rand(M)

    # Decide which functions to use
    ops = Operators(α, x_order, ω)

    # 3 allocs per loop iteration
    myloop(N, ψ, dx, ops)

end

function myloop(N, ψ, dx, ops)
    for i = 1:N-1
        @time @views ψ[:, i+1] .= ops.T̂(ψ[:, i], dx, ops) # 0 allocs from calling ops.T̂
    end
end

run(10)

Using let to define K significantly improves the allocations since the compiler doesn’t have to deal with it. We also use a function barrier in the form of myloop which decreases allocations since myloop already knows the type of ops. Currently, there are 3 allocations per loop iteration due to the lookups done by @memoize, I think.

Any other suggestions are still welcome however.

1 Like