I’m learning how to define a custom adjoint in Zygote for a type and get a BoundsError testing the gradient for my example case. Code is as below (note that the adjoint definition is probably wrong, but the dimensions should be correct)
using LinearAlgebra: I, Tridiagonal
using Zygote: gradient, @adjoint
struct MaternSPDE{T<:Integer, U<:Real, S<:AbstractMatrix}
d::T #Dimension of SPDE
h::U #Discretization length
D::S #Laplacian Differential Operator
end
function (m::MaternSPDE)(l::Real, σ::Real, w::AbstractVector)
return (I - l^2 * m.D / m.h^2) \ (σ * sqrt(l^m.d) * w / m.h^m.d)
end
@adjoint function (m::MaternSPDE)(l::Real, σ::Real, w::AbstractVector)
L = (I - l^2 * m.D / m.h^2)
ld = l^m.d
s = sqrt(ld) / m.h^m.d
u1 = L \ (s * w)
dudl = (-2 * l * m.D) \ (m.d / 2 * s * w / ld)
return σ * u1, Δ -> (dudl' * Δ, u1' * Δ, L' \ Δ / s)
end
spde = MaternSPDE(1, 1.0, Tridiagonal(ones(100), -2*ones(101), ones(100)))
f(x,y,z) = sum(spde(x,y,z))
w = ones(101)
f(1,1,w) #works...
gradient(f,1,1,w)
which gives
at index [4]
Stacktrace:
[1] getindex at ./tuple.jl:24 [inlined]
[2] gradindex(::Tuple{Float64,Float64,Array{Float64,1}}, ::Int64) at /Users/jackbmuir/.julia/packages/Zygote/bdE6T/src/compiler/reverse.jl:13
[3] f at ./REPL[8]:1 [inlined]
[4] (::typeof(∂(f)))(::Float64) at /Users/jackbmuir/.julia/packages/Zygote/bdE6T/src/compiler/interface2.jl:0
[5] (::getfield(Zygote, Symbol("##32#33")){typeof(∂(f))})(::Float64) at /Users/jackbmuir/.julia/packages/Zygote/bdE6T/src/compiler/interface.jl:38
[6] gradient(::Function, ::Int64, ::Vararg{Any,N} where N) at /Users/jackbmuir/.julia/packages/Zygote/bdE6T/src/compiler/interface.jl:47
[7] top-level scope at REPL[10]:1
Looking at the error and stacktrace [2], it seems like it is expecting a tuple with at least 4 fields, when my custom adjoint only defines 3 for the backpropagation function. Any ideas on why this is happening and how to fix it?