Taking a derivative of nested object using lens from Setfield.jl


#1

This is my response to @ChrisRackauckas’s comment:

in New Package PseudoArcLengthContinuation. I don’t want to hijack the package announcement topic of @rveltz’s awesome package PseudoArcLengthContinuation.jl. So I’m opening a new topic.


Here is a toy example showing that you can take derivative w.r.t a “deep location” in a nested Julia object.

using Setfield
using Parameters
import ForwardDiff

xyz = (x = 1.0, y = 1.0, z = 1.0)
pair = (xyz, (x = -1.0, y = -1.0, z = -1.0))

function f(xyz)
    @unpack x, y, z = xyz
    return x^2 + y^2 + z^2
end

function g(pair)
    a, b = pair
    return f((x = a.x - b.x,
              y = a.y - b.y,
              z = a.z - b.z))
end

# Basic lens usage:
lens_x = @lens _.x
lens_1 = @lens _[1]
lens_2 = @lens _[2]

@assert pair[1].x ==
    get(pair, @lens _[1].x) ==
    get(pair, lens_1 ∘ lens_x)

@assert set(pair, lens_1 ∘ lens_x, "new value") ==
    ((x = "new value", y = 1.0, z = 1.0), pair[2])

# We can take a derivative w.r.t. a lens:
derivative(f, at, wrt::Lens) =
    ForwardDiff.derivative(
        (x) -> f(set(at, wrt, x)),
        get(at, wrt))

@assert derivative(f, xyz, lens_x) ≈ 2
@assert derivative(f, (@set xyz.x = 2.0), lens_x) ≈ 4

@assert derivative(g, pair, lens_1 ∘ lens_x) ≈ 4
@assert derivative(g, pair, lens_2 ∘ lens_x) ≈ -4

The last examples show that we can take derivative w.r.t. a nested value. They also show that how to specify the “location” of the nested value can be composed, like lens_1 ∘ lens_x.

Some detail: I said “virtually any Julia object” but actually that was exaggeration (sorry). The locations specified by the lenses have to be able to change the corresponding type parameter for this to work. This is because you need to temporary store Dual instead of (say) Float64. Also, in terms of performance, those objects probably have to be immutable to avoid heap allocation.


Now imagine that you have some ODEs and also a composite one which couples them in some way (omitting type parameters):

struct ModelA
    alpha
    beta
    ...
end

struct ModelB
    ...
end

struct CoupledModel
    A::ModelA
    B::ModelB
end

Then you can specify the bifurcation parameter for ModelA like @lens _.alpha and for CoupledModel like @lens _.A.alpha. Functions working with CoupledModel can also be blind to what bifurcation parameter (axis) is used for ModelA. It can receive @lens _.alpha or @lens _.beta and then just post-compose it with @lens _.A.

There are likely other ways to do similar things. But I find that lens from Setfield.jl provides a very elegant way to express bifurcation parameter (or in general the “axes” along which you want to change something or compute the derivative).