[ANN] Announcing ImplicitDifferentiation.jl

@mohamed82008 and myself are happy to announce that our new package ImplicitDifferentiation.jl is ready, just in time for JuliaCon2022! :partying_face:

What is it for?

The goal is to differentiate through complex procedures like:

  • Nonlinear equation systems
  • Optimization, with or without constraints
  • Fixed point algorithms
  • Differential equations

Tackling such problems often requires a dedicated solver, which may not play nice with autodiff.
Indeed, some solvers are black boxes that will make Zygote.jl and friends crash, while others involve iterative routines that are very costly to backpropagate through.

How does it work?

If we specify a set of conditions satisfied by a solution, we can differentiate through the solver regardless of its actual implementation, and without unrolling any loop.
For instance, the solution to an unconstrained optimization problem satisfies gradient stationarity. The solution to a fixed-point equation satisfies… well, the fixed point equation!

Our package provides a simple wrapper ImplicitFunction(solver, conditions) that is compatible with the ChainRules.jl ecosystem. Its forward and reverse chain rules are computed thanks to the implicit function theorem (see this paper for theoretical details).
While we drew inspiration from @mohamed82008’s initial code in NonconvexUtils.jl, we strove to make ImplicitDifferentiation.jl as simple and lightweight as possible, with very few dependencies and LOCs.

Do you need it?

It depends! Some solvers, mainly from the SciML ecosystem, already come bundled with their own chain rules, but many others don’t.
We present several examples of applications in the package documentation, and we plan to add more in the future. If you have another use case in mind, feel free to reach out and we’ll help you make it work!

51 Likes

This is great, thanks! I’m looking forward to trying it out in the tensor network automatic differentiation code we are developing, I think there will be many applications.

I’m trying to test it out by differentiating an eigenvector equation. Here is my code:

using FiniteDifferences
using ImplicitDifferentiation
using LinearAlgebra
using Optim
using Random
using Zygote

Random.seed!(1234)

# Minimal eigenvector of `A`.
# Minimize the Rayleigh quotient:
# <x, A, x> / <x, x>
function fixed_point(A::AbstractMatrix)
  f(x) = x'A * x / x'x
  function g!(G, x)
    G .= gradient(f, x)[1]
  end
  y0 = randn(eltype(A), size(A, 1))
  res = optimize(f, g!, y0, LBFGS())
  y = Optim.minimizer(res)
  return y / norm(y)
end

# Optimality condition:
# <Ax, Ax> - |<x, A, x>|^2
function variance(A::AbstractMatrix, x::AbstractVector)
  Ax = A * x
  return Ax'Ax - abs2(x'Ax)
end

n = 2
A = randn(n, n) |> A -> A'A
y = fixed_point(A)
@show variance(A, y)

implicit = ImplicitFunction(fixed_point, variance)

J_ref = FiniteDifferences.jacobian(central_fdm(5, 1), fixed_point, A)[1]
@show J_ref
J = Zygote.jacobian(implicit, A)[1]

This outputs:

variance(A, y) = 5.421010862427522e-20
J_ref = [228164.23220199018 -0.06483760689429115 -28513.091586444993 -179568.6835259812; -316053.0191408435 -0.046807388514359186 39496.52462581584 248738.58850201144]
ERROR: LoadError: MethodError: no method matching (::ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}})(::Vector{Float64})
Closest candidates are:
  (::ChainRulesCore.ProjectTo)(::ChainRulesCore.InplaceableThunk) at ~/.julia/packages/ChainRulesCore/16PWJ/src/projection.jl:125
  (::ChainRulesCore.ProjectTo{<:Real})(::Complex) at ~/.julia/packages/ChainRulesCore/16PWJ/src/projection.jl:187
  (::ChainRulesCore.ProjectTo{T})(::ChainRulesCore.AbstractZero) where T at ~/.julia/packages/ChainRulesCore/16PWJ/src/projection.jl:120
  ...
Stacktrace:
  [1] -_pullback
    @ ~/.julia/packages/ChainRules/nu2G0/src/rulesets/Base/fastmath_able.jl:214 [inlined]
  [2] ZBack
    @ ~/.julia/packages/Zygote/DkIUK/src/compiler/chainrules.jl:205 [inlined]
  [3] Pullback
    @ ~/.julia/packages/ImplicitDifferentiation/90Hcu/src/implicit_function.jl:94 [inlined]
  [4] (::typeof(∂(λ)))(Δ::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/DkIUK/src/compiler/interface2.jl:0
  [5] (::Zygote.var"#ad_pullback#42"{Tuple{ImplicitDifferentiation.var"#conditions_y#8"{Matrix{Float64}, typeof(variance)}, Vector{Float64}}, typeof(∂(λ))})(Δ::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/DkIUK/src/compiler/chainrules.jl:257
  [6] (::ComposedFunction{typeof(last), Zygote.var"#ad_pullback#42"{Tuple{ImplicitDifferentiation.var"#conditions_y#8"{Matrix{Float64}, typeof(variance)}, Vector{Float64}}, typeof(∂(λ))}})(x::Vector{Float64}; kw::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
    @ Base ./operators.jl:1085
  [7] (::ComposedFunction{typeof(last), Zygote.var"#ad_pullback#42"{Tuple{ImplicitDifferentiation.var"#conditions_y#8"{Matrix{Float64}, typeof(variance)}, Vector{Float64}}, typeof(∂(λ))}})(x::Vector{Float64})
    @ Base ./operators.jl:1085
  [8] (::ImplicitDifferentiation.var"#mul_Aᵀ!#9"{ComposedFunction{typeof(last), Zygote.var"#ad_pullback#42"{Tuple{ImplicitDifferentiation.var"#conditions_y#8"{Matrix{Float64}, typeof(variance)}, Vector{Float64}}, typeof(∂(λ))}}, Vector{Float64}})(res::Vector{Float64}, u::Vector{Float64})
    @ ImplicitDifferentiation ~/.julia/packages/ImplicitDifferentiation/90Hcu/src/implicit_function.jl:99
  [9] prod3!(res::Vector{Float64}, prod!::ImplicitDifferentiation.var"#mul_Aᵀ!#9"{ComposedFunction{typeof(last), Zygote.var"#ad_pullback#42"{Tuple{ImplicitDifferentiation.var"#conditions_y#8"{Matrix{Float64}, typeof(variance)}, Vector{Float64}}, typeof(∂(λ))}}, Vector{Float64}}, v::Vector{Float64}, α::Float64, β::Float64, Mv5::Vector{Float64})
    @ LinearOperators ~/.julia/packages/LinearOperators/58FwN/src/operations.jl:12
 [10] mul!(res::Vector{Float64}, op::LinearOperators.LinearOperator{Float64, Int64, ImplicitDifferentiation.var"#mul_Aᵀ!#9"{ComposedFunction{typeof(last), Zygote.var"#ad_pullback#42"{Tuple{ImplicitDifferentiation.var"#conditions_y#8"{Matrix{Float64}, typeof(variance)}, Vector{Float64}}, typeof(∂(λ))}}, Vector{Float64}}, Nothing, Nothing, Vector{Float64}}, v::Vector{Float64}, α::Float64, β::Float64)
    @ LinearOperators ~/.julia/packages/LinearOperators/58FwN/src/operations.jl:31
 [11] mul!(res::Vector{Float64}, op::LinearOperators.LinearOperator{Float64, Int64, ImplicitDifferentiation.var"#mul_Aᵀ!#9"{ComposedFunction{typeof(last), Zygote.var"#ad_pullback#42"{Tuple{ImplicitDifferentiation.var"#conditions_y#8"{Matrix{Float64}, typeof(variance)}, Vector{Float64}}, typeof(∂(λ))}}, Vector{Float64}}, Nothing, Nothing, Vector{Float64}}, v::Vector{Float64})
    @ LinearOperators ~/.julia/packages/LinearOperators/58FwN/src/operations.jl:36
 [12] gmres!(solver::Krylov.GmresSolver{Float64, Float64, Vector{Float64}}, A::LinearOperators.LinearOperator{Float64, Int64, ImplicitDifferentiation.var"#mul_Aᵀ!#9"{ComposedFunction{typeof(last), Zygote.var"#ad_pullback#42"{Tuple{ImplicitDifferentiation.var"#conditions_y#8"{Matrix{Float64}, typeof(variance)}, Vector{Float64}}, typeof(∂(λ))}}, Vector{Float64}}, Nothing, Nothing, Vector{Float64}}, b::Vector{Float64}; M::UniformScaling{Bool}, N::UniformScaling{Bool}, atol::Float64, rtol::Float64, reorthogonalization::Bool, itmax::Int64, restart::Bool, verbose::Int64, history::Bool, callback::Krylov.var"#164#166")
    @ Krylov ~/.julia/packages/Krylov/PApE9/src/gmres.jl:213
 [13] gmres!
    @ ~/.julia/packages/Krylov/PApE9/src/gmres.jl:90 [inlined]
 [14] #gmres#161
    @ ~/.julia/packages/Krylov/PApE9/src/gmres.jl:61 [inlined]
 [15] gmres
    @ ~/.julia/packages/Krylov/PApE9/src/gmres.jl:60 [inlined]
 [16] (::ImplicitDifferentiation.var"#implicit_pullback#11"{Float64, Matrix{Float64}, LinearOperators.LinearOperator{Float64, Int64, ImplicitDifferentiation.var"#mul_Bᵀ!#10"{ComposedFunction{typeof(last), Zygote.var"#ad_pullback#42"{Tuple{ImplicitDifferentiation.var"#conditions_x#7"{Vector{Float64}, typeof(variance)}, Matrix{Float64}}, typeof(∂(λ))}}, Vector{Float64}}, Nothing, Nothing, Vector{Float64}}, LinearOperators.LinearOperator{Float64, Int64, ImplicitDifferentiation.var"#mul_Aᵀ!#9"{ComposedFunction{typeof(last), Zygote.var"#ad_pullback#42"{Tuple{ImplicitDifferentiation.var"#conditions_y#8"{Matrix{Float64}, typeof(variance)}, Vector{Float64}}, typeof(∂(λ))}}, Vector{Float64}}, Nothing, Nothing, Vector{Float64}}, typeof(Krylov.gmres)})(dy::Vector{Float64})
    @ ImplicitDifferentiation ~/.julia/packages/ImplicitDifferentiation/90Hcu/src/implicit_function.jl:108
 [17] ZBack
    @ ~/.julia/packages/Zygote/DkIUK/src/compiler/chainrules.jl:205 [inlined]
 [18] (::Zygote.var"#208#209"{Tuple{Tuple{Nothing}}, Zygote.ZBack{ImplicitDifferentiation.var"#implicit_pullback#11"{Float64, Matrix{Float64}, LinearOperators.LinearOperator{Float64, Int64, ImplicitDifferentiation.var"#mul_Bᵀ!#10"{ComposedFunction{typeof(last), Zygote.var"#ad_pullback#42"{Tuple{ImplicitDifferentiation.var"#conditions_x#7"{Vector{Float64}, typeof(variance)}, Matrix{Float64}}, typeof(∂(λ))}}, Vector{Float64}}, Nothing, Nothing, Vector{Float64}}, LinearOperators.LinearOperator{Float64, Int64, ImplicitDifferentiation.var"#mul_Aᵀ!#9"{ComposedFunction{typeof(last), Zygote.var"#ad_pullback#42"{Tuple{ImplicitDifferentiation.var"#conditions_y#8"{Matrix{Float64}, typeof(variance)}, Vector{Float64}}, typeof(∂(λ))}}, Vector{Float64}}, Nothing, Nothing, Vector{Float64}}, typeof(Krylov.gmres)}}})(Δ::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/DkIUK/src/lib/lib.jl:207
 [19] (::Zygote.var"#1750#back#210"{Zygote.var"#208#209"{Tuple{Tuple{Nothing}}, Zygote.ZBack{ImplicitDifferentiation.var"#implicit_pullback#11"{Float64, Matrix{Float64}, LinearOperators.LinearOperator{Float64, Int64, ImplicitDifferentiation.var"#mul_Bᵀ!#10"{ComposedFunction{typeof(last), Zygote.var"#ad_pullback#42"{Tuple{ImplicitDifferentiation.var"#conditions_x#7"{Vector{Float64}, typeof(variance)}, Matrix{Float64}}, typeof(∂(λ))}}, Vector{Float64}}, Nothing, Nothing, Vector{Float64}}, LinearOperators.LinearOperator{Float64, Int64, ImplicitDifferentiation.var"#mul_Aᵀ!#9"{ComposedFunction{typeof(last), Zygote.var"#ad_pullback#42"{Tuple{ImplicitDifferentiation.var"#conditions_y#8"{Matrix{Float64}, typeof(variance)}, Vector{Float64}}, typeof(∂(λ))}}, Vector{Float64}}, Nothing, Nothing, Vector{Float64}}, typeof(Krylov.gmres)}}}})(Δ::Vector{Float64})
    @ Zygote ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:67
 [20] Pullback
    @ ./operators.jl:1085 [inlined]
 [21] (::typeof(∂(#_#83)))(Δ::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/DkIUK/src/compiler/interface2.jl:0
 [22] (::Zygote.var"#208#209"{Tuple{Tuple{Nothing, Nothing}, Tuple{Nothing}}, typeof(∂(#_#83))})(Δ::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/DkIUK/src/lib/lib.jl:207
 [23] #1750#back
    @ ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:67 [inlined]
 [24] Pullback
    @ ./operators.jl:1085 [inlined]
 [25] (::typeof(∂(ComposedFunction{typeof(Zygote._jvec), ImplicitFunction{typeof(fixed_point), typeof(variance), typeof(Krylov.gmres)}}(Zygote._jvec, ImplicitFunction{typeof(fixed_point), typeof(variance), typeof(Krylov.gmres)}(fixed_point, variance, Krylov.gmres)))))(Δ::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/DkIUK/src/compiler/interface2.jl:0
 [26] (::Zygote.var"#52#53"{typeof(∂(ComposedFunction{typeof(Zygote._jvec), ImplicitFunction{typeof(fixed_point), typeof(variance), typeof(Krylov.gmres)}}(Zygote._jvec, ImplicitFunction{typeof(fixed_point), typeof(variance), typeof(Krylov.gmres)}(fixed_point, variance, Krylov.gmres))))})(Δ::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/DkIUK/src/compiler/interface.jl:41
 [27] withjacobian(f::ImplicitFunction{typeof(fixed_point), typeof(variance), typeof(Krylov.gmres)}, args::Matrix{Float64})
    @ Zygote ~/.julia/packages/Zygote/DkIUK/src/lib/grad.jl:162
 [28] jacobian(f::ImplicitFunction{typeof(fixed_point), typeof(variance), typeof(Krylov.gmres)}, args::Matrix{Float64})
    @ Zygote ~/.julia/packages/Zygote/DkIUK/src/lib/grad.jl:140
 [29] top-level scope
    @ ~/Dropbox (Simons Foundation)/workdir/ImplictDifferentiation.jl/fixed_point_optim.jl:40
 [30] include(fname::String)
    @ Base.MainInclude ./client.jl:451
 [31] top-level scope
    @ REPL[4]:1
in expression starting at /home/mfishman/Dropbox (Simons Foundation)/workdir/ImplictDifferentiation.jl/fixed_point_optim.jl:40

Am I doing something wrong or hitting a bug?

Let me take a look and get back to you!

Okay, so I think I know what went wrong.
Your optimality condition F(A,y)=0 should actually return a vector with the same dimension n as y. In your implementation, the optimality condition is a single number, which does not carry sufficient information.

Intuitively, we want to use Ay=λy (n equalities) as the condition, but we don’t know λ. Luckily, there is an easy fix: return λ alongside y in the solver. This yields an output of dimension n+1, therefore we need an additional equality to compensate: we can use <y,y>=1.

The following code seems to work, at least it doesn’t throw an error.

Code
using ImplicitDifferentiation
using LinearAlgebra
using Optim
using Random
using Zygote

Random.seed!(1234)

# Minimal eigenvector of A:
# Minimize the Rayleigh quotient <y,Ay> / <y,y>
function minimum_eigenelements(A::AbstractMatrix)
    f(y) = (y' * A * y) ./ (y'y)
    g!(G, y) = G .= gradient(f, y)[1]
    y0 = randn(eltype(A), size(A, 1))
    res = optimize(f, g!, y0, LBFGS())
    λ = Optim.minimum(res)
    y = Optim.minimizer(res)
    y ./= norm(y)
    λ_and_y = vcat(λ, y)
    return λ_and_y
end

# Optimality condition:
# Ay = λy & ||y|| = 1
function optimality_conditions(A::AbstractMatrix, λ_and_y::AbstractVector)
    λ, y = λ_and_y[1], λ_and_y[2:end]
    return vcat(A * y - λ * y, y'y - 1)
end

n = 2
A = Symmetric(randn(n, n))
λ_and_y = minimum_eigenelements(A)
optimality_conditions(A, λ_and_y)

implicit = ImplicitFunction(minimum_eigenelements, optimality_conditions)
J = Zygote.jacobian(implicit, A)[1]

I removed the comparison with the Jacobian from FiniteDifferences.jl because symmetric matrices are a low-dimensional subspace of M_n(R). Hence, taking small steps in all directions will throw you out of that subspace, and make the optimization algorithm incorrect. To perform a proper comparison, one would need to consider only the coefficients of the upper or lower triangle as arguments.

This is actually a pretty great application example, would you mind if I included it in the docs, with your name of course?

3 Likes

For this particular problem, you can also use the LOBPCG algorithm in the forward function. LOBPCG · IterativeSolvers.jl.

That’s true, thanks. Partially I was trying to follow along as closely as possible to the example in the ImplicitDifferentiation.jl documentation. In general I will be using a more optimized eigensolver.

Great, thank you! Good to know that you need the optimality condition and the solver to return vectors of the same size. Please feel free to use it as an example, I was going to suggest that once I got it working.

Indeed my first thought was to use Ay=λy as the optimality condition but I thought it would be easier using the variance since I didn’t have to pass around the eigenvalue.

1 Like

Here is a version using an eigensolver (from KrylovKit):

using ComponentArrays
using ImplicitDifferentiation
using LinearAlgebra
using KrylovKit
using Random
using Zygote

Random.seed!(1234)

# Minimal eigenvalue and eigenvector of A
function minimum_eigenelements(A::AbstractMatrix)
  y0 = randn(eltype(A), size(A, 1))
  vals, vecs = eigsolve(A, y0, 1, :SR; eager=true)
  λ = vals[1]
  y = vecs[1]
  y ./= norm(y)
  return ComponentVector(; λ, y)
end

# Optimality condition:
# Ay = λy & ||y|| = 1
function optimality_conditions(A::AbstractMatrix, λ_and_y::ComponentVector)
  λ, y = λ_and_y.λ, λ_and_y.y
  return [y'y - 1; A * y - λ * y]
end

n = 2
A = randn(n, n) |> A -> Symmetric(A'A)
λ_and_y = minimum_eigenelements(A)
optimality_conditions(A, λ_and_y)

implicit = ImplicitFunction(minimum_eigenelements, optimality_conditions)
J = Zygote.jacobian(implicit, A)[1]
3 Likes

I just updated the documentation in order to clarify this point.

3 Likes