Efficient custom broadcasting challenge

I am trying to build some optimized broadcasting behavior for some factorized arrays of quantities for FlexUnits.jl. It turns out that if a matrix of quantities can be multiplied by vector of quantities, the matrix units can be conveniently factored and described by two vectors and a scalar. I call this factorization a DimsMap

@kwdef struct DimsMap{D<:AbstractDimensions, TI<:AbstractVector{D}, TO<:AbstractVector{D}} <: AbstractDimsMap{D}
    u_fac :: D
    u_in  :: TI
    u_out :: TO
end

This object is array-like but not an abstract array. It supports indexing which calculates units on the fly. Moreover, it supports size and axes.

Base.getindex(m::DimsMap, ii::Integer, jj::Integer) = m.u_out[ii]/m.u_in[jj]*m.u_fac
Base.axes(m::DimsMap) = (axes(m.u_out)[1], axes(m.u_in)[1])
Base.size(m::DimsMap) = (length(m.u_out), length(m.u_in))

I was surprised to find out that broadcasting automatically worked, sort of.

u1 = SA[u"lbf*ft", u"kW", u"rpm"]
u2 = SA[u"kg/s", u"m^3/hr", u"kW"]
d = DimsMap(u_in=dimension.(u1), u_out=dimension.(u2), u_fac=dimension(u""))

julia> d.*d
3×3 SMatrix{3, 3, Dimensions{FixRat32}, 9} with indices SOneTo(3)×SOneTo(3):
 s²/m⁴        s⁴/m⁴        kg²
 (m² s²)/kg²  (m² s⁴)/kg²  m⁜
 1/s²                      (m⁴ kg²)/s⁴

This produced a correct result, but it destroyed the factorization, which is an O(n^2) procedure to reconstruct. Moreover, solving units is about 6x as expensive as solving numbers, and it is much more efficient to just solve the units separately as it is an (M + N +1) operation instead of an (M x N) operation. Theoretically, I could try to overload the actual broadcasting to do something like this.

function Base.broadcast(f, args::DimsMap...) 
    return DimsMap(
        u_fac = broadcast(f, map(ufactor, args)...), 
        u_in  = broadcast(f, map(uinput, args)...), 
        u_out = broadcast(f, map(uoutput, args)...)
    )
end

This works for explicit calls of broadcast like broadcast(d, d) but doesn’t if I use the dot syntax d.*d

Things get even more complicated when I combine it with a matrix of numbers to simulate an matrix of quantities (or more specifically, a linear mapping of them)

struct LinmapQuant{T, D<:AbstractDimensions, M<:AbstractMatrix{T}, U<:AbstractDimsMap{D}} <: AbstractMatrix{Quantity{T,D}}
    values :: M
    dims :: U
end

Again, redefining broadcast seems to work when using explicit calls, but fails with dot syntax.

function Base.broadcast(f, args::LinmapQuant...) 
    return LinmapQuant(
        broadcast(f, map(dstrip, args)...), 
        broadcast(f, map(dimension, args)...)
    )
end

How do I overload the dot syntax to use this form of broadcast? Is this the preferred way of using specialized versions of broadcast or do I need to go down one level?

There’s a system of “broadcast styles” by which the lazy broadcasting remembers what sort of object to materialise at the end. There are official docs, perhaps also this example is useful, and some earlier discourse questions: here and here.

1 Like

I’ve been reading the official docs all morning and digging into the “broadcast.jl” source code. It turns out that overloading broadcasted is the simplest solution in this case, as simply outputting a DimsMap object triggers the materialize(x) = x line.

Moreover, I only wanted to overload functions that return quantities as output, and I only want this efficient method if all objects are DimsMap (or LinmapQuant in the other example).

for op in (:+, :-, :*, :/, :\, :exp, :log)
    @eval function Base.broadcasted(f::typeof($op), arg1::DimsMap, argN::DimsMap...)
        args = (arg1, argN...)
        return DimsMap(
            u_fac = broadcast(f, map(ufactor, args)...), 
            u_in  = broadcast(f, map(uinput, args)...), 
            u_out = broadcast(f, map(uoutput, args)...)
        )
    end
end