Automatic differentiation, adjoint methods, and linear systems of equations

I am interested in differentiating things that are along the lines of what is described in section 2 of these notes where the linear systems are sparse. For several years, I have occasionally been checking out Julia’s AD ecosystem to see if there is a package that can do this sort of differentiation automatically. I haven’t yet had success. The code below shows a couple of my feeble attempts to get Zygote to do this for a very simple linear system. Am I doing something wrong, or is Zygote not ready for this yet? Is there another package that can automatically differentiate the code below? I am especially interested in the g2 example which is more representative of real problems that I am working with.

import LinearAlgebra
import SparseArrays
import Zygote

function g1(p)
	n = 3
	A = LinearAlgebra.diagm(-1=>fill(-p, n - 1), 0=>fill(2p, n), 1=>fill(-p, n - 1))
	b = ones(n)
	return (A \ b)[1]
end

@show g1(1.0)
@show g1'(1.0)#Mutating arrays is not supported

function g2(p::T) where {T}
	n = 3
	I = Int[]
	J = Int[]
	V = T[]
	for i = 1:n
		if i > 1
			push!(I, i)
			push!(J, i - 1)
			push!(V, -p)
		end
		push!(I, i)
		push!(J, i)
		push!(V, 2p)
		if i < n
			push!(I, i)
			push!(J, i + 1)
			push!(V, -p)
		end
	end
	b = ones(n)
	return (SparseArrays.sparse(I, J, V) \ b)[1]
end

@show g2(1.0)
@show g2'(1.0)#segfault
1 Like

@stevengj , those are your notes :smile:

2 Likes

@MikeInnes any insights on if/when Zygote might be able to support this?

Zygote should be able to handle this on the mutate branch (though it’ll work better if you can rewrite this in a more functional style).

If something crashes, it’s just a bug, and usually one that’s easy to fix; best to open an issue with this test case, simplified as much as you can, and I can take a look.

2 Likes

Thanks for taking a look!

I tried these on the mutate branch. g1 now gives a MethodError: no method matching *(::Zygote.Stack{Any}, ::Int64) and g2 still segfaults. I also created a g3 (see below) which reformulates g2 in a more functional style (I think), but it still segfaults. I’ll open an issue on GitHub.

function g3(p)
	n = 3
	I = [2:n; 1:n; 1:n - 1]
	J = [1:n - 1; 1:n; 2:n]
	V = [fill(-p, n - 1); fill(2p, n); fill(-p, n - 1)]
	b = ones(n)
	return (SparseArrays.sparse(I, J, V) \ b)[1]
end

@show g3(1.0)
@show g3'(1.0)#segfault on mutate branch