I am trying to optimise a function f: \mathbb{R}^n \to \mathbb{R} which internally solves a linear problem, informally something like f(w) = g(x) where x solves A(w) x= b(x).
To do this I am using LinearSolve and automatic differentiation of f via ForwardDiffand trying to leverage the fact that A(w) is sparse,
Unfortunately, when I try to solve the LinearProblem with
A <: SparseMatrixCSC{ForwardDiff.Dual{Nothing, Float64}}
b <: Vector{ForwardDiff.Dual{Nothing, Float64}}
I get a ERROR: BoundsError: attempt to access Tuple{} at index [1]. Here is a MWE. It is admittedly quite silly, but I hope the problem comes across.
using SparseArrays
using LinearSolve
using ForwardDiff
A = ForwardDiff.Dual.(sprand(4, 4, 0.75))
b = ForwardDiff.Dual.(rand(4))
prob = LinearProblem(A, b)
sol = solve(prob, KLUFactorization()) # `ERROR: BoundsError: attempt to access Tuple{} at index [1]`
Am I misunderstanding something? Is there a “proper” way of setting up such an optimisation problem?
Looks like a bug to me. Please file an issue on the LinearSolve repo.
But also,
A = ForwardDiff.Dual.(sprand(4, 4, 0.75))
creates a sparse matrix with Dual numbers that don’t have any Dual parts,
4Ă—4 SparseMatrixCSC{ForwardDiff.Dual{Nothing, Float64, 0}, Int64} with 11 stored entries:
â‹… Dual{Nothing}(0.374878) Dual{Nothing}(0.0699864) Dual{Nothing}(0.297709)
â‹… Dual{Nothing}(0.205259) â‹… â‹…
Dual{Nothing}(0.164573) Dual{Nothing}(0.54775) Dual{Nothing}(0.952552) Dual{Nothing}(0.434972)
Dual{Nothing}(0.188322) Dual{Nothing}(0.493789) Dual{Nothing}(0.611966) â‹…
these all only have primal parts, which is a case that I did not consider when writing the overloads apparently… but I’m assuming that’s not what you meant to do?
I assume you mean b(w), not b(x).
It’s not too hard to implement the derivative yourself, and it gives you a lot more control over the solvers in a sparse problem. AD through sparse solvers is often a bit fragile or non-existent (even Jax only has an “experimental” sparse module).
The gradient \nabla_w f is given by, using backpropagation/reverse-mode/adjoint-method:
\frac{\partial f}{\partial w_k} = v^T \left(\frac{\partial b}{\partial w_k} - \frac{\partial A}{\partial w_k }x \right)
where v solves the “adjoint” equation
A^T v = \nabla_x g
That is, you solve one extra sparse system with A^T (which is just as easy as Ax=b, and even easier if you already have a sparse factorization of A), and then compute some dot products to get the entries of \nabla_w f . In many practical applications, \frac{\partial A}{\partial w_k } is extremely sparse, often O(1) nonzeros, so the dot products with v^T are cheap. See also the course notes for our matrix-calculus course at MIT or my older notes on adjoint methods.
If w has many components n \gg 1, then forward-mode (ala ForwardDiff) will be a bad idea, because it will correspond to solving n linear systems and not just two.
Even if you end up using AD, it’s good to know what it’s doing under the hood.
2 Likes