Iām no Haskell user, but hereās an example implementing forward mode autodiff using a recursive type (as far as I understand the term)
struct Dual{T <: Number} <: Number # Notice Dual can hold a Dual!
value::T
partial::T
end
derivative(f, x) = f(Dual(x, one(x))).partial
Base.adjoint(f::Function) = x -> derivative(f, x)
const Ļµ = Dual(false, true)
import Base: convert, promote_rule
convert(::Type{Dual{T}}, x::Dual{T}) where {T <: Number} = x
convert(::Type{Dual{T}}, x::Dual) where {T <: Number} = Dual{T}(convert(T, x.value), convert(T, x.partial))
convert(::Type{Dual{T}}, x::Number) where {T <: Number} = Dual{T}(convert(T, x), zero(T))
promote_rule(::Type{Dual{T}}, ::Type{Dual{U}}) where {T<:Number, U<:Number} = Dual{promote_type(T, U)}
promote_rule(::Type{Dual{T}}, ::Type{U}) where {T<:Number, U<:Number} = Dual{promote_type(T, U)}
promote_rule(::Type{U}, ::Type{Dual{T}}) where {T<:Number, U<:Number} = Dual{promote_type(U, T)}
Dual(x::T, y::V) where {T, V} = Dual(promote(x, y)...)
Now we can teach various functions how to deal with Dual
s
import Base: +, -
(+)(a::Dual, b::Dual) = Dual(a.value + b.value, a.partial + b.partial)
(+)(a::Dual, b::Number) = Dual(a.value + b, a.partial)
(+)(a::Number, b::Dual) = b + a
(-)(a::Dual) = Dual(-a.value, -a.partial)
(-)(a::Dual, b::Dual) = a + -b
(-)(a::Dual, b::Number) = Dual(a.value - b, a.partial)
(-)(a::Number, b::Dual) = Dual(a - b.value, -b.partial)
import Base: *, inv, /
(*)(a::Dual, b::Dual) = Dual(a.value * b.value, (a.partial * b.value) + (a.value * b.partial))
(*)(a::Dual, b::Number) = Dual(a.value * b, a.partial * b)
(*)(a::Number, b::Dual) = Dual(a * b.value, a * b.partial)
(/)(a::Dual, b::Dual) = a * inv(b)
inv(a::Dual) = Dual(1/(a.value), -a.partial/(a.value)^2)
import Base: ^, exp, log, sqrt
exp(a::Dual) = Dual(exp(a.value), a.partial * exp(a.value))
log(a::Dual) = Dual(log(a.value), a.partial/a.value)
sqrt(x::Dual) = Dual(ā(x.value), x.partial/(2*ā(x.value)))
(^)(a::Dual, b::Dual) = exp(b * log(a))
(^)(a::Dual, b::Integer) = Dual((a.value)^b, a.partial * b * (a.value)^(b-1))
Now we can test it out at the REPL:
julia> f(x) = 2x - 3^x
f (generic function with 1 method)
julia> f'(2.0)
-7.887510598012989
julia> 2 - 3^2 * log(3)
-7.887510598012987
julia> f''(2.0)
-10.86254064731324
julia> - 3^2.0 * log(3)^2
-10.862540647313239
julia> f'''(2.0)
-11.933720641295169
julia> - 3^2.0 * log(3)^3
-11.933720641295167
As you can see, this works for higher order derivatives. The reason is that we wrote the Dual
type as allowed to store any number:
struct Dual{T <: Number} <: Number
value::T
partial::T
end
and the derivative
function
derivative(f, x) = f(Dual(x, one(x))).partial
as creating a dual where the differential part is one(x)
. If x
is already a Dual{T}
, this creates a Dual{Dual{T}}
, and extracting out the partial part will yield a Dual{T}
. This allows higher order derivatives without perturbation confusion.