How to define custom AD rule in ForwardDiff

How to define a custom AD rule in ForwardDiff using DiffRules. I have tried this:

using DiffRules
using ForwardDiff

f(xx) = NaN

DiffRules.@define_diffrule Main.f(xx) = :($xx[1] + $xx[2], $xx[1] - $xx[2])

julia> ForwardDiff.gradient(f, [1, 20])
2-element Vector{Float64}:
 0.0
 0.0

as you can see, it doesn’t pick up the rule. This one also doesn’t work:

g(x) = NaN

DiffRules.@define_diffrule Main.g(x) = :(5)

julia> ForwardDiff.derivative(g, 1)
0.0

ForwardDiff only reads the rules from DiffRules once, it doesn’t monitor for changes. So you probably need to just write methods for g(::Dual) yourself.

2 Likes

Should I load DiffRules and define the rule before loading ForwardDiff? I’m trying to understand how it works.
EDIT: That doesn’t work either.

This also doesn’t work:

using ForwardDiff
using DualNumbers

g(x) = NaN
g(::Dual) = Dual(NaN, 5)

julia> ForwardDiff.derivative(g, 1)
0.0

Can you provide a short example?

Got it to work in the scalar case:

using ForwardDiff
const fd = ForwardDiff

g(x) = NaN
Main.g(::fd.Dual{T}) where {T} = fd.Dual{T}(NaN, 5)

julia> fd.derivative(g, 1)
5.0

but I’m having trouble defining custom gradients:

f(x) = x[1] + x[2]
Main.f(x::Vector{fd.Dual{T}}) where {T} = fd.Dual{T}(f(value(x)), [50., 50.])

julia> fd.gradient(f, [1, 1])
2-element Vector{Int64}:
 1
 1

You’ve preserved the tag which is good, the first type parameter, but I think you need to allow for multiple partials:

julia> using ForwardDiff: Dual, partials, jacobian

julia> g(x::Dual{Z,T,N}) where {Z,T,N} = Dual{Z}(NaN, ntuple(i -> 5 * partials(x,i), N))
g (generic function with 2 methods)

julia> jacobian(x -> g.(x), [1,2])  # previously an error
2×2 Matrix{Float64}:
 5.0  0.0
 0.0  5.0

If you want to call gradient on f, it must return a scalar, one Dual number. But x::Vector{fd.Dual{T}} is a vector of abstract type, you want f(x::AbstractVector{Dual{Z,T,N}}) where ... = . But defining such rules is unusual, normally it is enough to treat whatever scalar functions are called. (If in fact f returns a vector, this method must return a vector of dual numbers.)

2 Likes

I see. Going back to the first post, just to complete the learning round, how can I define custom rules with DiffRules that can be used with ForwardDiff?

Perhaps more importantly for my application, how can I define custom rules for functions of many separate variables. For example:

using SpecialFunctions
using ForwardDiff: derivative, gradient, partials, value, Dual

SpecialFunctions.gamma(a, t::Dual{T}) where {T} = Dual{T}(gamma(a, value(t)), -exp(-value(t)) * value(t)^(a - 1))
SpecialFunctions.gamma(a::Dual{T}, t) where {T} = Dual{T}(gamma(value(a), t), NaN)

Is that how?

Yes, roughly, but you will probably need a method for Dual, Dual too.

For SpecialFunctions, it would be ideal to make a pull request to DiffRules. There are some rules here:
https://github.com/JuliaDiff/DiffRules.jl/blob/master/src/rules.jl#L65
And some here, although not for ForwardDiff:
https://github.com/JuliaMath/SpecialFunctions.jl/blob/master/src/chainrules.jl#L40

I don’t think it’s easy to use DiffRules to define rules for local use. Maybe ForwardDiff could be modified to know when extras have been defined & update itself, but there is no mechanism now. This loop runs once, I guess when the package is being precompiled:

https://github.com/JuliaDiff/ForwardDiff.jl/blob/master/src/dual.jl#L390

2 Likes

I will make that PR on DiffRules, seems very easy. ChainRules doesn’t seem to need it. I am misunderstanding something:

using SpecialFunctions
using ForwardDiff: derivative, gradient, partials, value, Dual

SpecialFunctions.gamma(a::Number, t::Dual{T, V, 1} where {V}) where {T} = Dual{T, V, 1}(gamma(a, value(t)), -exp(-value(t)) * value(t)^(a - 1))

julia> derivative(t -> gamma(1, t), 10)
ERROR: MethodError: gamma(::Int64, ::Dual{ForwardDiff.Tag{var"#5#6", Int64}, Int64, 1}) is ambiguous. Candidates:
  gamma(a::Integer, x::Number) in SpecialFunctions at /Users/amrods/.julia/packages/SpecialFunctions/tqwrL/src/gamma_inc.jl:1042
  gamma(a::Number, t::Dual{T, V, 1} where V) where T in Main at /Users/amrods/test.jl:93
Possible fix, define
  gamma(::Integer, ::Dual{T, V, 1} where V) where T