Macro to write function with many conditionals

The minimum working case would be: Let say I have a vector of mixed types:

x = [ isodd(i) ? rand(Int64) : rand(Float64) for i in 1:1000 ]

I want to sum these numbers, with:

function mysum(x)
  s = 0.
  for val in x
     s += val 
  end
  s
end

This turns out to the be slow because of the type instability of the numbers and the runtime dispatch associated with it. This version which splits the calculations for each type runs much faster and does not allocate:

function mysum_fast(x)
  s = 0.
  for val in x
    if val isa Int64
      s += val
    elseif val isa Float64
      s += val
    end
  end
  s
end

Now if instead of two types of numbers, I had dozens of different types, writing that function would be very annoying. Seems the place for a macro.

Given a list of the types:

types = [ Int64, Float64 ]

I guess a simple macro can generate the mysum_fast function for me. Seems that should be simple, but I have a hard time understanding the manual on this part. In particular I could not find how to concatenate two expressions, something that I have the impression that would be required to write such macro (generate the expression of the beginning of the loop, loop over the types concatenating the conditionals, concatenate the closing of the loop).

1 Like

This is a bit of an X-Y problem, tbh. An array like this won’t be as fast as it would be with a concrete eltype.

If you can re-think this problem in a way that has two different arrays with multiple dispatch, that would be best.

2 Likes
3 Likes

I agree with that. But this is what I came to understand that was a reasonable solution for a problem in another thread. Anyway, writing this macro is simple? I would like to know that as well for learning the macro syntax.

Great that that had been done. I will see if I understand that implementation.

To be honest, no I dont think this is a good use-case for a macro.

In your example, you say “given a list of types”, but macros can’t operate on the elements of an array, since those aren’t known at the time the code generated by the macro is created.

You would have to have something like

@dispatch begin 
    Int64 ~ s += val
    Float64 ~ s += val
end

i.e. you would still have to write the types out manually. This is really no improvement over the if conditionals.

2 Likes

The set of types would be constant, in the sense that they could be hardcoded. The variable of the function is the list of elements of mixed types. Maybe my difficulty in understanding the manual is a missinterpretation on what are macros best for.

It is even more simple then that. If all types are subtypes of some abstract class, then you can use subtypes function to generate list of all subtypes. It’s little more involved in case of parametric case, so I am not sure whether it is possible to write generic macro that can solve them all, but if you know you hierarchy beforehand you bend it appropriately.

As an example for non parametric types, here is an adaptation of the ManualDispatch.jl function

function _unionsplit(type, x, call)
    thetypes = subtypes(getfield(@__MODULE__, type))
    first_type, rest_types = Iterators.peel(thetypes)
    code = :(if $x isa $first_type
                $call
             end)
    the_args = code.args
    for next_type in rest_types
        clause = :(if $x isa $next_type # use `if` so this parses, then change to `elseif`
                       $call
                   end)
        clause.head = :elseif
        push!(the_args, clause)
        the_args = clause.args
    end
    return code
end

macro unionsplit(type, x, call)
    quote
        $(esc(_unionsplit(type, x, call)))
    end
end

Now, with the following structure defined

abstract type Foo end

struct Bar1 <: Foo
    x::Int
end

struct Bar2 <: Foo
    x::Float64
end

struct Bar3 <: Foo
    x::UInt8
end

u = map(x -> x[1](x[2]), zip(rand([Bar1, Bar2], 100), rand(1:1000, 100)))

one can see that proper code is generated

julia> @macroexpand let res = 0.0
           for z in u
               @unionsplit Foo z res += z.x
           end
           res
       end

:(let res = 0.0
      #= REPL[110]:2 =#
      for z = u
          #= REPL[110]:3 =#
          begin
              #= REPL[91]:3 =#
              if z isa Bar1
                  #= REPL[93]:5 =#
                  res += z.x
              elseif z isa Bar2
                  #= REPL[93]:10 =#
                  res += z.x
              elseif z isa Bar3
                  #= REPL[93]:10 =#
                  res += z.x
              end
          end
      end
      #= REPL[110]:5 =#
      res
  end)
5 Likes

This is a very nice example. Writing such macro is much more involved that I would thought, and that example provides many interesting concepts which I had not grasped from reading the manual. That will be very useful for understanding macros in general.

Notice that if you convert

y = convert(Vector{Union{Int,Float64}}, x)

then mysum runs pretty much as fast as mysum_fast:

Benchmark
julia> @benchmark mysum_fast(x)
BenchmarkTools.Trial: 
  memory estimate:  16 bytes
  allocs estimate:  1
  --------------
  minimum time:     990.417 ns (0.00% GC)
  median time:      994.083 ns (0.00% GC)
  mean time:        1.006 μs (0.00% GC)
  maximum time:     3.991 μs (0.00% GC)
  --------------
  samples:          10000
  evals/sample:     12

julia> @benchmark mysum(y)
BenchmarkTools.Trial: 
  memory estimate:  16 bytes
  allocs estimate:  1
  --------------
  minimum time:     1.001 μs (0.00% GC)
  median time:      1.229 μs (0.00% GC)
  mean time:        1.225 μs (0.00% GC)
  maximum time:     22.500 μs (0.00% GC)
  --------------
  samples:          10000
  evals/sample:     10

If you look at the typed code (@code_typed mysum_fast(x) and @code_typed mysum(y)) you’ll see that they are almost identical.

To be honest, I don’t understand why your code for x returns a Vector{Real} instead of Vector{Union{Int,Float64}}, which seems much more reasonable. I guess it’s because somewhere it uses typejoin(Int,Float64), which is Real.


Also, the previous macro can be compressed down to just

function _unionsplit(type, x, call)
    thetypes = subtypes(getfield(@__MODULE__, type))
    foldr((t, tail) -> :(if $x isa $t; $call else $tail end), thetypes, init=Expr(:block))
end

macro unionsplit(type, x, call)
    esc(_unionsplit(type, x, call))
end

making apparent the recursive structure.

The generated Julia code is slightly different (it uses nested if/else), but once compiled they become identical (check with @code_typed for instance).

Notice that in the macro I also removed the useless quote $(...) end.

1 Like