One type with branches or a union of types?

I do have types that basically split into two subtypes based on parameter values. The current implementation goes like that :

struct A{T}
    x::T
    s::Bool
    function A(x::Real)
        xf = float(x)
        return new{typeof(xf})(xf, xf > 0)
    end
end
f1(a::A) = a.s ? return 2x : throw("Not applicable for negative parameters")
f2(a::A) = a.s ? throw("Not applicable for positive parameters") : return 3x
f3(a::A) = a.s ? f1(a) : 1/f2(a)

The three methods are usefull and must surface to the user, when available. I wander if it would be worth it to trade the runtime switch to a contruction-time switch by using something like :

struct A{T,sign}
    x::T
    function A(x::Real)
        xf = float(x)
        return new{typeof(xf), xf>0}(xf)
    end
end
f1(a::A{T,true}) where T = 2x
f2(a::A{T,false}) where T = 3x
f3(a::A{T,true}) where T = f1(a)
f3(a::A{T,false}) where T = 1/f2(a)

The final behavior is actually the same, except that once the object is constructed, everything else will be type stable. I see a tradeoff in the strain on the compiler, but maybe the first version is simple enough so that the compiler actually removes the branch ?

What is the most idiomatic way of doing things here ?

Just my personal preference: I’ve come to prefer less type parameters whenever possible. It can be great for performance sometimes, but not necessarily (see the last part of this post). Most of the time I found it easier for readability and maintainability to keep struct definitions “simple”.

In particular about using “values as type parameters”: I like that it’s possible in principle, and for Array{T, 1} it feels very intuitive now, but in my experience it can also get confusing and harder to modify/debug, so I try to not overuse that feature :sweat_smile:

Another option/compromise would be this:

abstract type A end

struct PositiveA{T}
    x::T
    function PositiveA(x::Real)
        xf = float(x)
        if xf ≤ 0
            throw("...")
        end
        return new{typeof(xf)}(xf)
    end
end

struct NegativeA{T}
    x::T
    function NegativeA(x::Real)
        xf = float(x)
        if xf ≥ 0
            throw("...")
        end
        return new{typeof(xf)}(xf)
    end
end

f1(a::PositiveA) = 2a.x
f2(a::NegativeA) = 3a.x
f3(a::PositiveA) = f1(a.x)
f3(a::NegativeA) = 1/f2(a.x)

It looks a lot more verbose, but if your real use case are positive/negative signs, then PositiveA{T} immediately signals what type this is, whereas A{T, true} is not so obvious, e.g. if it appears in a stack trace.
This approach also combines the benefit of checking the sign only once at the creation of the object (vs. every time you call your functions) with the simpler method definitions (you don’t need to define methods with where T and for the cases that don’t make sense you just get a MethodError that the method is not defined for (..., NegativeA) or whatever).

julia> f2(a)
ERROR: MethodError: no method matching f2(::PositiveA{Float64})

Closest candidates are:
  f2(::NegativeA)

For convenience you could also add a constructor like

A(x) = x > 0 ? PositiveA(x) : x < 0 ? NegativeA(x) : throw("...")

so that the PositiveA and NegativeA types don’t need to be constructed manually.

The main downside of this approach compared to your first approach (just struct A{T} without the sign) is probably that iterating over collections of type A is not type stable anymore. If you have large collections of A objects and need to call f1 etc. repeatedly then it might be best to just use a single struct. That’s a tradeoff you always have if you put data from the fields to the type itself.

Would love to hear other opinions as well :slight_smile:

Thanks, I had considered this option as well, but i thought it was equivalent to the type parameter approach. For readability, we could use A{T, :Positive} and A{T, :Negative} instead of true/false.

Regarding dispatch performance : which one will be better handled by the compiler ? Union{PositiveA{T}, NegativeA{T}} or A{T, Sign} where Sign ?

Usescases are more “One A object doing a lot of things” than “large vectors of A object”, so I am not concerned about type instability in this sense

1 Like

This is true until you create a collection of As (including a struct). E.g. a

as = [A(randn()) for _ in 1:100]

The element type of the vector as will be abstract if you have the sign as a parameter, but concrete if the sign is a Bool inside A. And abstract element types are very bad for performance. Much worse than the simple tests on the Bool.

Edit: I see your use case is: One A object doing a lot of things. Then the parameter variant is a bit faster, except for compilation time.

1 Like

Ok so I have three options:

struct A{T}
    x::T
    s::Bool
    function A(x::Real)
        xf = float(x)
        return new{typeof(xf)}(xf, xf > 0)
    end
end
f1(a::A) = a.s ? 2a.x : throw("Not applicable for negative parameters")
f2(a::A) = a.s ? throw("Not applicable for positive parameters") : 3a.x
f3(a::A) = a.s ? f1(a) : 1/f2(a)

struct B{T,sign}
    x::T
    function B(x::Real)
        xf = float(x)
        return new{typeof(xf), xf>0}(xf)
    end
end
f1(a::B{T,true}) where T = 2a.x
f2(a::B{T,false}) where T = 3a.x
f3(a::B{T,true}) where T = f1(a)
f3(a::B{T,false}) where T = 1/f2(a)


struct PositiveC{T}
    x::T
    function PositiveC(x::Real)
        xf = float(x)
        if xf ≤ 0
            throw("...")
        end
        return new{typeof(xf)}(xf)
    end
end

struct NegativeC{T}
    x::T
    function NegativeC(x::Real)
        xf = float(x)
        if xf ≥ 0
            throw("...")
        end
        return new{typeof(xf)}(xf)
    end
end

f1(a::PositiveC) = 2a.x
f2(a::NegativeC) = 3a.x
f3(a::PositiveC) = f1(a)
f3(a::NegativeC) = 1/f2(a)
C(x) = x > 0 ? PositiveC(x) : x < 0 ? NegativeC(x) : throw("...")



main() = f3(A(-0.2)), f3(B(-0.2)), f3(C(-0.2)), f3(A(0.2)), f3(B(0.2)), f3(C(0.2))

main()

And this simply gives:

julia> @code_native main()
        .text
        .file   "main"
        .section        .ltext,"axl",@progbits
        .globl  julia_main_12644                # -- Begin function julia_main_12644
        .p2align        4, 0x90
        .type   julia_main_12644,@function
julia_main_12644:                       # @julia_main_12644
; Function Signature: main()
; ┌ @ Untitled-1:56 within `main`
        .cfi_startproc
# %bb.0:                                # %top
        push    rbp
        .cfi_def_cfa_offset 16
        .cfi_offset rbp, -16
        mov     rbp, rsp
        .cfi_def_cfa_register rbp
        mov     rax, rcx
        movabs  rcx, offset ".L_j_const#1"+16
; │ @ Untitled-1 within `main`
        vmovups ymm0, ymmword ptr [rcx]
        vmovups ymmword ptr [rax + 16], ymm0
        movabs  rcx, offset ".L_j_const#1"
        vmovups ymm0, ymmword ptr [rcx]
        vmovups ymmword ptr [rax], ymm0
        pop     rbp
        vzeroupper
        ret
.Lfunc_end0:
        .size   julia_main_12644, .Lfunc_end0-julia_main_12644
        .cfi_endproc
; â””
                                        # -- End function
        .type   ".L_j_const#1",@object          # @"_j_const#1"
        .section        .lrodata,"al",@progbits
        .p2align        3, 0x0
".L_j_const#1":
        .quad   0xbffaaaaaaaaaaaaa              # double -1.6666666666666665
        .quad   0xbffaaaaaaaaaaaaa              # double -1.6666666666666665
        .quad   0xbffaaaaaaaaaaaaa              # double -1.6666666666666665
        .quad   0x3fd999999999999a              # double 0.40000000000000002
        .quad   0x3fd999999999999a              # double 0.40000000000000002
        .quad   0x3fd999999999999a              # double 0.40000000000000002
        .size   ".L_j_const#1", 48

.set ".L+Core.Tuple#12646.jit", 140703350835968
        .size   ".L+Core.Tuple#12646.jit", 8
        .section        ".note.GNU-stack","",@progbits

julia> 

Which means that the compiler is removing the whole nonsense we are doing and going straight to the result in all cases. So, at least in this simple case, the three are actually equivalent at runtime. I do not know how to measure compile time to see if one is better than the others…

These are not actually equivalent at runtime. In your case this is an artifact of inlining/constant folding because you hard-coded the numeric values. Try slapping a couple of @noinline on there or put the values as parameters.

In terms of implementations: B and C are basically the same. You could define const PositiveB = B{T, true} where T and similar for negative and then it is very apparent. Version A is distinct because it actually is a single type and does the branching at runtime.

You have a design decision to make here: If you need absolutely maximal performance choose B or C. The downside is that you need to guarantee type stability or else performance will be much worse than A. If in doubt: choose A because it does not have performance pitfalls and branch prediction etc. Will likely make it not much slower than the alternatives.

1 Like

Yes, this is what i meant by equivalent at runtime: the compiler can constant-fold the three cases. Therefore keeping the conditionals in the methods for A is not that much of an issue in term of runtime, for e.g. a functiino like this :

f4(a::A, x) = a.s ? branch1(x) : branch2(x)

will not incur additional runtime when caling e.g. Base.Fix1(f4, a).(randn(10000)) since the condition is folded at compile time before broadcasting.

Thus i think i will keep this condition as a value istead of lifting it to the type system, to mitigate compiler burden.

I don’t think that’s how it works if a is not a compile time constant. If you are lucky, then the compiler is smart enough but this is definitely not guaranteed. Especially broadcasting is a very complex machinery and unless it is inlined completely, I don’t think that the loop and condition are exchanged. So imo, the assembly will contain the branch inside the loop in some fashion unless that’s completely inlined and LLVM realises it can exchange the order. However, I think in practice this should not incur much of a performance penalty due to branch prediction/speculative executation. This means that the CPU will guess the branch correctly virtually everytime which is similar in effect to compiling it away.

1 Like