Help with derivatives of matrix functions

Hm, there might be a way do it somewhat faster in your special case rolling your own AD and using @elrod`s lazy_e:

using StaticArrays, Random, Test, BenchmarkTools

struct Dual{T<:Number, G<:Number} <: Number
    val::T
    grad::G
    Dual{T, G}(val) where {T<:Number, G<:Number} = begin
        new{T, G}(val, zero(T))
    end
    Dual{T, G}(val, grad) where {T<:Number, G<:Number} = begin
        new{T, G}(val, grad)
    end
end

Dual(val::T) where {T<:Number} = Dual{T, T}(val)
Dual(val::T, grad::G) where {T<:Number, G<:Number} = Dual{T, G}(val, grad)

Base.convert(::Type{T}, d::Dual{T, G}) where {T<:Number, G<:Number} = d.val

import Base.+
+(d1::Dual{T, G1}, d2::Dual{T, G2}) where {T<:Number, G1<:Number, G2<:Number} = Dual(d1.val + d2.val, d1.grad + d2.grad)
+(d::Dual{T, G}, t::T) where {T<:Number, G<:Number} = Dual(d.val + t, d.grad)
+(t::T, d::Dual{T, G}) where {T<:Number, G<:Number} = d + t

import Base.-
-(d::Dual{T, G}) where {T<:Number, G<:Number} = Dual(-d.val, -d.grad)
-(d1::Dual{T, G1}, d2::Dual{T, G2}) where {T<:Number, G1<:Number, G2<:Number} = d1 + (-d2)
-(d::Dual{T, G}, t::T) where {T<:Number, G<:Number} = d + (-t)
-(t::T, d::Dual{T, G}) where {T<:Number, G<:Number} = t+ (-d)

import Base.*
*(d1::Dual{T, G1}, d2::Dual{T, G2}) where {T<:Number, G1<:Number, G2<:Number} = Dual(d1.val * d2.val, d1.val * d2.grad + d2.val * d1.grad)
*(d::Dual{T, G}, t::T) where {T<:Number, G<:Number} = Dual(d.val * t, t * d.grad)
*(t::T, d::Dual{T, G}) where {T<:Number, G<:Number} = d * t

import Base.inv
inv(d::Dual{T, G}) where {T<:Number, G<:Number} = Dual(one(T) / d.val, - d.grad / (d.val * d.val))

import Base./
/(d1::Dual{T, G1}, d2::Dual{T, G2}) where {T<:Number, G1<:Number, G2<:Number} = d1 * inv(d2)
/(d::Dual{T, G}, t::T) where {T<:Number, G<:Number} = d * inv(t)
/(t::T, d::Dual{T, G}) where {T<:Number, G<:Number} = t * inv(d)

struct lazy_e{T} <: AbstractVector{T}
    i::Int
    n::Int
end
lazy_e(i::Int, n::Int, ::Type{T} = Int) where T = lazy_e{T}(i, n)
function Base.getindex(e::lazy_e{T}, i) where T
    @boundscheck @assert i <= e.n
    ifelse(i == e.i, one(T), zero(T))
end
Base.size(e::lazy_e) = (e.n,)

derivative(f, x::T) where T = f(Dual(x, one(T))).grad
partial(f, i, x::SVector{N, T}) where {N, T} = f(map((x, d) -> Dual(x, d), x, lazy_e{T}(i, N))).grad

p(i, j) = x -> x[i]^2 * x[j]^2
dp(i, j, k) = x -> partial(p(i, j), k, x)
d2p(i, j, k, l) = x -> partial(dp(i, j, k), l, x)

@btime dp(1, 1, 1)(SVector(3.0, 2.0, 1.0))
@btime d2p(1, 1, 1, 1)(SVector(3.0, 2.0, 1.0))

yielding

  1.200 ns (0 allocations: 0 bytes)
  1.300 ns (0 allocations: 0 bytes)

which looks a bit like compile time AD to me :wink:

1 Like