ForwardDiff through StatsBase's fit(Histogram, ...)

I have code that is trying to autodifferentiate through the histogram fitting of StatsBase using ForwardDiff, any tips on how to get this to work? I’m hitting TwicePrecision and it’s trying to convert to Float64.

Here’s a MWE:

using ForwardDiff
using StatsBase
import LinearAlgebra: normalize

testparams = [1.3, 0.3]
testsample = params[1] + params[2]*randn()

function pdf(h::Histogram, x::Real)
    xpos = searchsortedfirst(h.edges[1], x)
    return h.weights[xpos]
end

function testloss(p::AbstractVector{T}) where T<:Real
    a, b = p
    x = a .+ randn(T, 10000)
    hist_x = normalize(fit(Histogram, x; nbins=100), mode=:pdf)
    return -log(pdf(hist_x, testsample))
end

ForwardDiff.gradient(testloss, testparams)

Which returns

julia> ForwardDiff.gradient(testloss, testparams)
ERROR: MethodError: no method matching Float64(::Dual{ForwardDiff.Tag{typeof(testloss), Float64}, Float64, 2})
Closest candidates are:
  (::Type{T})(::Real, ::RoundingMode) where T<:AbstractFloat at rounding.jl:200
  (::Type{T})(::T) where T<:Number at boot.jl:772
  (::Type{T})(::AbstractChar) where T<:Union{AbstractChar, Number} at char.jl:50
  ...
Stacktrace:
  [1] Base.TwicePrecision{Float64}(x::Dual{ForwardDiff.Tag{typeof(testloss), Float64}, Float64, 2})                                   
    @ Base ./twiceprecision.jl:267
  [2] TwicePrecision
    @ ./twiceprecision.jl:225 [inlined]
  [3] histrange(lo::Dual{ForwardDiff.Tag{typeof(testloss), Float64}, Float64, 2}, hi::Dual{ForwardDiff.Tag{typeof(testloss), Float64}, Float64, 2}, n::Int64, closed::Symbol)         
    @ StatsBase ~/.julia/packages/StatsBase/XgjIN/src/hist.jl:99
  [4] histrange(v::Vector{Dual{ForwardDiff.Tag{typeof(testloss), Float64}, Float64, 2}}, n::
Int64, closed::Symbol)                                                                     
    @ StatsBase ~/.julia/packages/StatsBase/XgjIN/src/hist.jl:39
  [5] #145
    @ ~/.julia/packages/StatsBase/XgjIN/src/hist.jl:105 [inlined]
  [6] map
    @ ./tuple.jl:246 [inlined]
  [7] histrange
    @ ~/.julia/packages/StatsBase/XgjIN/src/hist.jl:104 [inlined]
  [8] fit(::Type{Histogram{Int64}}, vs::Tuple{Vector{Dual{ForwardDiff.Tag{typeof(testloss), Float64}, Float64, 2}}}; closed::Symbol, nbins::Int64)
    @ StatsBase ~/.julia/packages/StatsBase/XgjIN/src/hist.jl:356
  [9] #fit#156
    @ ~/.julia/packages/StatsBase/XgjIN/src/hist.jl:300 [inlined]
 [10] #fit#166
    @ ~/.julia/packages/StatsBase/XgjIN/src/hist.jl:407 [inlined]
 [11] testloss(p::Vector{Dual{ForwardDiff.Tag{typeof(testloss), Float64}, Float64, 2}})
  ...

Here’s the code from StatsBase (at StatsBase.jl/hist.jl at 43880cf37cc2b034a3e3d9aa34a255e34d648e21 · JuliaStats/StatsBase.jl · GitHub)

function histrange(lo::F, hi::F, n::Integer, closed::Symbol=:left) where F
    if hi == lo
        start = F(hi)
        step = one(F)
        divisor = one(F)
        len = one(F)
    else
        bw = (F(hi) - F(lo)) / n
        lbw = log10(bw)
        if lbw >= 0
            step = exp10(floor(lbw))
            r = bw / step
            if r <= 1.1
                nothing
            elseif r <= 2.2
                step *= 2
            elseif r <= 5.5
                step *= 5
            else
                step *= 10
            end
            divisor = one(F)
            start = step*floor(lo/step)
            len = ceil((hi - start)/step)
        else
            divisor = exp10(-floor(lbw))
            r = bw * divisor
            if r <= 1.1
                nothing
            elseif r <= 2.2
                divisor /= 2
            elseif r <= 5.5
                divisor /= 5
            else
                divisor /= 10
            end
            step = one(F)
            start = floor(lo*divisor)
            len = ceil(hi*divisor - start)
        end
    end
    # fix up endpoints
    if closed == :right #(,]
        while lo <= start/divisor
            start -= step
        end
        while (start + (len-1)*step)/divisor < hi
            len += one(F)
        end
    else
        while lo < start/divisor
            start -= step
        end
        while (start + (len-1)*step)/divisor <= hi
            len += one(F)
        end
    end
    StepRangeLen(Base.TwicePrecision{Float64}((start, divisor)),
                 Base.TwicePrecision{Float64}((step, divisor)),
                 Int(len))
end