Zygote gradients around closures and objects with custom properties

I have the following function:

function make_automatic_xi(g_b)
    function automatic_xi(Ψ, trajectory, tlist, n)
        grad = Zygote.gradient(psi -> g_b(psi, trajectory, tlist, n), Ψ)[1]
        if isnothing(grad)
            return zero(Ψ)  # g_b does not depend on Ψ
        end
        return -0.5 * grad
    end
    return automatic_xi
end

Applying this as

function g_b(Ψ, traj, tlist, n)
    return real(dot(Ψ, traj.D * Ψ))
end
xi_auto = make_automatic_xi(g_b)
ξ1 = xi_auto(Ψ[1], trajectories[1], tlist, 1)

I get an ERROR: UndefRefError: access to undefined reference

Using g_b functions that don’t use traj work fine (but aren’t sufficiently general).

It did occur to me while I was writing this that this might be vaguely “problematic” in that I’m calling g_b with varying traj / n parameters in my program, not just varying Ψ — although for a particular derivative, traj, tlist, and n are constants.

I’ve long had other, “similar” code like this:

function make_gate_chi(J_T_U, trajectories; kwargs...)

    function zygote_gate_chi(Ψ, trajectories)
        function _J_T(U)
            -J_T_U(U; kwargs...)
        end
        N = length(trajectories)
        χ = Vector{eltype(Ψ)}(undef, N)
        # We assume that that the initial states of the trajectories are the
        # logical basis states
        U = [trajectories[i].initial_state ⋅ Ψ[j] for i = 1:N, j = 1:N]
        ∇J = Zygote.gradient(gate -> _J_T(gate), U)[1]
        for k = 1:N
            χ[k] = 0.5 * sum([∇J[i, k] * trajectories[i].initial_state for i = 1:N])
        end
        return χ
    end

    return zygote_gate_chi

end

which always seems to work fine, and the only difference I see is that whenever I call the zygote_gate_chi, it is with the same trajectories variable, which also happens to be the same trajectories variable that was originally passed to make_gate_chi; even though that’s “coincidental” (make_gate_chi only gets trajectories due to some API requirements, it’s not actually used)

So I can’t quite put my finger on what exactly the difference is from Zygote’s perspective, or what approach I should take to make the original code work.

Apologies for these code snippets not being MWEs, but I was hoping someone might have some insights into Zygote’s internals based on the general structure of the code.

Alright, I was barking up the wrong tree. Sorry for the noise! Turns out Zygote got confused because the type of the trajectory here had a custom getproperty method.

Apparently, I need a custom ChainRulesCore.rrule. A full MWE to fix this:

using Zygote: Zygote
using LinearAlgebra: dot, norm
using Random: rand
using Test

struct Trajectory{ST,GT}
    initial_state::ST
    generator::GT
    target_state::Union{Nothing,ST}
    weight::Float64
    kwargs::Dict{Symbol,Any}

    function Trajectory(
        initial_state::ST,
        generator::GT;
        target_state::Union{Nothing,ST} = nothing,
        weight = 1.0,
        kwargs...
    ) where {ST,GT}
        new{ST,GT}(initial_state, generator, target_state, weight, kwargs)
    end

end


function Base.getproperty(traj::Trajectory, name::Symbol)
    if name in (:initial_state, :generator, :target_state, :weight)
        return getfield(traj, name)
    else
        kwargs = getfield(traj, :kwargs)
        return get(kwargs, name) do
            error("type Trajectory has no property $name")
        end
    end
end


function make_automatic_xi(g_b)
    function automatic_xi(Ψ, trajectory, tlist, n)
        grad = Zygote.gradient(psi -> g_b(psi, trajectory, tlist, n), Ψ)[1]
        if isnothing(grad)
            # g_b does not depend on Ψ
            return zero(Ψ)
        end
        return -0.5 * grad
    end
    return automatic_xi
end

function g_b(Ψ, traj, tlist, n)
    return real(dot(Ψ, traj.D * Ψ))
end


function test()
    N = 4
    A = rand(ComplexF64, N, N)
    D = A * A' / N
    H = nothing
    tlist = [0.0, 1.0]
    Ψ = rand(ComplexF64, N)
    Ψ ./ norm(Ψ)
    traj = Trajectory(Ψ, H; D)
    xi = make_automatic_xi(g_b)
    ξ = xi(Ψ, traj, tlist, 1)
    @show ξ
    ξ_expected = -D * Ψ
    @test norm(ξ - ξ_expected) < 1e-14
end

# The following definitions are required to fix the problem:

using ChainRulesCore: ChainRulesCore, NoTangent

function ChainRulesCore.rrule(::typeof(getproperty), traj::Trajectory, name::Symbol)
    val = getproperty(traj, name)
    if name in (:initial_state, :generator, :target_state, :weight)
        function field_pullback(Δ)
            dt = ChainRulesCore.Tangent{typeof(traj)}(; (name => Δ,)...)
            return NoTangent(), dt, NoTangent()
        end
        return val, field_pullback
    else
        # kwargs-stored property: route gradient back into the kwargs Dict
        function kwargs_pullback(Δ)
            dkwargs = Dict{Symbol,Any}(name => Δ)
            dt = ChainRulesCore.Tangent{typeof(traj)}(; kwargs=dkwargs)
            return NoTangent(), dt, NoTangent()
        end
        return val, kwargs_pullback
    end
end


test()

Nay for Julia’s impenetrable stacktraces, and yay for Claude being able to go through them to figure out that it was getting hung up on get_property… and, after a lot of back-and-forth, writing the rrule that seems somewhat sensible.

Feels like a bit of a bug in Zygote though. I wonder if Zygote is correctly distinguishing between fields and properties.

P.S.: Followup issue: Zygote gets confused about structs with custom get_properties · Issue #1610 · FluxML/Zygote.jl · GitHub