What version of LinearSolve? Recent versions of LinearSolve actually use a Actually that’s probably not relevant, since you’re passing the linsolve directly.DualLinearCache
if A
or b
contain Dual numbers.
The GenericLUFactorization is the Base.LinearAlgebra.generic_lu!
which does not have pivoting. It’s only for small cases. I’d really recommend just using RecursiveFactorization.jl (i.e. RFLUFactorization.jl)
if you need a generic algorithm.
Though as Jadon mentioned, the latest versions of LinearSolve has overloads for ForwardDiff, Zygote, and Enzyme Reverse so they won’t differentiate the solver. So if it’s dual numbers, you can even use MKL/CUDA just fine.
But static arrays hit a specific static arrays only dispatch in LinearSolve (LinearSolve.jl/src/common.jl at main · SciML/LinearSolve.jl · GitHub), because static arrays has special overloads that are non-allocating and don’t allow mutation. LinearSolve.jl/ext/LinearSolveForwardDiffExt.jl at main · SciML/LinearSolve.jl · GitHub I don’t think we tested whether those will be hit in the context of static arrays, and my guess is that the issue is that they might not. We should make sure to test and cover that case.
If they’re passing a LinearCache
to solve!
, even if A
or b
are Dual, it won’t go through the overloads. The overloads are only hit if a DualLinearCache
is passed to solve!
.
StaticArrays should work, the primal problem that actually goes through the solver will be a StaticLinearProblem
. But yes, it’s untested so I’ll add some tests.
I’m not an expert on the Caching API for DifferentiationInterface, but I think Cache(::LinearCache)
might not be valid. See the warning on the docs here: API · DifferentiationInterface.jl.
" Some backends require any Cache
context to be an AbstractArray
, others accept nested (named) tuples of AbstractArray
s."
Note that I am not currently using StaticArrays (but instead MVector
and co), so from what I can see I should not hit those methods.
I have tried RFLUFactorization
, but like other methods using LU I have not found a way to pass the LinearSolve.init
cache to DiffererentiationInterface.jacobian!
. In particular even if I manually (though in the most naive way) define recursive_similar
via type piracy for the involved structs (very much going against the warning from the docs that JClugstor mentioned), I still get warnings of the form
MethodError: no method matching lu!(::MMatrix{3, 3, ForwardDiff.Dual{…}, 9}, ::Vector{ForwardDiff.Dual{…}}, ::Val{true}, ::Val{true}; check::Bool)
The DualLinearCache
(and in general the machinery defined in LinearSolveForwardDiffExt.jl
) you both mentioned seems interesting. Is this something I might want to wrangle manually (using something like Base.get_extension(LinearSolve, :LinearSolveForwardDiffExt).DualLinearCache
)? I assume all of this would require me to abandon using DifferentiationInterface
anyway.
The general issue is that DI.Cache
need to allow Dual
number computations when used with ForwardDiff.jl. That is why the internal function DI.recursive_similar
allocates new storage for the wrapped argument, giving each array the right eltype
. To avoid having to define a walk through general structs, I only define it for arrays and tuples thereof (here). Doing it more generally would also require a distinction between differentiable and non-differentiable types (e.g. Int
vs Float64
), and we quickly get into murky design decisions that I don’t feel DI should have to make.
This explains why you can’t create a DI.Cache
from a more sophisticated structure like those in LinearSolve.jl or FastLapackInterface.jl. In those advanced cases, I guess PreallocationTools.jl is the right fit for now, and it should make DI.Cache
work seamlessly because the underlying storage will be adaptive. Otherwise you could also commit type piracy on DI.recursive_similar(cache::LinearSolve.LinearCache, t)
, but it must return an actual object (and not a type as your implementation above does). Or you could destructure these fancy structs until they are nested tuples, and recombine them inside your solver.
As for the choice between LinearSolve.jl and FastLapackInterface.jl, Chris is right: in general you don’t want to differentiate through the linear solve, but use a custom rule which LinearSolve.jl has already written for you.
ImplicitDifferentiation.jl does something similar but it is not at all focused on avoiding allocations and it is much more generic, so it doesn’t seem like a good fit for your linear problem.