Promote field of abstract type

I’d like to simplify the following code so that there is only one method for multiplication that applies to all subtypes of AbstractTerm

abstract type AbstractTerm{T} end

struct ATerm{T} <: AbstractTerm{T}
    data::T
end

struct BTerm{T} <: AbstractTerm{T}
    data::T
end

function Base.:*(x::Number, aterm::ATerm{T}) where T
    return ATerm{promote_type(typeof(x), T)}(x * aterm.data)
end

function Base.:*(x::Number, bterm::BTerm{T}) where T
    return BTerm{promote_type(typeof(x), T)}(x * bterm.data)
end

a = ATerm(3)
println(4.1 * a)

I first had this method, instead of two,


function Base.:*(x::Number, term::AbstractTerm)
    return typeof(term)(x * term.data)
end

But, that fails with an error for the example given. (Also, this is a simplified version of my actual type hierarchy)

What I’d like is something like ...term::T{V}) where {V, T<:AbstractTerm}. But, that, of course, is not valid Julia.
I find similar problems quite often, for example, I write identical constructors for all subtypes of an abstract type.

Shouldn’t Base.:*(x::Number, term::T) where T<:AbstractTerm = T(x * term.data) work?
(Sorry, I’m on a phone right now.)

No, that is equivalent to the last method given in my original post. In the example in the OP, 4.1 * ATerm(3), your method gives

ERROR: LoadError: InexactError: Int64(12.299999999999999)

The method for ATerm in the OP will correctly give ATerm{Float64}(12.299999999999999). And likewise, for the method for BTerm. But, I can’t find a way to write a single method to do this for all AbstractTerm types. For each subtype, I have to write an identical method.

Oh, okay, now I see. I didn’t catch that the problem is the fact that your AbstractTerm’s parametric type (Int) doesn’t match the promoted type of the multiplication (Float64). You could strip the parametric type maybe but I think this is discouraged actually(?).

strip(::ATerm) = ATerm
strip(::BTerm) = BTerm
Base.:*(x::Number, term::AbstractTerm) = strip(term)(x * term.data)

Also, you would have to define the strip methods for each concrete type anyway.

I think the better way is to use meta-programming to generate all necessary operations for all your types.

2 Likes

Nice trick, strip(::ATerm) = ATerm, I did not think of it. Defining strip for each type is economical, and I really want to avoid metaprogramming if possible. strip is not a super-elegant solution, but it may be the best. I would actually use it in many places.

I think that based on the discussion in this post:

you might achieve what you want by defining the strip function parametrically like this:

strip(::T) where {T} = (isempty(T.parameters) ? T : T.name.wrapper)
1 Like

Yes this

strip(::T) where {T} = (isempty(T.parameters) ? T : T.name.wrapper)

seems to work. @btime does not do all the work at compile time, like it does for the hardcoded case above.
EDIT:
After testing further I find that the solution using T.wrapper is very slow compared to strip(::ATerm) = ATerm. I don’t recommend it.

I do not know what are the implications so someone else might comment on this.

But it seems that if you avoid having the intial check in the function I put above you might get the same speed as the specific one.

using BenchmarkTools

abstract type AbstractTerm{T} end

struct ATerm{T} <: AbstractTerm{T}
    data::T
end

struct BTerm{T} <: AbstractTerm{T}
    data::T
end

strip(::T) where {T} = (isempty(T.parameters) ? T : T.name.wrapper)
stripshort(::T) where T = T.name.wrapper
stripspec(::ATerm) = ATerm

a = ATerm(3)

prod1(x::Number,term::T) where T <: AbstractTerm = strip(term)(x * term.data)
prod2(x::Number,term::T) where T <: AbstractTerm = stripspec(term)(x * term.data)
prod3(x::Number,term::T) where T <: AbstractTerm = stripshort(term)(x * term.data)

@benchmark prod1.(x,Ref($a)) setup=(x = rand(10000))
@benchmark prod2.(x,Ref($a)) setup=(x = rand(10000))
@benchmark prod3.(x,Ref($a)) setup=(x = rand(10000))

BenchmarkTools.Trial: 
  memory estimate:  78.67 KiB
  allocs estimate:  13
  --------------
  minimum time:     224.400 μs (0.00% GC)
  median time:      293.000 μs (0.00% GC)
  mean time:        333.178 μs (2.12% GC)
  maximum time:     10.016 ms (0.00% GC)
  --------------
  samples:          10000
  evals/sample:     1

BenchmarkTools.Trial: 
  memory estimate:  78.20 KiB
  allocs estimate:  2
  --------------
  minimum time:     6.975 μs (0.00% GC)
  median time:      16.725 μs (0.00% GC)
  mean time:        25.221 μs (15.13% GC)
  maximum time:     1.249 ms (97.55% GC)
  --------------
  samples:          10000
  evals/sample:     4

BenchmarkTools.Trial: 
  memory estimate:  78.20 KiB
  allocs estimate:  2
  --------------
  minimum time:     7.240 μs (0.00% GC)
  median time:      15.900 μs (0.00% GC)
  mean time:        25.617 μs (15.96% GC)
  maximum time:     1.229 ms (96.18% GC)
  --------------
  samples:          10000
  evals/sample:     5
1 Like