Instability with union of types with ForwardDiff and KeyedArray types, hints

The goal is to calculate a number between two Vector-ish things in a type stable manner, without type assertions. This is part of a larger workflow that involves ForwardDiff.gradient, which runs, but slow, hence the inquiry :smiley:

using Dates, ForwardDiff, AxisKeys
r = rand(Float32, 4)
r[1] = NaN32    
ka = KeyedArray(r; i =Date(today()):Date(today()+Day(3)) )
ar = Union{Float32, ForwardDiff.Dual}[1f0,2f0, ForwardDiff.Dual(NaN32), 2f0]
idx_f(ar,ka) = (.!isnan.(ar .* ka))
idxs = idx_f(ar,ka)

For this initial function we already see some instabilities

@code_warntype idx_f(ar,ka)
Arguments #self#::Core.Const(idx_f) ar::Vector{Union{Float32, ForwardDiff.Dual}} ka::KeyedArray{Float32, 1, NamedDimsArray{(:i,), Float32, 1, Vector{Float32}}, Base.RefValue{StepRange{Date, Day}}} Body::KeyedArray{_A, 1, _B, Base.RefValue{StepRange{Date, Day}}} where {_A, _B} 1 ─ %1 = Main.:!::Core.Const(!) β”‚ %2 = Main.isnan::Core.Const(isnan) β”‚ %3 = Base.broadcasted(Main.:*, ar, ka)::Base.Broadcast.Broadcasted{AxisKeys.KeyedStyle{NamedDims.NamedDimsStyle{Base.Broadcast.DefaultArrayStyle{1}}}, Nothing, typeof(*), T} β”‚ %4 = Base.broadcasted(%2, %3)::Base.Broadcast.Broadcasted{AxisKeys.KeyedStyle{NamedDims.NamedDimsStyle{Base.Broadcast.DefaultArrayStyle{1}}}, Nothing, typeof(isnan), NT} β”‚ %5 = Base.broadcasted(%1, %4)::Base.Broadcast.Broadcasted{AxisKeys.KeyedStyle{NamedDims.NamedDimsStyle{Base.Broadcast.DefaultArrayStyle{1}}}, Nothing, typeof(!), NT} β”‚ %6 = Base.materialize(%5)::KeyedArray{_A, 1, _B, Base.RefValue{StepRange{Date, Day}}} where {_A, _B} └── return %6

and let’s say that the target function is

function ka_ar_unstable(ka, ar,idxs)
    return abs2.(ka[idxs] .- ar[idxs])
end

@code_warntype ka_ar_unstable(ka, ar, idxs)
Arguments #self#::Core.Const(ka_ar_unstable) ka::KeyedArray{Float32, 1, NamedDimsArray{(:i,), Float32, 1, Vector{Float32}}, Base.RefValue{StepRange{Date, Day}}} ar::Vector{Union{Float32, ForwardDiff.Dual}} idxs::KeyedArray{Bool, 1, NamedDimsArray{(:i,), Bool, 1, BitVector}, Base.RefValue{StepRange{Date, Day}}} Body::KeyedArray{_A, 1, _B, Base.RefValue{Vector{Date}}} where {_A, _B} 1 ─ %1 = Main.abs2::Core.Const(abs2) β”‚ %2 = Main.:-::Core.Const(-) β”‚ %3 = Base.getindex(ka, idxs)::KeyedArray{Float32, 1, NamedDimsArray{(:i,), Float32, 1, Vector{Float32}}, Base.RefValue{Vector{Date}}} β”‚ %4 = Base.getindex(ar, idxs)::Vector{Union{Float32, ForwardDiff.Dual}} β”‚ %5 = Base.broadcasted(%2, %3, %4)::Base.Broadcast.Broadcasted{AxisKeys.KeyedStyle{NamedDims.NamedDimsStyle{Base.Broadcast.DefaultArrayStyle{1}}}, Nothing, typeof(-), T} β”‚ %6 = Base.broadcasted(%1, %5)::Base.Broadcast.Broadcasted{AxisKeys.KeyedStyle{NamedDims.NamedDimsStyle{Base.Broadcast.DefaultArrayStyle{1}}}, Nothing, typeof(abs2), NT} β”‚ %7 = Base.materialize(%6)::KeyedArray{_A, 1, _B, Base.RefValue{Vector{Date}}} where {_A, _B} └── return %7

and all togheter

function all_pack(ar, ka)
    idts = idx_f(ar,ka) # this needs to be called here because, 'ar' and 'ka' 
                        #are comming from an outer loop.
    vals = ka_ar_unstable(ka, ar,idxs)
    return sum(vals)
end
@code_warntype all_pack(ar, ka)
Arguments #self#::Core.Const(all_pack) ar::Vector{Union{Float32, ForwardDiff.Dual}} ka::KeyedArray{Float32, 1, NamedDimsArray{(:i,), Float32, 1, Vector{Float32}}, Base.RefValue{StepRange{Date, Day}}} Locals vals::Any idxs::KeyedArray{_A, 1, _B, Base.RefValue{StepRange{Date, Day}}} where {_A, _B} Body::Any 1 ─ (idxs = Main.idx_f(ar, ka)) β”‚ (vals = Main.ka_ar_unstable(ka, ar, idxs)) β”‚ %3 = Main.sum(vals)::Any └── return %3

doing the indices inside leads to Base.getindex issues directly.

function ka_ar_unstable(ka, ar)
    idxs = (.!isnan.(ar .* ka))
    return abs2.(ka[idxs] .- ar[idxs])
end


@code_warntype ka_ar_unstable(ar, ka)

Arguments
#self#::Core.Const(ka_ar_unstable)
ka::Vector{Union{Float32, ForwardDiff.Dual}}
ar::KeyedArray{Float32, 1, NamedDimsArray{(:i,), Float32, 1, Vector{Float32}}, Base.RefValue{StepRange{Date, Day}}}
Locals
idxs::KeyedArray{_A, 1, _B, Base.RefValue{StepRange{Date, Day}}} where {_A, _B}
Body::Any
1 ─ %1 = Main.:!::Core.Const(!)
β”‚ %2 = Main.isnan::Core.Const(isnan)
β”‚ %3 = Base.broadcasted(Main.:, ar, ka)::Base.Broadcast.Broadcasted{AxisKeys.KeyedStyle{NamedDims.NamedDimsStyle{Base.Broadcast.DefaultArrayStyle{1}}}, Nothing, typeof(), T}
β”‚ %4 = Base.broadcasted(%2, %3)::Base.Broadcast.Broadcasted{AxisKeys.KeyedStyle{NamedDims.NamedDimsStyle{Base.Broadcast.DefaultArrayStyle{1}}}, Nothing, typeof(isnan), NT}
β”‚ %5 = Base.broadcasted(%1, %4)::Base.Broadcast.Broadcasted{AxisKeys.KeyedStyle{NamedDims.NamedDimsStyle{Base.Broadcast.DefaultArrayStyle{1}}}, Nothing, typeof(!), NT}
β”‚ (idxs = Base.materialize(%5))
β”‚ %7 = Main.abs2::Core.Const(abs2)
β”‚ %8 = Main.:-::Core.Const(-)
β”‚ %9 = Base.getindex(ka, idxs)::Any
β”‚ %10 = Base.getindex(ar, idxs)::Any
β”‚ %11 = Base.broadcasted(%8, %9, %10)::Any
β”‚ %12 = Base.broadcasted(%7, %11)::Any
β”‚ %13 = Base.materialize(%12)::Any
└── return %13

any ideas or hints that could help solve this issue would be greatly appreciated :smiley: . Maybe is something very trivial for some :slight_smile: .

pkg> status Dates ForwardDiff AxisKeys
Status `~/Project.toml`
  [94b1ba4f] AxisKeys v0.2.13
  [f6369f11] ForwardDiff v0.10.35

I think the main reason for the instability is that ForwardDiff.Dual is not a concrete type, so your pre-allocated array ar can per-se not be inferred. Note that the type-instability is gone when you just define your array as

ar = [1f0,2f0, ForwardDiff.Dual(NaN32), 2f0]

However, I guess you need to pre-allocate the array so it can be re-used. I think an option would be to use PreallocationsTools.jl as in this example GitHub - SciML/PreallocationTools.jl: Tools for building non-allocating pre-cached functions in Julia, allowing for GC-free usage of automatic differentiation in complex codes .