Custom ForwardDiff rule for ternary function

I’m trying to define a custom ForwardDiff.jl differentiation rule for the bivariate normal cumulative function BvN(h, k, r). What I did so far works for first derivatives, but it doesn’t work for mixed higher derivatives. This is what I have so far:

import ForwardDiff: Dual, value, partials, derivative, ≺, @define_ternary_dual_op
@inline function calc_BvN(h, k, r, ::Type{T}) where T
    vh = value(h)
    vk = value(k)
    vr = value(r)
    val = BvN(vh, vk, vr)
    p = ∂hBvN(vh, vk, vr) * partials(h) + ∂kBvN(vh, vk, vr) * partials(k) + ∂rBvN(vh, vk, vr) * partials(r)
    return Dual{T}(val, p)
end
@define_ternary_dual_op(
    BvN,
    calc_BvN(x, y, z, Txyz),
    calc_BvN(x, y, z, Txy),
    calc_BvN(x, y, z, Txz),
    calc_BvN(x, y, z, Tyz),
    calc_BvN(x, y, z, Tx),
    calc_BvN(x, y, z, Ty),
    calc_BvN(x, y, z, Tz),
)

derivative(h -> BvN(h, .3, .5), .6) # gives the correct result
derivative(h -> derivative(k -> BvN(h, k, .4), .3), .6) # gives 0 which is wrong

Can you explain what kind of error you get with higher order derivatives?

Thanks, @gdalle. No error is thrown, but higher order derivatives are wrong (see above). I think the problem is the line

p = ∂hBvN(vh, vk, vr) * partials(h) + ∂kBvN(vh, vk, vr) * partials(k) + ∂rBvN(vh, vk, vr) * partials(r)

because the perturbations aren’t passed to the custom partial derivative functions, but if I write

p = ∂hBvN(h, k, r) * partials(h) + ∂kBvN(h, k, r) * partials(k) + ∂rBvN(h, k, r) * partials(r)

I get the error

julia> derivative(h -> derivative(k -> BvN(h, k, .4), .3), .6)
ERROR: Cannot determine ordering of Dual tags ForwardDiff.Tag{var"#317#319", Float64} and ForwardDiff.Tag{var"#318#320"{Dual{ForwardDiff.Tag{var"#317#319", Float64}, Float64, 1}}, Float64}
Stacktrace:
 [1] partials
   @ ~/.julia/packages/ForwardDiff/UBbGT/src/dual.jl:115 [inlined]
 [2] extract_derivative
   @ ~/.julia/packages/ForwardDiff/UBbGT/src/derivative.jl:87 [inlined]
 [3] derivative(f::var"#317#319", x::Float64)
   @ ForwardDiff ~/.julia/packages/ForwardDiff/UBbGT/src/derivative.jl:14
 [4] top-level scope
   @ REPL[237]:1

I spent some time reading the documentation, other threads here and the source code of ForwardDiff.jl, but unfortunately I couldn’t figure out how to solve this…