ReverseDiff.GradientTape fails for backslash operator when symmetric sparse matrix involved?

Hi!
I am using ReverseDiff to calculate the gradient of a scalar function with respect to its vector input.
A simplified piece of code is attached in the end. It’s observed the ReverseDiff.GradientTape(loss_rd, x) breaks when K in the example in symmetric with an error message (truncated) saying “ERROR: LoadError: MethodError: no method matching lu!(::SparseMatrixCSC{ReverseDiff.TrackedReal{Float64,Float64,Nothing},Int64}, ::Val{true}; check=true)”, but it works when K is nonsymmetric. I am not sure if I did something wrong in the code. Could someone help me with this? Thanks.

using ReverseDiff, SparseArrays
function loss_rd(x::AbstractArray{T,1}) where{T}
    f = ones(T, 3)
    K = sparse(I, J, x)
    u = K \ f
    return sum(abs2, u)
end


I = [1, 1, 2, 3, 3] # case that does not work
J = [1, 3, 2, 3, 1]
x = [3.0, 1.0, 3.0, 3.0, 1.0]


#I = [1, 1, 2, 3] # case that works fine
#J = [1, 3, 2, 3]
#x = [3.0, 1.0, 3.0, 3.0]
#
println("loss_rd=", loss_rd(x))

const f_tape = ReverseDiff.GradientTape(loss_rd, x) ###get stuck at this line
const compiled_f_tape = ReverseDiff.compile(f_tape)
dldx=similar(x)
ReverseDiff.gradient!(dldx, compiled_f_tape, x)
println("dl/dx=",dldx)

ReverseDiff doesn’t support many array/matrix types which is the issue here. Zygote might handle this case better.

3 Likes

Thank you for your help, @ChrisRackauckas. I actually started with Zygote and then saw the mutation issue in the original code.
I tried Zygote for the simplified code here and got the error message, ERROR: LoadError: Need an adjoint for constructor SparseMatrixCSC{Float64,Int64}. Gradient is of type Array{Float64,2}. I guess I need to define an adjoint for sparse(I,J,x)? After reading the custom adjoint session in Zygote, I get the general idea about how to do this for scalar functions but not for sparse matrix. Could someone point me to some references? Thanks.

Open an issue on Zygote. Looks like it just needs a constructor adjoint. You define those like https://fluxml.ai/Zygote.jl/dev/adjoints/#Custom-Types-1

Thanks, Chris! I just opened an issue, https://github.com/FluxML/Zygote.jl/issues/742.