I am trying to optimise a function f: \mathbb{R}^n \to \mathbb{R} which internally solves a linear problem, informally something like f(w) = g(x) where x solves A(w) x= b(x).
To do this I am using LinearSolve and automatic differentiation of f via ForwardDiffand trying to leverage the fact that A(w) is sparse,
Unfortunately, when I try to solve the LinearProblem with
A <: SparseMatrixCSC{ForwardDiff.Dual{Nothing, Float64}}
b <: Vector{ForwardDiff.Dual{Nothing, Float64}}
I get a ERROR: BoundsError: attempt to access Tuple{} at index [1]. Here is a MWE. It is admittedly quite silly, but I hope the problem comes across.
using SparseArrays
using LinearSolve
using ForwardDiff
A = ForwardDiff.Dual.(sprand(4, 4, 0.75))
b = ForwardDiff.Dual.(rand(4))
prob = LinearProblem(A, b)
sol = solve(prob, KLUFactorization()) # `ERROR: BoundsError: attempt to access Tuple{} at index [1]`
Am I misunderstanding something? Is there a “proper” way of setting up such an optimisation problem?
these all only have primal parts, which is a case that I did not consider when writing the overloads apparently… but I’m assuming that’s not what you meant to do?
It’s not too hard to implement the derivative yourself, and it gives you a lot more control over the solvers in a sparse problem. AD through sparse solvers is often a bit fragile or non-existent (even Jax only has an “experimental” sparse module).
The gradient \nabla_w f is given by, using backpropagation/reverse-mode/adjoint-method:
That is, you solve one extra sparse system with A^T (which is just as easy as Ax=b, and even easier if you already have a sparse factorization of A), and then compute some dot products to get the entries of \nabla_w f . In many practical applications, \frac{\partial A}{\partial w_k } is extremely sparse, often O(1) nonzeros, so the dot products with v^T are cheap. See also the course notes for our matrix-calculus course at MIT or my older notes on adjoint methods.
If w has many components n \gg 1, then forward-mode (ala ForwardDiff) will be a bad idea, because it will correspond to solving n linear systems and not just two.
Even if you end up using AD, it’s good to know what it’s doing under the hood.
In fact, having no partials is the culprit here.
You should define:
A = ForwardDiff.Dual{Nothing, Float64,1}.(sprand(4, 4, 0.75))
b = ForwardDiff.Dual{Nothing, Float64,1}.(rand(4))
and try again. The algorithm described by @stevengj recently been implemented in LinearSolve (1) and is dispatched onto automatically. It only seems to fail if there are zero dual parts. An alternative is to use SparspakFactorization() which does all calculations in dual numbers and in this case also works without the dual parts.
Are you sure? If you add dual parts, that would normally correspond to forward-mode differentiation, whereas what I described is reverse mode.
Sounds slow compared to the reverse-mode algorithm in this case, which can use the ordinary sparse factorization code on ordinary floating-point numbers (and you only need to factorize once, because you can re-use the factorization for the adjoint problem).
The OP said they are differentiating a scalar-valued function of n inputs, and presumably n > 1, maybe \gg 1. Dual numbers (i.e. forward mode) is usually a poor fit for this for this: Sparse factorization with dual numbers will probably be at leastn times more expensive than factorization with Float64 (and require \approx n\times more memory).
Just to be clear here, both forward mode and reverse mode ADs have overloads with LinearSolve.jl so that they do not differentiate the solver and instead directly use the linear solver “the right way”.
The issue here is that the Dual numbers have 0 partials and a Nothing tag. That’s an ill-formed dual number and not something normal autodiff would encounter, so it doesn’t really make sense to test?
If you have a large number of parameters using forward-mode AD (like ForwardDiff) is going to drag down performance because you’ll end up having to solve linear systems.
A much better approach is the adjoint method—which is exactly what reverse-mode AD is doing behind the scenes anyway. All you really need to do is solve one extra “adjoint” equation
Once you solve for it, getting the rest of the gradient just comes down to a few cheap dot products. It’s incredibly efficient because you’re only solving one extra sparse system—which is practically free if you already have a sparse factorization. Plus, in most real-world applications, those parameter sensitivities are so sparse (often just O(1) nonzeros) that the dot products cost almost nothing.