How to cast function input parameters

In the code below, I have a Float32 vector to which I would like to add a number. To ensure that computation happens at single precision, I force the input parameters to be Float32. Is it possible to allow the function to take also Int or Float64 as argument, but at the same time directly cast them to Float32 without having to write explicit converts in the function?

function add_float32!(v::Vector{Float32}, f::Float32)
    v .+= f
end

a = rand(Float32, 10) 
b = 123.
c = 4

add_float32!(a, b)
add_float32!(a, c)

You can do it with an explicit convert without adding any complexity to your code.

function add_float32!(v::Vector{Float32}, f)
    v .+= convert(Float32, f)
end
2 Likes

You don’t need this. Since you are assigning into a Vector{Float32}, convert happens automatically. Just write

function add_float32!(v::Vector{Float32}, f)
    v .+= f
end
a = rand(Float32, 10) 
b = 123.0
add_float32!(a, b)
10-element Vector{Float32}:
 123.8254
 123.764946
 123.354645
 123.38994
 123.01592
 123.38142
 123.62168
 123.07243
 123.63172
 123.220024

Edit: It does seem though that there is a performance difference. Converting the second input to Float32 is faster if it is a ‘wider’ type, like Float64.

A more generic solution could be

function add_all!(v::AbstractArray{T}, f) where {T}
    v .+= T(f)
end

This works for most number types, and for several array types as well.

4 Likes

It is still a bit unclear to me. I wrote a benchmark with a very large array, and I do not see performance difference between a Float32 and Float64 f argument if I use a code without casting. When is Julia casting the computation to Float64 and then back to Float32 and when not?

With code without casting I meant:

unction add_float32!(v::Vector{Float32}, f)
    v .+= f
end

if you leave it like that (btw, you should), then Julia will use the promotion rule for whatever f ends up being, for example:

julia> a = Float32[1,2,3]

julia> a .+= Float64(1)
3-element Vector{Float32}:
 2.0
 3.0
 4.0

julia> a .+= 1
3-element Vector{Float32}:
 3.0
 4.0
 5.0

julia> a .+= UInt(1)
3-element Vector{Float32}:
 4.0
 5.0
 6.0

it all works because Julia knows how to add a Float32 to a Float64/Int/UInt etc.

I misread this sentence initially, thinking you were saying you were seeing the performance drop, hence the edit.

Strange, I do also see the performance difference mentioned by @DNF

julia> function add_all_noconvert!(v::AbstractArray{T}, f) where {T}
           v .+= f
       end
add_all_noconvert! (generic function with 1 method)

julia> function add_all!(v::AbstractArray{T}, f) where {T}
           v .+= T(f)
       end
add_all! (generic function with 1 method)

julia> a = rand(Float32, 10000); @btime add_all_noconvert!(a, 1.0);
  1.971 μs (0 allocations: 0 bytes)

julia> a = rand(Float32, 10000); @btime add_all!(a, 1.0);
  1.255 μs (0 allocations: 0 bytes)

I want to point out that this is not a (generally) legal optimization because rounding could be different:

julia> a = rand(Float32, 10^5); b = rand(Float64, 10^5);

julia> foldl(+, a .+ b)
100169.7271276231

julia> foldl(+, a .+ Float32.(b))
100170.37f0

although in this case using sum() which does adding in better order (paired?) mitigate the issue, but this should convince you that rounding before adding could lead to different result in general.

Yes, it uses pairwise summation.

1 Like

The difference in performance compared to raw broadcast is, I guess, because of SIMD opportunities: with just a .+= x, if a is an array of single-precision floats and x is double-precision, the sum is done in double precision and converted to single on write. With a .+= T(x), the sum is computed in single precision meaning more array elements can be updated in parallel using SIMD operations.

If the difference is due to SIMD, then it depends on whether or not the data can be fetched from memory at high enough rate. With large arrays, you may be benchmarking the memory speed rather than computation. On my computer, there’s a difference with an array of 1M Float32s which vanishes on 10M.

4 Likes

I see this drop in my actual code (not the minimal example), therefore I got confused.

Indeed! I meant strange that we don’t both see the difference. My initial reply (that I edited away) was explaining exactly this!

2 Likes