Writing a macro that generates a function

I would like to do define a macro that allows me sum together vector-like structs, whose fields might be heterogeneous in terms of their types. I tried something like the following inside of a module:

"""
A heterogeneous vector. Actually the heterogeneity comes from the fields having different units from Unitful.jl, but this is a simplified example. I will have a few different kinds of types like these, and wish to reduce allocations that StaticArrays.FieldVector introduces due to the type heterogeneity.
"""
struct StateVector
    a :: Int
    b :: Float64
end

"""
This defines addition for FieldVectors.
"""
macro defineAdd(type)

    return quote

        T = $(esc(type))

        fields = fieldnames(T)

        sumExprs = [:(x.$f + y.$f) for f in fields]

        function Base.:+(x::T, y::T)

            T($(sumExprs...))

        end # function

    end # quote

end # macro

However, this does not work and I get the following error, where M is the module I am defining this in:

julia> M.@defineAdd M.StateVector
ERROR: LoadError: UndefVarError: `sumExprs` not defined in `M`
Suggestion: check for spelling errors or missing imports.
Stacktrace:
 [1] var"@defineAdd"(__source__::LineNumberNode, __module__::Module, type::Any)
   @ M ~/M.jl/src/M.jl:26
in expression starting at REPL[2]:1

How might I fix this error?

I know I could just define an abstract supertype for these different vectors and call propertynames or fieldnames on those, but this causes allocations. I then need to generate things like

function Base.:+(a::StateVector,b::StateVector)

    StateVector(
        a.a + b.a,
        b.a + b.b,
    )

end # function

programmatically at compile-time.

This sounds like a generated function?

1 Like

Sure. But I think my main problem now comes from the macro being defined in a module. All of the examples are always performed in the REPL, so I’m not sure how to proceed.

I also wonder what would happen if I did “function piracy” like

@generated function Base.:+(a::T,b::T)
     ...
end # function

Your immediate problem is that you try to interpolate something into a quote that was not defined before the quote but inside the quote.

Your main problem is that macros don’t have access to values when they are expanded, so they cannot query the fields of the type you give as argument.

This can be solved with a generated function, which has the critical distinction from normal macros in that it has access to type information.

Another approach that should be possible is to generate the add function together with the struct definition, with a macro

@struct_with_add struct StateVector
    a :: Int
    b :: Float64
end
1 Like

Any examples of passing struct definitions to macros? I think the generated function approach can be a bit problematic, because essentially want to do this, if I took that approach.

Ah, I guess there is this to go on: julia/base/util.jl at 9615af0f269df4d371b8010e9507ed5bae86103b · JuliaLang/julia · GitHub.

GitHub - JuliaServices/AutoHashEquals.jl: A Julia macro to add == and hash() to composite types. is another classic example.

If you can’t figure out a way to constrain the types in the generated function you can of course do the workaround of

Base.:+(a::T, b::T) where {T <: MySuperType} = my_generated_function(a, b)

Would

+(x, y) = StateVector(map(+, getfields(x), getfields(y))...)

work for you? Should be fast even without macros/generated functions.
(getfields comes from ConstructionBase.jl; it also has getproperties)

1 Like

For the record, defining

"""
Summation for StateVectors.
"""
@generated function Base.:+(left::T,right::T) where T <: FieldVector

    propNames = fieldnames(T)

    return :(
        T(
            [:(left.$p + right.$p) for p in propNames ]
        )
    )

end # function

results in an error

ERROR: The function body AST defined by this @generated function is not pure. This likely means it contains a closure, a comprehension or a generator.
Stacktrace:
 [1] top-level scope
   @ REPL[25]:1

The following results in a different error:

"""
Summation for StateVectors.
"""
@generated function Base.:+(left::T,right::T) where T <: FieldVector

    @show fieldNames = fieldnames(T)

    @show fieldAdditionExprs = [ :(left.$n + right.$n) for n in fieldNames ]

    return :(
        T($fieldAdditionExprs...)
    )

end
ERROR: MethodError: no method matching StateVector(::Expr, ::Expr)

So I guess now the question is, do I need to somehow define a constructor for StateVector, that takes two expressions and evaluates them?

function StateVector(e1::Expr,e2::Expr)

    StateVector(
        eval(e1),
        eval(e2),
    )

end # function

I think you just missed a bracket, you want T($(fieldAdditionExprs...)) splatting before interpolation, not T($fieldAdditionExprs...).

julia> struct StateVector
           a :: Int
           b :: Float64
       end

julia> @generated function Base.:+(left::T,right::T) where T <: StateVector

           @show fieldNames = fieldnames(T)

           @show fieldAdditionExprs = [ :(left.$n + right.$n) for n in fieldNames ]

           return :(
               T($(fieldAdditionExprs...))
           )

       end

julia> StateVector(1,2.0) + StateVector(3,4.0)
fieldNames = fieldnames(T) = (:a, :b)
fieldAdditionExprs = [$(Expr(:quote, :(left.:($(Expr(:$, :n))) + right.:($(Expr(:$, :n)))))) for n = fieldNames] = Expr[:(left.a + right.a), :(left.b + right.b)]
StateVector(4, 6.0)

Edit, but @aplavin 's solution is certainly easier to read:

julia> using ConstructionBase

julia> Base.:*(x:: StateVector, y:: StateVector) = StateVector(map(*, getfields(x), getfields(y))...)

julia> StateVector(1,2.0) * StateVector(3,4.0)
StateVector(3, 8.0)

And this was indeed it. Thanks so much!

Now onto generating almost the entire AbstractVector interface like this (except for indexing)… :smiley:

I kind of prefer the lack of an additional dependency in this case. But ConstructionBase is good to know about.

I don’t know the scope of your project but be prepared that you might gain runtime speed at the cost of compilation time, which may or may not be a favorable tradeoff.

Note that ConstructionBase.jl is very light and extremely widely used – it’s a dependency of thousands of Julia packages.

But I just checked, getfields implementation is just a few lines there:

function getfields(obj::T) where {T}
    fnames = fieldnames(T)
    NamedTuple{fnames}(getfield.((obj,), fnames))
end

You can just copy-paste it.