Automatic differentiation of function using `LinearProblem`

I am trying to optimise a function f: \mathbb{R}^n \to \mathbb{R} which internally solves a linear problem, informally something like f(w) = g(x) where x solves A(w) x= b(x).

To do this I am using LinearSolve and automatic differentiation of f via ForwardDiffand trying to leverage the fact that A(w) is sparse,

Unfortunately, when I try to solve the LinearProblem with

A <: SparseMatrixCSC{ForwardDiff.Dual{Nothing, Float64}}
b <: Vector{ForwardDiff.Dual{Nothing, Float64}}

I get a ERROR: BoundsError: attempt to access Tuple{} at index [1]. Here is a MWE. It is admittedly quite silly, but I hope the problem comes across.

using SparseArrays
using LinearSolve
using ForwardDiff

A = ForwardDiff.Dual.(sprand(4, 4, 0.75))
b = ForwardDiff.Dual.(rand(4))

prob = LinearProblem(A, b)
sol = solve(prob, KLUFactorization()) #  `ERROR: BoundsError: attempt to access Tuple{} at index [1]`

Am I misunderstanding something? Is there a “proper” way of setting up such an optimisation problem?

Looks like a bug to me. Please file an issue on the LinearSolve repo.

But also,
A = ForwardDiff.Dual.(sprand(4, 4, 0.75))
creates a sparse matrix with Dual numbers that don’t have any Dual parts,

4Ă—4 SparseMatrixCSC{ForwardDiff.Dual{Nothing, Float64, 0}, Int64} with 11 stored entries:
         â‹…                Dual{Nothing}(0.374878)  Dual{Nothing}(0.0699864)  Dual{Nothing}(0.297709)
         â‹…                Dual{Nothing}(0.205259)          â‹…                         â‹…         
 Dual{Nothing}(0.164573)  Dual{Nothing}(0.54775)   Dual{Nothing}(0.952552)   Dual{Nothing}(0.434972)
 Dual{Nothing}(0.188322)  Dual{Nothing}(0.493789)  Dual{Nothing}(0.611966)           â‹…         

these all only have primal parts, which is a case that I did not consider when writing the overloads apparently… but I’m assuming that’s not what you meant to do?

I assume you mean b(w), not b(x).

It’s not too hard to implement the derivative yourself, and it gives you a lot more control over the solvers in a sparse problem. AD through sparse solvers is often a bit fragile or non-existent (even Jax only has an “experimental” sparse module).

The gradient \nabla_w f is given by, using backpropagation/reverse-mode/adjoint-method:

\frac{\partial f}{\partial w_k} = v^T \left(\frac{\partial b}{\partial w_k} - \frac{\partial A}{\partial w_k }x \right)

where v solves the “adjoint” equation

A^T v = \nabla_x g

That is, you solve one extra sparse system with A^T (which is just as easy as Ax=b, and even easier if you already have a sparse factorization of A), and then compute some dot products to get the entries of \nabla_w f . In many practical applications, \frac{\partial A}{\partial w_k } is extremely sparse, often O(1) nonzeros, so the dot products with v^T are cheap. See also the course notes for our matrix-calculus course at MIT or my older notes on adjoint methods.

If w has many components n \gg 1, then forward-mode (ala ForwardDiff) will be a bad idea, because it will correspond to solving n linear systems and not just two.

Even if you end up using AD, it’s good to know what it’s doing under the hood.

2 Likes