Performance and allocation issue with arrays of functions (v1.5)

I have an ODE type problem where I generate a sequence of values, x_n, that, in principle, may be in a comparatively high dimensional space. Often, I don’t need to record teh full values, but rather, just some scalar valued functions f_k(x_n), k=1,\ldots, K, where K is much smaller than the dimension of x_n. I wrote the following code to illustrate a performance issue I’ve noticed:

function compute_values(x₀, Δt, nΔt, f)
    x = x₀;
    n_f = length(f);
    f_vals = zeros(n_f, nΔt)
    
    for n in 1:nΔt
        x += 0.5 * Δt * x
        for k in 1:n_f
           f_vals[k,n] = f[k](x); 
        end
    end
    return f_vals
end

x₀ = 1.0;
Δt = 0.5;
nΔt = 10^2;

function f1(x)
    return sin(x);
end

function f2(x)
    return x^2;
end

If I now doe:

@btime compute_values(x₀, Δt, nΔt, [f1]);
@btime compute_values(x₀, Δt, nΔt, [f2]);
@btime compute_values(x₀, Δt, nΔt, [f1,f2]);

it returns:

  259.131 ns (2 allocations: 976 bytes)
  269.878 ns (2 allocations: 976 bytes)
  15.101 μs (303 allocations: 6.56 KiB)

which is not what I anticipated. Both the speed and the memory allocation have gone up by over an order of magnitude when I try to have the computation performed with both functions, even though there is no trouble with a single function.

Try a tuple instead of an array. With an array, the compiler doesn’t know what type it’s going to get when it grabs one (each function is a separate type).

1 Like

That doesn’t seem to help very much:

@btime compute_values(x₀, Δt, nΔt, (f1,f2));
  19.653 μs (401 allocations: 8.02 KiB)

This implementation speeds up the last case by 10X. Note the x *= 1 + 0.5Δt is the same as x += 0.5 * Δt * x but looks nicer to me and surprisingly gives a slight speedup. Also, don’t forget to interpolate ($) while using @btime.

function compute_values2(x₀, Δt, nΔt, f)
    x = x₀
    n_f = length(f)
    f_vals = zeros(n_f, nΔt)
    
    for n in 1:nΔt
        x *= 1 + 0.5Δt
        f_vals[:,n] .= x
    end
    for i in 1:n_f
        fi = view(f_vals,i,:)
        map!(f[i], fi, fi)
    end
    return f_vals
end

x₀ = 1.0;
Δt = 0.5;
nΔt = 10^2;

f1(x) = sin(x)
f2(x) = x^2

@btime compute_values2($x₀, $Δt, $nΔt, $[f1])
@btime compute_values2($x₀, $Δt, $nΔt, $[f2])
@btime compute_values2($x₀, $Δt, $nΔt, $(f1,f2))
  1.820 μs (1 allocation: 896 bytes)
  220.359 ns (1 allocation: 896 bytes)
  2.022 μs (5 allocations: 1.95 KiB)
2 Likes

Instead of map you could also use a generated function with Base.Cartesian.@nexprs to manually unroll the loop and make it type stable:

@generated function compute_values(x₀, Δt, nΔt, f::Tuple{Vararg{<:Any,K}}) where {K}
    quote
        x = x₀;
        f_vals = zeros($K, nΔt)
        for n in 1:nΔt
            x += 0.5 * Δt * x
            Base.Cartesian.@nexprs $K k -> f_vals[k,n] = (f[k])(x);
        end
        return f_vals
    end
end
2 Likes

Isn’t using @generated considered “bad?”

Sure, but @generated is easy and it performs well. Comparing the above compute_values2 with the generated version:

julia> @btime compute_values2($x₀, $Δt, $nΔt, $[f1])
  1.665 μs (1 allocation: 896 bytes)
 1×100 Matrix{Float64}:
 0.948985  0.999966  0.927798  0.64436  0.0897141  -0.623416  -0.998433  -0.317148  0.919731  …  0.265737  -0.944011  0.0275612  -0.999406  0.342549  0.939933  -0.676239  0.599191

julia> @btime compute_values2($x₀, $Δt, $nΔt, $[f2])
  231.457 ns (1 allocation: 896 bytes)
 1×100 Matrix{Float64}:
 1.5625  2.44141  3.8147  5.96046  9.31323  14.5519  22.7374  35.5271  55.5112  86.7362  135.525  …  1.65608e18  2.58763e18  4.04317e18  6.31746e18  9.87103e18  1.54235e19  2.40992e19

julia> @btime compute_values2($x₀, $Δt, $nΔt, $(f1,f2))
  1.726 μs (1 allocation: 1.77 KiB)
 2×100 Matrix{Float64}:
 0.948985  0.999966  0.927798  0.64436  0.0897141  -0.623416  -0.998433  -0.317148   0.919731  …  -0.944011    0.0275612   -0.999406    0.342549    0.939933    -0.676239    0.599191
 1.5625    2.44141   3.8147    5.96046  9.31323    14.5519    22.7374    35.5271    55.5112        1.65608e18  2.58763e18   4.04317e18  6.31746e18  9.87103e18   1.54235e19  2.40992e19

julia> @btime compute_values_generated($x₀, $Δt, $nΔt, $(f1,))
  1.506 μs (1 allocation: 896 bytes)
 1×100 Matrix{Float64}:
 0.948985  0.999966  0.927798  0.64436  0.0897141  -0.623416  -0.998433  -0.317148  0.919731  …  0.265737  -0.944011  0.0275612  -0.999406  0.342549  0.939933  -0.676239  0.599191

julia> @btime compute_values_generated($x₀, $Δt, $nΔt, $(f2,))
  220.222 ns (1 allocation: 896 bytes)
 1×100 Matrix{Float64}:
 1.5625  2.44141  3.8147  5.96046  9.31323  14.5519  22.7374  35.5271  55.5112  86.7362  135.525  …  1.65608e18  2.58763e18  4.04317e18  6.31746e18  9.87103e18  1.54235e19  2.40992e19

julia> @btime compute_values_generated($x₀, $Δt, $nΔt, $(f1,f2))
  1.550 μs (1 allocation: 1.77 KiB)
 2×100 Matrix{Float64}:
 0.948985  0.999966  0.927798  0.64436  0.0897141  -0.623416  -0.998433  -0.317148   0.919731  …  -0.944011    0.0275612   -0.999406    0.342549    0.939933    -0.676239    0.599191
 1.5625    2.44141   3.8147    5.96046  9.31323    14.5519    22.7374    35.5271    55.5112        1.65608e18  2.58763e18   4.04317e18  6.31746e18  9.87103e18   1.54235e19  2.40992e19

Although if you’re unhappy with laziness as an excuse, with a little more work, we can use dispatch instead of Base.Cartesian.@nexprs to unroll our expressions:

@inline fmap(fs::Tuple, x) = (first(fs)(x), fmap(Base.tail(fs), x)...)
@inline fmap(fs::Tuple{T}, x) where {T} = (first(fs)(x), )
function compute_values_map(x₀, Δt, nΔt, f::Tuple{Vararg{<:Any,K}}) where {K}
    x = x₀;
    f_vals = zeros(K, nΔt)
    for n in 1:nΔt
        x += 0.5 * Δt * x
        fv = fmap(f, x)
        for k in 1:K
            f_vals[k,n] = fv[k]
        end
    end
    return f_vals
end

Result:

julia> @btime compute_values_map($x₀, $Δt, $nΔt, $(f1,f2))
  1.562 μs (1 allocation: 1.77 KiB)
 2×100 Matrix{Float64}:
 0.948985  0.999966  0.927798  0.64436  0.0897141  -0.623416  -0.998433  -0.317148   0.919731  …  -0.944011    0.0275612   -0.999406    0.342549    0.939933    -0.676239    0.599191
 1.5625    2.44141   3.8147    5.96046  9.31323    14.5519    22.7374    35.5271    55.5112        1.65608e18  2.58763e18   4.04317e18  6.31746e18  9.87103e18   1.54235e19  2.40992e19
4 Likes