Macro to write function with many conditionals

It is even more simple then that. If all types are subtypes of some abstract class, then you can use subtypes function to generate list of all subtypes. It’s little more involved in case of parametric case, so I am not sure whether it is possible to write generic macro that can solve them all, but if you know you hierarchy beforehand you bend it appropriately.

As an example for non parametric types, here is an adaptation of the ManualDispatch.jl function

function _unionsplit(type, x, call)
    thetypes = subtypes(getfield(@__MODULE__, type))
    first_type, rest_types = Iterators.peel(thetypes)
    code = :(if $x isa $first_type
                $call
             end)
    the_args = code.args
    for next_type in rest_types
        clause = :(if $x isa $next_type # use `if` so this parses, then change to `elseif`
                       $call
                   end)
        clause.head = :elseif
        push!(the_args, clause)
        the_args = clause.args
    end
    return code
end

macro unionsplit(type, x, call)
    quote
        $(esc(_unionsplit(type, x, call)))
    end
end

Now, with the following structure defined

abstract type Foo end

struct Bar1 <: Foo
    x::Int
end

struct Bar2 <: Foo
    x::Float64
end

struct Bar3 <: Foo
    x::UInt8
end

u = map(x -> x[1](x[2]), zip(rand([Bar1, Bar2], 100), rand(1:1000, 100)))

one can see that proper code is generated

julia> @macroexpand let res = 0.0
           for z in u
               @unionsplit Foo z res += z.x
           end
           res
       end

:(let res = 0.0
      #= REPL[110]:2 =#
      for z = u
          #= REPL[110]:3 =#
          begin
              #= REPL[91]:3 =#
              if z isa Bar1
                  #= REPL[93]:5 =#
                  res += z.x
              elseif z isa Bar2
                  #= REPL[93]:10 =#
                  res += z.x
              elseif z isa Bar3
                  #= REPL[93]:10 =#
                  res += z.x
              end
          end
      end
      #= REPL[110]:5 =#
      res
  end)
5 Likes