ForwardDiff only reads the rules from DiffRules once, it doesn’t monitor for changes. So you probably need to just write methods for g(::Dual) yourself.
You’ve preserved the tag which is good, the first type parameter, but I think you need to allow for multiple partials:
julia> using ForwardDiff: Dual, partials, jacobian
julia> g(x::Dual{Z,T,N}) where {Z,T,N} = Dual{Z}(NaN, ntuple(i -> 5 * partials(x,i), N))
g (generic function with 2 methods)
julia> jacobian(x -> g.(x), [1,2]) # previously an error
2×2 Matrix{Float64}:
5.0 0.0
0.0 5.0
If you want to call gradient on f, it must return a scalar, one Dual number. But x::Vector{fd.Dual{T}} is a vector of abstract type, you want f(x::AbstractVector{Dual{Z,T,N}}) where ... = . But defining such rules is unusual, normally it is enough to treat whatever scalar functions are called. (If in fact f returns a vector, this method must return a vector of dual numbers.)
I see. Going back to the first post, just to complete the learning round, how can I define custom rules with DiffRules that can be used with ForwardDiff?
I don’t think it’s easy to use DiffRules to define rules for local use. Maybe ForwardDiff could be modified to know when extras have been defined & update itself, but there is no mechanism now. This loop runs once, I guess when the package is being precompiled:
I will make that PR on DiffRules, seems very easy. ChainRules doesn’t seem to need it. I am misunderstanding something:
using SpecialFunctions
using ForwardDiff: derivative, gradient, partials, value, Dual
SpecialFunctions.gamma(a::Number, t::Dual{T, V, 1} where {V}) where {T} = Dual{T, V, 1}(gamma(a, value(t)), -exp(-value(t)) * value(t)^(a - 1))
julia> derivative(t -> gamma(1, t), 10)
ERROR: MethodError: gamma(::Int64, ::Dual{ForwardDiff.Tag{var"#5#6", Int64}, Int64, 1}) is ambiguous. Candidates:
gamma(a::Integer, x::Number) in SpecialFunctions at /Users/amrods/.julia/packages/SpecialFunctions/tqwrL/src/gamma_inc.jl:1042
gamma(a::Number, t::Dual{T, V, 1} where V) where T in Main at /Users/amrods/test.jl:93
Possible fix, define
gamma(::Integer, ::Dual{T, V, 1} where V) where T