The goal is to calculate a number between two Vector-ish things in a type stable manner, without type assertions. This is part of a larger workflow that involves ForwardDiff.gradient
, which runs, but slow, hence the inquiry
using Dates, ForwardDiff, AxisKeys
r = rand(Float32, 4)
r[1] = NaN32
ka = KeyedArray(r; i =Date(today()):Date(today()+Day(3)) )
ar = Union{Float32, ForwardDiff.Dual}[1f0,2f0, ForwardDiff.Dual(NaN32), 2f0]
idx_f(ar,ka) = (.!isnan.(ar .* ka))
idxs = idx_f(ar,ka)
For this initial function we already see some instabilities
@code_warntype idx_f(ar,ka)
Arguments #self#::Core.Const(idx_f) ar::Vector{Union{Float32, ForwardDiff.Dual}} ka::KeyedArray{Float32, 1, NamedDimsArray{(:i,), Float32, 1, Vector{Float32}}, Base.RefValue{StepRange{Date, Day}}} Body::KeyedArray{_A, 1, _B, Base.RefValue{StepRange{Date, Day}}} where {_A, _B} 1 β %1 = Main.:!::Core.Const(!) β %2 = Main.isnan::Core.Const(isnan) β %3 = Base.broadcasted(Main.:*, ar, ka)::Base.Broadcast.Broadcasted{AxisKeys.KeyedStyle{NamedDims.NamedDimsStyle{Base.Broadcast.DefaultArrayStyle{1}}}, Nothing, typeof(*), T} β %4 = Base.broadcasted(%2, %3)::Base.Broadcast.Broadcasted{AxisKeys.KeyedStyle{NamedDims.NamedDimsStyle{Base.Broadcast.DefaultArrayStyle{1}}}, Nothing, typeof(isnan), NT} β %5 = Base.broadcasted(%1, %4)::Base.Broadcast.Broadcasted{AxisKeys.KeyedStyle{NamedDims.NamedDimsStyle{Base.Broadcast.DefaultArrayStyle{1}}}, Nothing, typeof(!), NT} β %6 = Base.materialize(%5)::KeyedArray{_A, 1, _B, Base.RefValue{StepRange{Date, Day}}} where {_A, _B} βββ return %6
and letβs say that the target function is
function ka_ar_unstable(ka, ar,idxs)
return abs2.(ka[idxs] .- ar[idxs])
end
@code_warntype ka_ar_unstable(ka, ar, idxs)
Arguments #self#::Core.Const(ka_ar_unstable) ka::KeyedArray{Float32, 1, NamedDimsArray{(:i,), Float32, 1, Vector{Float32}}, Base.RefValue{StepRange{Date, Day}}} ar::Vector{Union{Float32, ForwardDiff.Dual}} idxs::KeyedArray{Bool, 1, NamedDimsArray{(:i,), Bool, 1, BitVector}, Base.RefValue{StepRange{Date, Day}}} Body::KeyedArray{_A, 1, _B, Base.RefValue{Vector{Date}}} where {_A, _B} 1 β %1 = Main.abs2::Core.Const(abs2) β %2 = Main.:-::Core.Const(-) β %3 = Base.getindex(ka, idxs)::KeyedArray{Float32, 1, NamedDimsArray{(:i,), Float32, 1, Vector{Float32}}, Base.RefValue{Vector{Date}}} β %4 = Base.getindex(ar, idxs)::Vector{Union{Float32, ForwardDiff.Dual}} β %5 = Base.broadcasted(%2, %3, %4)::Base.Broadcast.Broadcasted{AxisKeys.KeyedStyle{NamedDims.NamedDimsStyle{Base.Broadcast.DefaultArrayStyle{1}}}, Nothing, typeof(-), T} β %6 = Base.broadcasted(%1, %5)::Base.Broadcast.Broadcasted{AxisKeys.KeyedStyle{NamedDims.NamedDimsStyle{Base.Broadcast.DefaultArrayStyle{1}}}, Nothing, typeof(abs2), NT} β %7 = Base.materialize(%6)::KeyedArray{_A, 1, _B, Base.RefValue{Vector{Date}}} where {_A, _B} βββ return %7
and all togheter
function all_pack(ar, ka)
idts = idx_f(ar,ka) # this needs to be called here because, 'ar' and 'ka'
#are comming from an outer loop.
vals = ka_ar_unstable(ka, ar,idxs)
return sum(vals)
end
@code_warntype all_pack(ar, ka)
Arguments #self#::Core.Const(all_pack) ar::Vector{Union{Float32, ForwardDiff.Dual}} ka::KeyedArray{Float32, 1, NamedDimsArray{(:i,), Float32, 1, Vector{Float32}}, Base.RefValue{StepRange{Date, Day}}} Locals vals::Any idxs::KeyedArray{_A, 1, _B, Base.RefValue{StepRange{Date, Day}}} where {_A, _B} Body::Any 1 β (idxs = Main.idx_f(ar, ka)) β (vals = Main.ka_ar_unstable(ka, ar, idxs)) β %3 = Main.sum(vals)::Any βββ return %3
doing the indices inside leads to Base.getindex
issues directly.
function ka_ar_unstable(ka, ar)
idxs = (.!isnan.(ar .* ka))
return abs2.(ka[idxs] .- ar[idxs])
end
@code_warntype ka_ar_unstable(ar, ka)
Arguments
#self#::Core.Const(ka_ar_unstable)
ka::Vector{Union{Float32, ForwardDiff.Dual}}
ar::KeyedArray{Float32, 1, NamedDimsArray{(:i,), Float32, 1, Vector{Float32}}, Base.RefValue{StepRange{Date, Day}}}
Locals
idxs::KeyedArray{_A, 1, _B, Base.RefValue{StepRange{Date, Day}}} where {_A, _B}
Body::Any
1 β %1 = Main.:!::Core.Const(!)
β %2 = Main.isnan::Core.Const(isnan)
β %3 = Base.broadcasted(Main.:, ar, ka)::Base.Broadcast.Broadcasted{AxisKeys.KeyedStyle{NamedDims.NamedDimsStyle{Base.Broadcast.DefaultArrayStyle{1}}}, Nothing, typeof(), T}
β %4 = Base.broadcasted(%2, %3)::Base.Broadcast.Broadcasted{AxisKeys.KeyedStyle{NamedDims.NamedDimsStyle{Base.Broadcast.DefaultArrayStyle{1}}}, Nothing, typeof(isnan), NT}
β %5 = Base.broadcasted(%1, %4)::Base.Broadcast.Broadcasted{AxisKeys.KeyedStyle{NamedDims.NamedDimsStyle{Base.Broadcast.DefaultArrayStyle{1}}}, Nothing, typeof(!), NT}
β (idxs = Base.materialize(%5))
β %7 = Main.abs2::Core.Const(abs2)
β %8 = Main.:-::Core.Const(-)
β %9 = Base.getindex(ka, idxs)::Any
β %10 = Base.getindex(ar, idxs)::Any
β %11 = Base.broadcasted(%8, %9, %10)::Any
β %12 = Base.broadcasted(%7, %11)::Any
β %13 = Base.materialize(%12)::Any
βββ return %13
any ideas or hints that could help solve this issue would be greatly appreciated . Maybe is something very trivial for some .
pkg> status Dates ForwardDiff AxisKeys
Status `~/Project.toml`
[94b1ba4f] AxisKeys v0.2.13
[f6369f11] ForwardDiff v0.10.35