I started liking lens-based approach! Now that lens can change type parameters, we have “differentiable lens”!
using ForwardDiff
using Parameters: @with_kw, @unpack
using Setfield: @lens, Lens, set, get
import Setfield
@with_kw struct Coordinate{X, Y, Z}
x::X = 1.0
y::Y = 1.0
z::Z = 1.0
end
# Since set called via ForwardDiff would specify a dual number, we
# have to ignore type parameters:
Setfield.constructor_of(::Type{<: Coordinate}) = Coordinate
function f(c::Coordinate)
@unpack x, y, z = c
return x^2 + y^2 + z^2
end
derivative(f, at, wrt::Lens) =
ForwardDiff.derivative(
(x) -> f(set(wrt, at, x)),
get(wrt, at))
@assert derivative(f, Coordinate(x=1.0), @lens _.x) ≈ 2
@assert derivative(f, Coordinate(x=2.0), @lens _.x) ≈ 4
This is actually super useful for building something like bifurcation analysis tool where you need a generic way to let user specify a few parameters in a differentiable manner.