Using Zygote.jl with IterativeSolvers.jl

I have a code which solves an initial value problem of ODEs parameterized by a control vector, then uses the solution as an argument to calculate an objective function, which I compute the gradient of (by my own methods specific to the problem). In the timestepping for this code I use gmres! from the IterativeSolvers.jl package.

I was interested in using ForwardDiff.jl or Zygote.jl to use automatic differentiation to compute the gradient, but it seems like this may not be a possibility. As I understand it, automatic differentiation requires that all functions be “pure.” That is, they do not mutate their arguments. I am worried that because gmres allocates a solution vector which it updates in place, that it calls mutating functions which would make automatic differentiation incompatible.

Indeed, I tried the following minimal working example trying to use automatic differentiation on a function which uses gmres:

julia> using Zygote, IterativeSolvers, LinearAlgebra

julia> function f(x)
         B = [1.0 2.0;3.0 4.0]
         c = IterativeSolvers.gmres(B, x)
         return LinearAlgebra.norm(c)
       end
f (generic function with 1 method)

julia> Zygote.gradient(f, [1.0,2.0])
ERROR: Can't differentiate gc_preserve_end expression $(Expr(:gc_preserve_end, %83)).
You might want to check the Zygote limitations documentation.
https://fluxml.ai/Zygote.jl/latest/limitations

Zygote.gradient throws an error, but from the error message it is not obvious to me that the error has to do with array mutation, as I would have expected. (the error message is not like that given in Limitations · Zygote). I don’t know where gc_preserve_end expression comes from. Perhaps gc stands for garbage collection?

Has anyone had any success using Zygote.jl or any other automatic differentiation package with IterativeSolvers.jl (specifically with gmres)? If so, how did you resolve this issue?

Hi @leespen1!
We worked on DiffKrylov.jl recently to combine Krylov.jl with the AD packages ForwardDiff.jl and Enzyme.jl.
It should be quite easy to perform sensitivity analysis with it.

2 Likes

That’s fantastic! I notice the github readme says it does not support “linear operators.” Does that mean I can’t use a (nonmutating) linear map from LinearMaps.jl in place of a matrix? Because that is also essential for me (although I didn’t include that in the minimal working example).

Try using GitHub - gdalle/ImplicitDifferentiation.jl: Automatic differentiation of implicit functions. Under the hood, ID uses gmres and is compatible with most common autodiff libraries.

@amontoison [and Michel but I can’t find is discourse], is there a reason this is in a separate package, rather than an extension package to Krylov.jl?

I ask because it would be nice for those who autodiff Krylov to be able to use AD without having to know that they should also include a separate package with the rules?

If you use LinearSolve.jl, then you get pretty much all linear solvers in Julia (Linear System Solvers · LinearSolve.jl) setup with most AD backends. Right now that’s Enzyme and ForwardDiff, but I can probably finish ChainRules (and thus Zygote) by the end of today.

2 Likes

DiffKrylov.jl has not yet been released. Current open issues are listed in the README. We were just thinking about the general issues you can encounter when differentiating Krylov solvers. Even with this API you can still be using bicgstab in the original code and then having to switch to gmres for the tangents and adjoints. Also, the interaction with preconditioners (tangent/adjoint precondition) is interesting.

We will add operator support sometime soon. And it’s a separate package because I couldn’t figure out how to add dependencies via extensions (I guess it’s doable). Also, you can’t add new functions to a module via an extension. Krylov.jl is very clean and simple. We didn’t want to add any disturbance just yet.

Our focus for now was on portability to all GPU architectures. So let me also introduce Krylov.jl’s sidekick KrylovPreconditioners.jl. It has a block-Jacobi preconditioner that works on Intel, AMD, and NVIDIA via KernelAbstractions.jl.

2 Likes

That solves the problem of the MWE I provided, thanks! Unfortunately, it doesn’t seem to work when the left-hand-side is a linear map as implemented via LinearMaps.jl, instead of just a plain matrix (which I didn’t have in the MWE, but is an additional requirement I have).

julia> using LinearSolve, LinearAlgebra, LinearMaps, ForwardDiff

julia> linmap = LinearMap(x -> 2*x, 2, 2)
2×2 FunctionMap{Float64,false}(#5; issymmetric=false, ishermitian=false, isposdef=false)

julia> function f(x)
         prob = LinearProblem{isinplace}(lmapp, x)
         linsolve = init(prob)
         sol = solve(linsolve)
         return norm(sol)
       end
f (generic function with 1 method)

julia> ForwardDiff.gradient(f, [1.0, 2.0])
ERROR: MethodError: no method matching solve!(::Krylov.GmresSolver{ForwardDiff.Dual{ForwardDiff.Tag{t
ypeof(f), Float64}, Float64, 2}, ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float64}, Float64, 2}, V
ector{ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float64}, Float64, 2}}}, ::FunctionMap{Float64, typ
eof(lmap_wrapper), Nothing, false}, ::Vector{ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float64}, Fl
oat64, 2}}; atol::Float64, rtol::Float64, itmax::Int64, verbose::Int64, ldiv::Bool, history::Bool)

We deprecated LinearMaps for SciMLOperators quite awhile ago since there’s missing operations that would be required in order to support many things. In general I’d highly recommend not using LinearMaps.

But that’s somewhat unrelated. What you’re hitting there is almost ready to merge, follow Add ForwardDiff rules by sharanry · Pull Request #434 · SciML/LinearSolve.jl · GitHub.

Hi is it updated with the Zygote backend ? I couldn’t find an example in the documentation. When I try to differentiate a function with a linear solve from the package Zygote (or maybe LinearSolve) complains “type LinearSolution has no field prob”

I just realized that we forgot to merge Adjoints for Linear Solve by avik-pal · Pull Request #449 · SciML/LinearSolve.jl · GitHub. We’ll get it rebased and merged ASAP and that should solve this.

Adjoints for Linear Solve by avik-pal · Pull Request #449 · SciML/LinearSolve.jl · GitHub merged and so in the LinearSolve v2.25.0 you should now see that the linear solve interface is now supported by Zygote (and Enzyme) AD frontends.

1 Like