Caching metadata in custom broadcasting interface

I’m trying to create a QuantityArray type for DynamicQuantities.jl (on the arrays-2 branch). Original announcement here: [ANN] DynamicQuantities.jl: type stable physical quantities. cc @non-Jedi

The reason you would want a QuantityArray different than simply Quantity is so you can (1) subtype AbstractArray, and (2) so you don’t need to re-compute the units for every element, as you assume they are constant for the entire array.

I basically have (1) down. I can broadcast calculations with a custom broadcast style. But for the life of me I just cannot seem to figure out how to avoid re-computing the units at every element as part of this custom broadcast.

My current plan is to:

  1. Materialize the broadcasted calculation for element 1, and record the physical units.
  2. Strip the units from all inputs to the broadcasted calculation.
  3. Perform a normal broadcasted array calculation without units.
  4. Add the units from (1) to the output array.

Here is what I have so far. The most relevant part is the custom broadcasting (changing some things for readability) which is shown below:

function Base.BroadcastStyle(::Type{QA}) where {QA<:QuantityArray}
    return Broadcast.ArrayStyle{QA}()
end
function Base.BroadcastStyle(
    ::Broadcast.ArrayStyle{QA1}, ::Broadcast.ArrayStyle{QA2}
) where {QA1<:QuantityArray,QA2<:QuantityArray}
    return Broadcast.ArrayStyle{promote_type(QA1,QA2)}()
end

function Base.similar(bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{QA}}, ::Type{ElType}) where {QA<:QuantityArray,ElType}
    q = find_q(bc)
    if isa(q, AbstractQuantity)
        return QuantityArray(similar(array_type(QA), axes(bc)), dimension(q))
    else
        return similar(array_type(QA), axes(bc))
    end
end

# https://discourse.julialang.org/t/defining-broadcast-for-custom-types-the-example-in-the-docs-fails/32291/2
# Basically, we want to solve a single element to find the output dimension. Then
# we can put results in the output `QuantityArray`.
find_q(bc::Base.Broadcast.Broadcasted) = bc.f(find_q.(bc.args)...)

# The rest of these functions are to extract either the quantities, or
# to materialize the lazy broadcast functions.
find_q(x) = x
find_q(q::AbstractQuantity) = q
find_q(r::Base.RefValue) = find_q(r.x)
find_q(x::Base.Broadcast.Extruded) = find_q(x.x)
find_q(args::Tuple) = find_q(find_q(first(args)), Base.tail(args))
find_q(args::AbstractArray) = (@assert length(args) >= 1; find_q(find_q(args[begin], args[begin+1:end])))
find_q(::Tuple{}) = error("Unexpected.")
find_q(q::AbstractQuantity, ::Any) = find_q(q)
find_q(q::AbstractArray{Q}) where {Q<:AbstractQuantity} = find_q(first(q))
find_q(q::AbstractArray{Q}, ::Any) where {Q<:AbstractQuantity} = find_q(first(q))
find_q(q::QuantityArray) = find_q(first(q))
find_q(q::QuantityArray, ::Any) = find_q(first(q))
find_q(::Any, rest) = find_q(find_q(rest))

I tried to base this on this section of the docs, as well as some discourse posts to fill in the holes (this one in particular).

Here is the full `QuantityArray` definition:
const DEFAULT_QUANTITY_TYPE = Quantity

struct QuantityArray{T,N,D<:AbstractDimensions,Q<:AbstractQuantity{T,D},V<:AbstractArray{T,N}} <: AbstractArray{Q,N}
    value::V
    dimensions::D

    QuantityArray(v::_V, d::_D) where {_T,_N,_D<:AbstractDimensions,_V<:AbstractArray{_T,_N}} = new{_T,_N,_D,DEFAULT_QUANTITY_TYPE{_T,_D},_V}(v, d)
    QuantityArray(v::_V, d::_D, ::Type{_Q}) where {_T,_N,_D<:AbstractDimensions,_Q<:AbstractQuantity{_T,_D},_V<:AbstractArray{_T,_N}} = new{_T,_N,_D,_Q,_V}(v, d)
    QuantityArray(v::_V, d::_D, ::Type{_Q}) where {_T,_N,_D<:AbstractDimensions,_Q<:AbstractQuantity{_T},_V<:AbstractArray{_T,_N}} = QuantityArray(v, d, constructor_of(_Q){_T,_D})
    QuantityArray(v::_V, d::_D, ::Type{_Q}) where {_T,_N,_D<:AbstractDimensions,_Q<:AbstractQuantity,_V<:AbstractArray{_T,_N}} = QuantityArray(v, d, _Q{_T,_D})
end

# Construct with a Quantity (easier, as you can use the units):
QuantityArray(v::AbstractArray; kws...) = QuantityArray(v, DEFAULT_DIM_TYPE(; kws...))
QuantityArray(v::AbstractArray, q::AbstractQuantity) = QuantityArray(v .* ustrip(q), dimension(q), typeof(q))
QuantityArray(v::QA) where {Q<:AbstractQuantity,QA<:AbstractArray{Q}} = allequal(dimension.(v)) ? QuantityArray(ustrip.(v), dimension(first(v)), Q) : throw(DimensionError(first(v), v))
# TODO: Should this check that the dimensions are the same?

ustrip(A::QuantityArray) = A.value
dimension(A::QuantityArray) = A.dimensions

array_type(::Type{A}) where {T,A<:QuantityArray{T}} = Array{T,1}
array_type(::Type{A}) where {T,N,A<:QuantityArray{T,N}} = Array{T,N}
array_type(::Type{A}) where {T,N,D,Q,V,A<:QuantityArray{T,N,D,Q,V}} = V
array_type(A) = array_type(typeof(A))

quantity_type(::Type{A}) where {T,N,D,Q,A<:QuantityArray{T,N,D,Q}} = Q
quantity_type(A) = quantity_type(typeof(A))

dim_type(::Type{A}) where {A<:QuantityArray} = DEFAULT_DIM_TYPE
dim_type(::Type{A}) where {T,N,D,A<:QuantityArray{T,N,D}} = D
dim_type(A) = dim_type(typeof(A))

# One field:
for f in (:size, :length, :axes)
    @eval Base.$f(A::QuantityArray) = $f(ustrip(A))
end

Base.getindex(A::QuantityArray, i...) = quantity_type(A)(getindex(ustrip(A), i...), dimension(A))
Base.setindex!(A::QuantityArray{T,N,D,Q}, v::Q, i...) where {T,N,D,Q<:AbstractQuantity} = dimension(A) == dimension(v) ? unsafe_setindex!(A, v, i...) : throw(DimensionError(A, v))
Base.setindex!(A::QuantityArray{T,N,D,Q}, v::AbstractQuantity, i...) where {T,N,D,Q<:AbstractQuantity} = error("Cannot set values in a quantity array with element type $(Q) with different element type: $(typeof(v)).")
# TODO: Should this dimension check be removed?
# TODO: This does not allow for efficient broadcasting; as the dimension calculation is repeated...

unsafe_setindex!(A, v, i...) = setindex!(ustrip(A), ustrip(v), i...)

Base.IndexStyle(::Type{Q}) where {Q<:QuantityArray} = IndexStyle(array_type(Q))

Base.similar(A::QuantityArray) = QuantityArray(similar(ustrip(A)), dimension(A))
Base.similar(A::QuantityArray, ::Type{S}) where {S} = QuantityArray(similar(ustrip(A), S), dimension(A))
Base.similar(A::QuantityArray, dims::Dims) = QuantityArray(similar(ustrip(A), dims), dimension(A))
Base.similar(A::QuantityArray, ::Type{S}, dims::Dims) where {S} = QuantityArray(similar(ustrip(A), S, dims), dimension(A))

Base.similar(::Type{QA}) where {T,QA<:QuantityArray{T}} = QuantityArray(similar(array_type(QA)), dim_type(QA)())
Base.similar(::Type{QA}, ::Type{S}) where {T,QA<:QuantityArray{T},S} = QuantityArray(similar(array_type(QA), S), dim_type(QA)())
Base.similar(::Type{QA}, dims::Dims) where {T,QA<:QuantityArray{T}} = QuantityArray(similar(array_type(QA), dims), dim_type(QA)())
Base.similar(::Type{QA}, ::Type{S}, dims::Dims) where {T,QA<:QuantityArray{T},S} = QuantityArray(similar(array_type(QA), S, dims), dim_type(QA)())

function Base.BroadcastStyle(::Type{QA}) where {QA<:QuantityArray}
    return Broadcast.ArrayStyle{QA}()
end
function Base.BroadcastStyle(
    ::Broadcast.ArrayStyle{QA1}, ::Broadcast.ArrayStyle{QA2}
) where {
    T1,T2,N,V1<:AbstractArray{T1,N},V2<:AbstractArray{T2,N},D<:AbstractDimensions,
    Q1<:AbstractQuantity{T1,D},Q2<:AbstractQuantity{T2,D},
    QA1<:QuantityArray{T1,N,D,Q1,V1},QA2<:QuantityArray{T2,N,D,Q2,V2}
}
    T = promote_type(T1,T2)
    V = promote_type(V1,V2)
    Q = constructor_of(Q1){T,D}
    return Broadcast.ArrayStyle{QuantityArray{T,N,D,Q,V}}()
end
# TODO: How can I ustrip after finding the output units?
function Base.similar(bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{QA}}, ::Type{ElType}) where {QA<:QuantityArray,ElType}
    q = find_q(bc)
    if isa(q, AbstractQuantity)
        return QuantityArray(similar(array_type(QA), axes(bc)), dimension(q))
    else
        return similar(array_type(QA), axes(bc))
    end
end

# https://discourse.julialang.org/t/defining-broadcast-for-custom-types-the-example-in-the-docs-fails/32291/2
# Basically, we want to solve a single element to find the output dimension. Then
# we can put results in the output `QuantityArray`.
find_q(bc::Base.Broadcast.Broadcasted) = bc.f(find_q.(bc.args)...)

# The rest of these functions are to extract either the quantities, or
# to materialize the lazy broadcast functions.
find_q(x) = x
find_q(q::AbstractQuantity) = q
find_q(r::Base.RefValue) = find_q(r.x)
find_q(x::Base.Broadcast.Extruded) = find_q(x.x)
find_q(args::Tuple) = find_q(find_q(first(args)), Base.tail(args))
find_q(args::AbstractArray) = (@assert length(args) >= 1; find_q(find_q(args[begin], args[begin+1:end])))
find_q(::Tuple{}) = error("Unexpected.")
find_q(q::AbstractQuantity, ::Any) = find_q(q)
find_q(q::AbstractArray{Q}) where {Q<:AbstractQuantity} = find_q(first(q))
find_q(q::AbstractArray{Q}, ::Any) where {Q<:AbstractQuantity} = find_q(first(q))
find_q(q::QuantityArray) = find_q(first(q))
find_q(q::QuantityArray, ::Any) = find_q(first(q))
find_q(::Any, rest) = find_q(find_q(rest))

_print_array_type(io::IO, ::Type{QA}) where {QA<:QuantityArray} = print(io, "QuantityArray(::", array_type(QA), ", ::", quantity_type(QA), ")")
Base.showarg(io::IO, v::QuantityArray, _) = _print_array_type(io, typeof(v))
Base.show(io::IO, ::MIME"text/plain", ::Type{QA}) where {QA<:QuantityArray} = _print_array_type(io, QA)

Basically the find_q function materializes the first element of the broadcasted calculation. That gives us the required output type (and in this case, the physical units – which are constant across the input and output arrays). Then a lot of manual debugging to figure out what sort of types can appear in a broadcasted calculation, and define find_q for the missing ones.

But, I can’t figure out how to ustrip the quantities before the rest of the broadcasted calculation takes place, or where I can add the units back after the calculation is done. Perhaps I need to be overloading the broadcasted(...) function instead? There’s unfortunately not a lot of docs to work off of on this, so I haven’t been able to figure this out myself. Any help is much appreciated.

Thanks!
Miles

1 Like

I wonder how far just the existence of QuantityArray already gets you some performance gain, due to loop-invariant code motion?

Out of interest, do you have a minimal example where loop-invariant code motion does not do the trick? i.e. the compiler is unable to recognize that the same units field of the QuantityArray is being used in the same way each iteration and does not lift the computation out of the loop, resulting in a performance loss compared to unitless arrays.

Here’s a quick benchmark:

julia> using DynamicQuantities, BenchmarkTools

julia> x = randn(10_000_000) .* u"km/s";  # Vector{Quantity}

julia> x_qarray = QuantityArray(x);

julia> f(v) = 1.5 .* v .* abs.(v ./ v) .^ 2 .* v;

julia> @btime f($x);
  138.926 ms (2 allocations: 381.47 MiB)

julia> @btime f($x_qarray);
  144.424 ms (35 allocations: 76.30 MiB)

julia> @btime f(normal_array) setup=(normal_array=randn(10_000_000));
  12.132 ms (2 allocations: 76.29 MiB)

so my custom array interface is lowering the memory usage as it only stores a single set of units for the whole vector. But it doesn’t even approach the speed of a normal array calculation. So I think stripping the units, broadcasting using a normal array, and then putting the units back in, is probably the fastest option. I just can’t figure out how to implement the interface.

I also tried broadcasting the entire statement, and looks like something is wrong with my custom broadcast:

julia> f(v) = @. 1.5 * v * abs(v / v) ^ 2 * v;

julia> @btime f($x);
  136.529 ms (2 allocations: 381.47 MiB)

julia> @btime f($x_qarray);
  395.655 ms (19999507 allocations: 686.64 MiB)

julia> @btime f(normal_array) setup=(normal_array=randn(10_000_000));
  12.102 ms (2 allocations: 76.29 MiB)
1 Like

It does seem like it’s possible for LICM to help out in some cases. After profiling a for-loop sum I was able to hunt down a missing overload to Base.literal_pow for quantities, so that QuantityArray’s are equally fast as a unitless array in that case:

using BenchmarkTools
using DynamicQuantities 

function sum_square(arr)
    S = zero(arr[1] * arr[1])
    for i in 1:length(arr)
        S += arr[i] * arr[i]
    end
    return S
end

function sum_square_simd(arr)
    S = zero(arr[1] * arr[1])
    @simd for i in 1:length(arr)
        S += arr[i] * arr[i]
    end
    return S
end

function sum_square_simd_pow(arr)
    S = zero(arr[1] * arr[1])
    @simd for i in 1:length(arr)
        S += arr[i]^2
    end
    return S
end

## Benchmarking

x = rand(10_000)
x_arrq = x .* u"km/s";  # Vector{Quantity}
x_qarr = QuantityArray(x_arrq);

@btime sum_square($x); # 8.5 μs
@btime sum_square($x_arrq); # 10.8 μs
@btime sum_square($x_qarr); # 8.5 μs 

@btime sum_square_simd($x); # 540 ns
@btime sum_square_simd($x_arrq); # 10.7 μs
@btime sum_square_simd($x_qarr); # 540 ns (fast as possible)

@btime sum_square_simd_pow($x); # 540 ns
@btime sum_square_simd_pow($x_arrq); # 120 μs
@btime sum_square_simd_pow($x_qarr); # 113 μs (!!)

## now with the literal_pow fix (hard coded for power 2 here)

function Base.literal_pow(::typeof(^), l::AbstractQuantity{T,D}, ::Val{2}) where {T,R,D<:AbstractDimensions{R}}
    DynamicQuantities.new_quantity(typeof(l), ustrip(l)*ustrip(l), dimension(l)*dimension(l))
end

@btime sum_square_simd_pow($x_arrq); # 10.6 μs
@btime sum_square_simd_pow($x_qarr); # 540 ns (fast as possible!)

(See https://github.com/JuliaLang/julia/blob/b99f251e86c7c09b957a1b362b6408dbba106ff0/base/intfuncs.jl#L332, https://github.com/JuliaDiff/ForwardDiff.jl/blob/e3670ce9055c66863f655d2bac2d6615c165d838/src/dual.jl#L578)

But the literal_pow is not the only missing scalar optimization for abstract quantities; even without LICM I imagine you’d expect at most a factor of 2 drop in performance? And it doesn’t totally close the gap on your more involved benchmarks, although it does help:)

Sorry for getting sidetracked from your original goal… and this adventure definitely does justify that goal; I imagine making quantities super speedy at a scalar level should be feasible by systematically performing all the necessary overloads, but overloading the broadcast machinery would get the speed without the headache for those operations.


Edit: with the literal_pow fix above, as well as removing the dimension check in https://github.com/SymbolicML/DynamicQuantities.jl/blob/0459fda9b1072dcb66b37d99abe2329bbf18f781/src/arrays.jl#L69 which it seems the compiler wasn’t smart enough to lift out, I get:

x = rand(1_000_000);
x_qarr = QuantityArray(x .* u"km/s");
g(v) = 1.5 * v * abs(v / v) ^ 2 * v; 

@btime g.($x); # 630 μs
@btime g.($x_qarr); # 940 μs

which isn’t too bad!

Thanks so much @tictaccat, this is very helpful. Indeed it does seem like adding a definition literal_pow gets us much further to the performance of a normal array!

I have added all of your suggestions to the array-2 branch.

Here's the updated interface, which has also been cleaned up a bit:
const DEFAULT_QUANTITY_TYPE = Quantity

"""
    QuantityArray{T,N,D<:AbstractDimensions,Q<:AbstractQuantity,V<:AbstractArray}

An array of quantities with value `value` of type `V` and dimensions `dimensions` of type `D`
(which are shared across all elements of the array). This is a subtype of `AbstractArray{Q,N}`,
and so can be used in most places where a normal array would be used.

# Fields

- `value`: The underlying array of values.
- `dimensions`: The dimensions of the array.
"""
struct QuantityArray{T,N,D<:AbstractDimensions,Q<:AbstractQuantity{T,D},V<:AbstractArray{T,N}} <: AbstractArray{Q,N}
    value::V
    dimensions::D

    function QuantityArray(v::_V, d::_D, ::Type{_Q}) where {_T,_N,_D<:AbstractDimensions,_Q<:AbstractQuantity,_V<:AbstractArray{_T,_N}}
        Q_out = constructor_of(_Q){_T,_D}
        return new{_T,_N,_D,Q_out,_V}(v, d)
    end
end

# Construct with a Quantity (easier, as you can use the units):
QuantityArray(v::AbstractArray; kws...) = QuantityArray(v, DEFAULT_DIM_TYPE(; kws...))
QuantityArray(v::AbstractArray, d::AbstractDimensions) = QuantityArray(v, d, DEFAULT_QUANTITY_TYPE)
QuantityArray(v::AbstractArray, q::AbstractQuantity) = QuantityArray(v .* ustrip(q), dimension(q), typeof(q))
QuantityArray(v::QA) where {Q<:AbstractQuantity,QA<:AbstractArray{Q}} = allequal(dimension.(v)) ? QuantityArray(ustrip.(v), dimension(first(v)), Q) : throw(DimensionError(first(v), v))
# TODO: Should this check that the dimensions are the same?

function Base.promote_rule(::Type{QA1}, ::Type{QA2}) where {QA1<:QuantityArray,QA2<:QuantityArray}
    D = promote_type(dim_type.((QA1, QA2))...)
    Q = promote_type(quantity_type.((QA1, QA2))...)
    T = promote_type(value_type.((QA1, QA2))...)
    V = promote_type(array_type.((QA1, QA2))...)
    N = ndims(QA1)

    @assert(Q <: AbstractQuantity{T,D}, "Incompatible promotion rules.")
    @assert(V <: AbstractArray{T}, "Incompatible promotion rules.")

    if N != ndims(QA2)
        return QuantityArray{T,_N,D,Q,V} where {_N}
    else
        return QuantityArray{T,N,D,Q,V}
    end
end

@inline ustrip(A::QuantityArray) = A.value
@inline dimension(A::QuantityArray) = A.dimensions

array_type(::Type{A}) where {T,A<:QuantityArray{T}} = Array{T,1}
array_type(::Type{A}) where {T,N,A<:QuantityArray{T,N}} = Array{T,N}
array_type(::Type{A}) where {T,N,D,Q,V,A<:QuantityArray{T,N,D,Q,V}} = V
array_type(A) = array_type(typeof(A))

quantity_type(::Type{A}) where {T,N,D,Q,A<:QuantityArray{T,N,D,Q}} = Q
quantity_type(A) = quantity_type(typeof(A))

dim_type(::Type{A}) where {A<:QuantityArray} = DEFAULT_DIM_TYPE
dim_type(::Type{A}) where {T,N,D,A<:QuantityArray{T,N,D}} = D
dim_type(A) = dim_type(typeof(A))

value_type(::Type{A}) where {A<:QuantityArray} = DEFAULT_VALUE_TYPE
value_type(::Type{A}) where {T,A<:QuantityArray{T}} = T
value_type(::Type{Q}) where {T,Q<:AbstractQuantity{T}} = T
value_type(A) = value_type(typeof(A))

# One field:
for f in (:size, :length, :axes)
    @eval Base.$f(A::QuantityArray) = $f(ustrip(A))
end

function Base.getindex(A::QuantityArray, i...)
    output_value = getindex(ustrip(A), i...)
    if isa(output_value, AbstractArray)
        return QuantityArray(output_value, dimension(A), quantity_type(A))
    else
        return new_quantity(quantity_type(A), output_value, dimension(A))
    end
end
Base.setindex!(A::QuantityArray{T,N,D,Q}, v::Q, i...) where {T,N,D,Q<:AbstractQuantity} = dimension(A) == dimension(v) ? unsafe_setindex!(A, v, i...) : throw(DimensionError(A, v))
Base.setindex!(A::QuantityArray{T,N,D,Q}, v::AbstractQuantity, i...) where {T,N,D,Q<:AbstractQuantity} = setindex!(A, convert(Q, v), i...)

unsafe_setindex!(A, v, i...) = setindex!(ustrip(A), ustrip(v), i...)

Base.IndexStyle(::Type{Q}) where {Q<:QuantityArray} = IndexStyle(array_type(Q))

Base.similar(A::QuantityArray) = QuantityArray(similar(ustrip(A)), dimension(A), quantity_type(A))
Base.similar(A::QuantityArray, ::Type{S}) where {S} = QuantityArray(similar(ustrip(A), S), dimension(A), quantity_type(A))
Base.similar(A::QuantityArray, dims::Dims) = QuantityArray(similar(ustrip(A), dims), dimension(A), quantity_type(A))
Base.similar(A::QuantityArray, ::Type{S}, dims::Dims) where {S} = QuantityArray(similar(ustrip(A), S, dims), dimension(A), quantity_type(A))

Base.similar(::Type{QA}) where {T,QA<:QuantityArray{T}} = QuantityArray(similar(array_type(QA)), dim_type(QA)(), quantity_type(QA))
Base.similar(::Type{QA}, ::Type{S}) where {T,QA<:QuantityArray{T},S} = QuantityArray(similar(array_type(QA), S), dim_type(QA)(), quantity_type(QA))
Base.similar(::Type{QA}, dims::Dims) where {T,QA<:QuantityArray{T}} = QuantityArray(similar(array_type(QA), dims), dim_type(QA)(), quantity_type(QA))
Base.similar(::Type{QA}, ::Type{S}, dims::Dims) where {T,QA<:QuantityArray{T},S} = QuantityArray(similar(array_type(QA), S, dims), dim_type(QA)(), quantity_type(QA))

function Base.BroadcastStyle(::Type{QA}) where {QA<:QuantityArray}
    return Broadcast.ArrayStyle{QA}()
end
function Base.BroadcastStyle(::Broadcast.ArrayStyle{QA1}, ::Broadcast.ArrayStyle{QA2}) where {QA1<:QuantityArray,QA2<:QuantityArray}
    return Broadcast.ArrayStyle{promote_type(QA1, QA2)}()
end
function Base.similar(bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{QA}}, ::Type{ElType}) where {QA<:QuantityArray,ElType}
    output_array_type = constructor_of(array_type(QA)){unwrap_quantity(ElType)}
    output_array = similar(output_array_type, axes(bc))

    if ElType <: AbstractQuantity
        first_output = materialize_first(bc)
        if typeof(first_output) != ElType
            @warn (
                "Materialization of first element likely failed. "
                * "Please submit a bug report with information on "
                * "the function you are broadcasting."
            )
        end
        return QuantityArray(output_array, dimension(first_output), ElType)
    else
        return output_array
    end
end
unwrap_quantity(::Type{Q}) where {T,Q<:AbstractQuantity{T}} = T
unwrap_quantity(::Type{T}) where {T} = T

# Basically, we want to solve a single element to find the output dimension.
# Then we can put results in the output `QuantityArray`.
materialize_first(bc::Base.Broadcast.Broadcasted) = bc.f(materialize_first.(bc.args)...)

# Base cases
materialize_first(q::AbstractQuantity) = q
materialize_first(q::AbstractQuantity, ::Any) = q
materialize_first(q::QuantityArray) = first(q)
materialize_first(q::QuantityArray, ::Any) = first(q)
materialize_first(q::AbstractArray{Q}) where {Q<:AbstractQuantity} = first(q)
materialize_first(q::AbstractArray{Q}, ::Any) where {Q<:AbstractQuantity} = first(q)

# Derived calls
materialize_first(r::Base.RefValue) = materialize_first(r.x)
materialize_first(x::Base.Broadcast.Extruded) = materialize_first(x.x)
materialize_first(args::Tuple) = materialize_first(first(args), Base.tail(args))
materialize_first(args::AbstractArray) = length(args) >= 1 ? materialize_first(args[begin], args[begin+1:end]) : error("Unexpected.")
materialize_first(::Tuple{}) = error("Unexpected.")
materialize_first(::Any, rest) = materialize_first(rest)

# Everything else:
materialize_first(x) = x

_print_array_type(io::IO, ::Type{QA}) where {QA<:QuantityArray} = print(io, "QuantityArray(::", array_type(QA), ", ::", quantity_type(QA), ")")
Base.showarg(io::IO, v::QuantityArray, _) = _print_array_type(io, typeof(v))
Base.show(io::IO, ::MIME"text/plain", ::Type{QA}) where {QA<:QuantityArray} = _print_array_type(io, QA)

# Other array operations:
Base.copy(A::QuantityArray) = QuantityArray(copy(ustrip(A)), dimension(A), quantity_type(A))
function Base.cat(A::QuantityArray...; dims)
    if !allequal(dimension.(A))
        throw(DimensionError(A[begin], A[begin+1:end]))
    end
    return QuantityArray(cat(ustrip.(A)...; dims=dims), dimension(A[begin]), quantity_type(A[begin]))
end
Base.hcat(A::QuantityArray...) = cat(A...; dims=2)
Base.vcat(A::QuantityArray...) = cat(A...; dims=1)
Base.fill(x::AbstractQuantity, dims::Dims...) = QuantityArray(fill(ustrip(x), dims...), dimension(x), typeof(x))

In particular, the key function, which I tried to make more readable, is:

function Base.similar(bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{QA}}, ::Type{ElType}) where {QA<:QuantityArray,ElType}
    output_array_type = constructor_of(array_type(QA)){unwrap_quantity(ElType)}
    output_array = similar(output_array_type, axes(bc))

    if ElType <: AbstractQuantity
        first_output = materialize_first(bc)
        if typeof(first_output) != ElType
            @warn (
                "Materialization of first element likely failed. "
                * "Please submit a bug report with information on "
                * "the function you are broadcasting."
            )
        end
        return QuantityArray(output_array, dimension(first_output), ElType)
    else
        return output_array
    end
end

so that it tries to materialize the first output element, take the dimensions of that to build a QuantityArray, and then use that array for output storage.

In retrospect, maybe this is the best we could possibly do. It does seem a bit dangerous to strip units from the rest of the calculation. For example, what if the user had written, e.g.,

f(x) = ustrip(x) > 0.5 ? x : x * u"km/s"
f.(x_qarr)

we would still want that dimension(array) == dimension(value) check to happen for every output element, rather than pre-emptively stripping units from x.

I guess the best we could hope for is that the compiler could be smart enough to avoid redundant dimension calculations if it sees the same dimensions object being used in an inner loop without interacting with other variables. What do you think?

1 Like

I think that makes sense! In theory I suppose the compiler should be able to prove it (after all, it’s able to do such reasoning at a typelevel, so the infrastructure is there). I am not a compiler expert, but I have seen things like Base.@constprop :aggressive bandied about and maybe these fancy annotations could help squeeze out the last drop of performance in cases where dimension calculations are a bottleneck. But I agree with your overall conclusion, glad that performance is much improved:)

Cool, I didn’t know this! I’ll give it a go.

I put this into the following PR where I’ll continue mucking about: https://github.com/SymbolicML/DynamicQuantities.jl/pull/33. Will be interesting to see if the perf can get better somehow.

1 Like