I’m trying to create a QuantityArray
type for DynamicQuantities.jl (on the arrays-2
branch). Original announcement here: [ANN] DynamicQuantities.jl: type stable physical quantities. cc @non-Jedi
The reason you would want a QuantityArray
different than simply Quantity
is so you can (1) subtype AbstractArray
, and (2) so you don’t need to re-compute the units for every element, as you assume they are constant for the entire array.
I basically have (1) down. I can broadcast calculations with a custom broadcast style. But for the life of me I just cannot seem to figure out how to avoid re-computing the units at every element as part of this custom broadcast.
My current plan is to:
- Materialize the broadcasted calculation for element 1, and record the physical units.
- Strip the units from all inputs to the broadcasted calculation.
- Perform a normal broadcasted array calculation without units.
- Add the units from (1) to the output array.
Here is what I have so far. The most relevant part is the custom broadcasting (changing some things for readability) which is shown below:
function Base.BroadcastStyle(::Type{QA}) where {QA<:QuantityArray}
return Broadcast.ArrayStyle{QA}()
end
function Base.BroadcastStyle(
::Broadcast.ArrayStyle{QA1}, ::Broadcast.ArrayStyle{QA2}
) where {QA1<:QuantityArray,QA2<:QuantityArray}
return Broadcast.ArrayStyle{promote_type(QA1,QA2)}()
end
function Base.similar(bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{QA}}, ::Type{ElType}) where {QA<:QuantityArray,ElType}
q = find_q(bc)
if isa(q, AbstractQuantity)
return QuantityArray(similar(array_type(QA), axes(bc)), dimension(q))
else
return similar(array_type(QA), axes(bc))
end
end
# https://discourse.julialang.org/t/defining-broadcast-for-custom-types-the-example-in-the-docs-fails/32291/2
# Basically, we want to solve a single element to find the output dimension. Then
# we can put results in the output `QuantityArray`.
find_q(bc::Base.Broadcast.Broadcasted) = bc.f(find_q.(bc.args)...)
# The rest of these functions are to extract either the quantities, or
# to materialize the lazy broadcast functions.
find_q(x) = x
find_q(q::AbstractQuantity) = q
find_q(r::Base.RefValue) = find_q(r.x)
find_q(x::Base.Broadcast.Extruded) = find_q(x.x)
find_q(args::Tuple) = find_q(find_q(first(args)), Base.tail(args))
find_q(args::AbstractArray) = (@assert length(args) >= 1; find_q(find_q(args[begin], args[begin+1:end])))
find_q(::Tuple{}) = error("Unexpected.")
find_q(q::AbstractQuantity, ::Any) = find_q(q)
find_q(q::AbstractArray{Q}) where {Q<:AbstractQuantity} = find_q(first(q))
find_q(q::AbstractArray{Q}, ::Any) where {Q<:AbstractQuantity} = find_q(first(q))
find_q(q::QuantityArray) = find_q(first(q))
find_q(q::QuantityArray, ::Any) = find_q(first(q))
find_q(::Any, rest) = find_q(find_q(rest))
I tried to base this on this section of the docs, as well as some discourse posts to fill in the holes (this one in particular).
Here is the full `QuantityArray` definition:
const DEFAULT_QUANTITY_TYPE = Quantity
struct QuantityArray{T,N,D<:AbstractDimensions,Q<:AbstractQuantity{T,D},V<:AbstractArray{T,N}} <: AbstractArray{Q,N}
value::V
dimensions::D
QuantityArray(v::_V, d::_D) where {_T,_N,_D<:AbstractDimensions,_V<:AbstractArray{_T,_N}} = new{_T,_N,_D,DEFAULT_QUANTITY_TYPE{_T,_D},_V}(v, d)
QuantityArray(v::_V, d::_D, ::Type{_Q}) where {_T,_N,_D<:AbstractDimensions,_Q<:AbstractQuantity{_T,_D},_V<:AbstractArray{_T,_N}} = new{_T,_N,_D,_Q,_V}(v, d)
QuantityArray(v::_V, d::_D, ::Type{_Q}) where {_T,_N,_D<:AbstractDimensions,_Q<:AbstractQuantity{_T},_V<:AbstractArray{_T,_N}} = QuantityArray(v, d, constructor_of(_Q){_T,_D})
QuantityArray(v::_V, d::_D, ::Type{_Q}) where {_T,_N,_D<:AbstractDimensions,_Q<:AbstractQuantity,_V<:AbstractArray{_T,_N}} = QuantityArray(v, d, _Q{_T,_D})
end
# Construct with a Quantity (easier, as you can use the units):
QuantityArray(v::AbstractArray; kws...) = QuantityArray(v, DEFAULT_DIM_TYPE(; kws...))
QuantityArray(v::AbstractArray, q::AbstractQuantity) = QuantityArray(v .* ustrip(q), dimension(q), typeof(q))
QuantityArray(v::QA) where {Q<:AbstractQuantity,QA<:AbstractArray{Q}} = allequal(dimension.(v)) ? QuantityArray(ustrip.(v), dimension(first(v)), Q) : throw(DimensionError(first(v), v))
# TODO: Should this check that the dimensions are the same?
ustrip(A::QuantityArray) = A.value
dimension(A::QuantityArray) = A.dimensions
array_type(::Type{A}) where {T,A<:QuantityArray{T}} = Array{T,1}
array_type(::Type{A}) where {T,N,A<:QuantityArray{T,N}} = Array{T,N}
array_type(::Type{A}) where {T,N,D,Q,V,A<:QuantityArray{T,N,D,Q,V}} = V
array_type(A) = array_type(typeof(A))
quantity_type(::Type{A}) where {T,N,D,Q,A<:QuantityArray{T,N,D,Q}} = Q
quantity_type(A) = quantity_type(typeof(A))
dim_type(::Type{A}) where {A<:QuantityArray} = DEFAULT_DIM_TYPE
dim_type(::Type{A}) where {T,N,D,A<:QuantityArray{T,N,D}} = D
dim_type(A) = dim_type(typeof(A))
# One field:
for f in (:size, :length, :axes)
@eval Base.$f(A::QuantityArray) = $f(ustrip(A))
end
Base.getindex(A::QuantityArray, i...) = quantity_type(A)(getindex(ustrip(A), i...), dimension(A))
Base.setindex!(A::QuantityArray{T,N,D,Q}, v::Q, i...) where {T,N,D,Q<:AbstractQuantity} = dimension(A) == dimension(v) ? unsafe_setindex!(A, v, i...) : throw(DimensionError(A, v))
Base.setindex!(A::QuantityArray{T,N,D,Q}, v::AbstractQuantity, i...) where {T,N,D,Q<:AbstractQuantity} = error("Cannot set values in a quantity array with element type $(Q) with different element type: $(typeof(v)).")
# TODO: Should this dimension check be removed?
# TODO: This does not allow for efficient broadcasting; as the dimension calculation is repeated...
unsafe_setindex!(A, v, i...) = setindex!(ustrip(A), ustrip(v), i...)
Base.IndexStyle(::Type{Q}) where {Q<:QuantityArray} = IndexStyle(array_type(Q))
Base.similar(A::QuantityArray) = QuantityArray(similar(ustrip(A)), dimension(A))
Base.similar(A::QuantityArray, ::Type{S}) where {S} = QuantityArray(similar(ustrip(A), S), dimension(A))
Base.similar(A::QuantityArray, dims::Dims) = QuantityArray(similar(ustrip(A), dims), dimension(A))
Base.similar(A::QuantityArray, ::Type{S}, dims::Dims) where {S} = QuantityArray(similar(ustrip(A), S, dims), dimension(A))
Base.similar(::Type{QA}) where {T,QA<:QuantityArray{T}} = QuantityArray(similar(array_type(QA)), dim_type(QA)())
Base.similar(::Type{QA}, ::Type{S}) where {T,QA<:QuantityArray{T},S} = QuantityArray(similar(array_type(QA), S), dim_type(QA)())
Base.similar(::Type{QA}, dims::Dims) where {T,QA<:QuantityArray{T}} = QuantityArray(similar(array_type(QA), dims), dim_type(QA)())
Base.similar(::Type{QA}, ::Type{S}, dims::Dims) where {T,QA<:QuantityArray{T},S} = QuantityArray(similar(array_type(QA), S, dims), dim_type(QA)())
function Base.BroadcastStyle(::Type{QA}) where {QA<:QuantityArray}
return Broadcast.ArrayStyle{QA}()
end
function Base.BroadcastStyle(
::Broadcast.ArrayStyle{QA1}, ::Broadcast.ArrayStyle{QA2}
) where {
T1,T2,N,V1<:AbstractArray{T1,N},V2<:AbstractArray{T2,N},D<:AbstractDimensions,
Q1<:AbstractQuantity{T1,D},Q2<:AbstractQuantity{T2,D},
QA1<:QuantityArray{T1,N,D,Q1,V1},QA2<:QuantityArray{T2,N,D,Q2,V2}
}
T = promote_type(T1,T2)
V = promote_type(V1,V2)
Q = constructor_of(Q1){T,D}
return Broadcast.ArrayStyle{QuantityArray{T,N,D,Q,V}}()
end
# TODO: How can I ustrip after finding the output units?
function Base.similar(bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{QA}}, ::Type{ElType}) where {QA<:QuantityArray,ElType}
q = find_q(bc)
if isa(q, AbstractQuantity)
return QuantityArray(similar(array_type(QA), axes(bc)), dimension(q))
else
return similar(array_type(QA), axes(bc))
end
end
# https://discourse.julialang.org/t/defining-broadcast-for-custom-types-the-example-in-the-docs-fails/32291/2
# Basically, we want to solve a single element to find the output dimension. Then
# we can put results in the output `QuantityArray`.
find_q(bc::Base.Broadcast.Broadcasted) = bc.f(find_q.(bc.args)...)
# The rest of these functions are to extract either the quantities, or
# to materialize the lazy broadcast functions.
find_q(x) = x
find_q(q::AbstractQuantity) = q
find_q(r::Base.RefValue) = find_q(r.x)
find_q(x::Base.Broadcast.Extruded) = find_q(x.x)
find_q(args::Tuple) = find_q(find_q(first(args)), Base.tail(args))
find_q(args::AbstractArray) = (@assert length(args) >= 1; find_q(find_q(args[begin], args[begin+1:end])))
find_q(::Tuple{}) = error("Unexpected.")
find_q(q::AbstractQuantity, ::Any) = find_q(q)
find_q(q::AbstractArray{Q}) where {Q<:AbstractQuantity} = find_q(first(q))
find_q(q::AbstractArray{Q}, ::Any) where {Q<:AbstractQuantity} = find_q(first(q))
find_q(q::QuantityArray) = find_q(first(q))
find_q(q::QuantityArray, ::Any) = find_q(first(q))
find_q(::Any, rest) = find_q(find_q(rest))
_print_array_type(io::IO, ::Type{QA}) where {QA<:QuantityArray} = print(io, "QuantityArray(::", array_type(QA), ", ::", quantity_type(QA), ")")
Base.showarg(io::IO, v::QuantityArray, _) = _print_array_type(io, typeof(v))
Base.show(io::IO, ::MIME"text/plain", ::Type{QA}) where {QA<:QuantityArray} = _print_array_type(io, QA)
Basically the find_q
function materializes the first element of the broadcasted calculation. That gives us the required output type (and in this case, the physical units – which are constant across the input and output arrays). Then a lot of manual debugging to figure out what sort of types can appear in a broadcasted calculation, and define find_q
for the missing ones.
But, I can’t figure out how to ustrip
the quantities before the rest of the broadcasted calculation takes place, or where I can add the units back after the calculation is done. Perhaps I need to be overloading the broadcasted(...)
function instead? There’s unfortunately not a lot of docs to work off of on this, so I haven’t been able to figure this out myself. Any help is much appreciated.
Thanks!
Miles