Boundserror in Zygote custom adjoint

I’m learning how to define a custom adjoint in Zygote for a type and get a BoundsError testing the gradient for my example case. Code is as below (note that the adjoint definition is probably wrong, but the dimensions should be correct)

using LinearAlgebra: I, Tridiagonal
using Zygote: gradient, @adjoint

struct MaternSPDE{T<:Integer, U<:Real, S<:AbstractMatrix}
    d::T #Dimension of SPDE
    h::U #Discretization length
    D::S #Laplacian Differential Operator

function (m::MaternSPDE)(l::Real, σ::Real, w::AbstractVector)
    return (I - l^2 * m.D / m.h^2) \ (σ * sqrt(l^m.d) * w / m.h^m.d)

@adjoint function (m::MaternSPDE)(l::Real, σ::Real, w::AbstractVector)
    L = (I - l^2 * m.D / m.h^2)
    ld = l^m.d
    s = sqrt(ld) / m.h^m.d
    u1 = L \ (s * w)
    dudl = (-2 * l * m.D) \ (m.d / 2 * s * w / ld)
    return σ * u1, Δ -> (dudl' * Δ, u1' * Δ, L' \ Δ / s)

spde = MaternSPDE(1, 1.0, Tridiagonal(ones(100), -2*ones(101), ones(100)))
f(x,y,z) = sum(spde(x,y,z))
w = ones(101)
f(1,1,w) #works...

which gives

  at index [4]
 [1] getindex at ./tuple.jl:24 [inlined]
 [2] gradindex(::Tuple{Float64,Float64,Array{Float64,1}}, ::Int64) at /Users/jackbmuir/.julia/packages/Zygote/bdE6T/src/compiler/reverse.jl:13
 [3] f at ./REPL[8]:1 [inlined]
 [4] (::typeof(∂(f)))(::Float64) at /Users/jackbmuir/.julia/packages/Zygote/bdE6T/src/compiler/interface2.jl:0
 [5] (::getfield(Zygote, Symbol("##32#33")){typeof(∂(f))})(::Float64) at /Users/jackbmuir/.julia/packages/Zygote/bdE6T/src/compiler/interface.jl:38
 [6] gradient(::Function, ::Int64, ::Vararg{Any,N} where N) at /Users/jackbmuir/.julia/packages/Zygote/bdE6T/src/compiler/interface.jl:47
 [7] top-level scope at REPL[10]:1

Looking at the error and stacktrace [2], it seems like it is expecting a tuple with at least 4 fields, when my custom adjoint only defines 3 for the backpropagation function. Any ideas on why this is happening and how to fix it?

A simpler equivalent test case shows that Zygote seems to expect the first field of the tuple of the back function to correspond to the struct itself, so putting e.g. nothing in that field fixes the problem - i.e. the below works

struct TS

function (t::TS)(x,y)
    return t.z .* x .* y

@adjoint function (t::TS)(x,y)
    return t(x,y), Δ -> (nothing, t.z .* Δ .* y, t.z .* x .* Δ)

t = TS(5)
gradient((x,y)->sum(t(x,y)), ones(100), ones(100))

1 Like