Reusing variables with Zygote?

I’ve written a module to perform some physics computations, and have been trying to make it autodifferentiable so I can run it with stuff like Hamilonian Monte Carlo. The computation is specifically about perturbation theory, which means that the values at order n are given by the previous orders, so it is absolutely essential that the code reuses variables previously computed.

The code defines some custom structs, which are basically wrappers around regular Julia Arrays, methods for any basic operation (sums, products, FFTs) and a main function that actually solves the perturbation theory by composing a bunch of the above basic operations (plus some ancillary functions that are necessary to generate the starting quantities).

After some heavy modifications to make the code avoid using mutating operations, I managed to make all the basic operations work nicely with Zygote, and I get gradients that are pretty much spot on with their analytical value. The only thing left is to make the main function work, but I’ve hit a pretty steep wall…

This is how the main function looks like:

function pt(
    n::Integer,
    δₖL::ScalarField{Complex{T}},
    frequencies::FFTFrequencies{T};
    solver::Function=solve_pt
) where T<:Real
    # initialize delta and theta arrays
    δ = Zygote.Buffer(Vector{ScalarField{T}}(), n) #ScalarField{T}[]
    θ = Zygote.Buffer(Vector{ScalarField{T}}(), n) #ScalarField{T}[]
    𝛁δ = Zygote.Buffer(Vector{VectorField{T}}(), n-1) #VectorField{T}[]
    𝐮 = Zygote.Buffer(Vector{VectorField{T}}(), n-1)#VectorField{T}[]
    
    # solve SPT
    @inbounds for nth in 1:n
        if nth==1
            δL = do_ifft(δₖL)
            δ[nth] = δL # δ = vcat(δ, δL)
            θ[nth] = δL # θ = vcat(θ, δL)
            continue
        elseif nth==2
            𝛁δ[nth-1] = get_gradient(δₖL, frequencies) # 𝛁δ = vcat(𝛁δ, get_gradient(δₖL, frequencies))
            𝐮[nth-1] = get_velocity(δₖL, frequencies) # 𝐮 = vcat(𝐮, get_velocity(δₖL, frequencies))
        else
            𝛁δ[nth-1] = get_gradient(δ[nth-1], frequencies) # 𝛁δ = vcat(𝛁δ, get_gradient(δ[nth-1], frequencies))
            𝐮[nth-1] = get_velocity(θ[nth-1], frequencies) # 𝐮 = vcat(𝐮, get_velocity(θ[nth-1], frequencies))
        end
            
        solver(δ, θ, 𝛁δ, 𝐮, nth, frequencies) # updates δ[nth] and θ[nth] with all the quantities at m<nth
        # δ[nth] = δn # δ = vcat(δ, δn)
        # θ[nth] = θn # θ = vcat(θ, θn)
    end

    return copy(δ), copy(θ), copy(𝛁δ), copy(𝐮)
end

I call this like so in my test function:

# [...]
δ, _, _, _ = gspt.pt(3, WδₖL, frequencies)
model = δ[1]+δ[2]+δ[3]
# [...]

however the this fails with the rather unuseful error:

Need an adjoint for constructor ScalarFieldCore{Float32}. Gradient is of type Tangent{Any, @NamedTuple{L::ZeroTangent, Ng::ZeroTangent, S::Array{Float32, 3}, is_fourier::ZeroTangent}}

Where the stack trace indicates some operation within solve_pt (commenting this operation just replicates the error in another operation above it, all the way up to that do_ifft in pt; it seems like the error is moving upstream through the code).

I tried to specify the adjoint but Zygote seems to ignore it.

Interestingly, if I try to define model=δ[i] (i being any element of δ), then the gradient of the test function works perfectly.

I created a minimal version of pt to see where the issue stems from:

function pt2(
    δₖL::ScalarField{Complex{T}}
) where T<:Real
    δ = Zygote.Buffer(Vector{ScalarField{T}}(), 1) #2)
    
    δL = do_ifft(δₖL)
    δ[1] = δL
    # δ[2] = δL

    return copy(δ)
end

This works, however if uncomment the above, it fails with a similar error. This seems to suggest that somehow making copies of variables messes up Zygote, but I’m really not sure how to overcome this.

The last code box at the end of Limitations · Zygote seems like a hint on a solution for this, but I can’t seem to figure it out.