Taking gradients of a matrix exponential

I’m having trouble differentiating simple matrix exponentials. `ForwardDiff` and `ReverseDiff` can’t really do it, while Zygote’s result is type-unstable.

``````using Zygote, ForwardDiff, ReverseDiff

M = rand(2,2)

``````
@code_warntype
``````MethodInstance for Zygote.gradient(::var"#13#14", ::Matrix{Float64})
from gradient(f, args...) @ Zygote C:\Users\55619\.julia\packages\Zygote\YYT6v\src\compiler\interface.jl:95
Arguments
f::Core.Const(var"#13#14"())
args::Tuple{Matrix{Float64}}
Locals
@_4::Int64
back::Zygote.var"#75#76"{Zygote.Pullback{Tuple{var"#13#14", Matrix{Float64}}, var"#s178"}} where var"#s178"<:Tuple{Union{Zygote.ZBack{ChainRules.var"#exp_pullback#2042"{Tuple{Int64, Int64, Vector{Float64}, Vector{Float64}, Int64, Vector{Matrix{Float64}}, Matrix{Float64}, LinearAlgebra.LU{Float64, Matrix{Float64}, Vector{Int64}}, Vector{Matrix{Float64}}}, Matrix{Float64}, Matrix{Float64}}}, Zygote.ZBack{ChainRules.var"#exp_pullback_hermitian#2041"{Tuple{Vector{Float64}, Matrix{Float64}, Vector{Float64}, Vector{Float64}}, LinearAlgebra.Symmetric{Float64, Matrix{Float64}}, LinearAlgebra.Hermitian{Float64, Matrix{Float64}}}}}, Zygote.var"#2991#back#766"{Zygote.var"#760#764"{Matrix{Float64}}}}
y::Float64
Body::Union{Nothing, Tuple}
1 ─ %1  = Zygote.pullback::Core.Const(ZygoteRules.pullback)
│   %2  = Core.tuple(f)::Core.Const((var"#13#14"(),))
│   %3  = Core._apply_iterate(Base.iterate, %1, %2, args)::Tuple{Float64, Zygote.var"#75#76"{Zygote.Pullback{Tuple{var"#13#14", Matrix{Float64}}, var"#s178"}} where var"#s178"<:Tuple{Union{Zygote.ZBack{ChainRules.var"#exp_pullback#2042"{Tuple{Int64, Int64, Vector{Float64}, Vector{Float64}, Int64, Vector{Matrix{Float64}}, Matrix{Float64}, LinearAlgebra.LU{Float64, Matrix{Float64}, Vector{Int64}}, Vector{Matrix{Float64}}}, Matrix{Float64}, Matrix{Float64}}}, Zygote.ZBack{ChainRules.var"#exp_pullback_hermitian#2041"{Tuple{Vector{Float64}, Matrix{Float64}, Vector{Float64}, Vector{Float64}}, LinearAlgebra.Symmetric{Float64, Matrix{Float64}}, LinearAlgebra.Hermitian{Float64, Matrix{Float64}}}}}, Zygote.var"#2991#back#766"{Zygote.var"#760#764"{Matrix{Float64}}}}}
│   %4  = Base.indexed_iterate(%3, 1)::Core.PartialStruct(Tuple{Float64, Int64}, Any[Float64, Core.Const(2)])
│         (y = Core.getfield(%4, 1))
│         (@_4 = Core.getfield(%4, 2))
│   %7  = Base.indexed_iterate(%3, 2, @_4::Core.Const(2))::Core.PartialStruct(Tuple{Zygote.var"#75#76"{Zygote.Pullback{Tuple{var"#13#14", Matrix{Float64}}, var"#s178"}} where var"#s178"<:Tuple{Union{Zygote.ZBack{ChainRules.var"#exp_pullback#2042"{Tuple{Int64, Int64, Vector{Float64}, Vector{Float64}, Int64, Vector{Matrix{Float64}}, Matrix{Float64}, LinearAlgebra.LU{Float64, Matrix{Float64}, Vector{Int64}}, Vector{Matrix{Float64}}}, Matrix{Float64}, Matrix{Float64}}}, Zygote.ZBack{ChainRules.var"#exp_pullback_hermitian#2041"{Tuple{Vector{Float64}, Matrix{Float64}, Vector{Float64}, Vector{Float64}}, LinearAlgebra.Symmetric{Float64, Matrix{Float64}}, LinearAlgebra.Hermitian{Float64, Matrix{Float64}}}}}, Zygote.var"#2991#back#766"{Zygote.var"#760#764"{Matrix{Float64}}}}, Int64}, Any[Zygote.var"#75#76"{Zygote.Pullback{Tuple{var"#13#14", Matrix{Float64}}, var"#s178"}} where var"#s178"<:Tuple{Union{Zygote.ZBack{ChainRules.var"#exp_pullback#2042"{Tuple{Int64, Int64, Vector{Float64}, Vector{Float64}, Int64, Vector{Matrix{Float64}}, Matrix{Float64}, LinearAlgebra.LU{Float64, Matrix{Float64}, Vector{Int64}}, Vector{Matrix{Float64}}}, Matrix{Float64}, Matrix{Float64}}}, Zygote.ZBack{ChainRules.var"#exp_pullback_hermitian#2041"{Tuple{Vector{Float64}, Matrix{Float64}, Vector{Float64}, Vector{Float64}}, LinearAlgebra.Symmetric{Float64, Matrix{Float64}}, LinearAlgebra.Hermitian{Float64, Matrix{Float64}}}}}, Zygote.var"#2991#back#766"{Zygote.var"#760#764"{Matrix{Float64}}}}, Core.Const(3)])
│         (back = Core.getfield(%7, 1))
│   %9  = Zygote.sensitivity(y)::Core.Const(1.0)
└──       goto #3 if not %11
2 ─       return Zygote.nothing
3 ─ %14 = Zygote.map(Zygote._project, args, grad::Tuple)::Union{Tuple{}, Tuple{Any}}
└──       return %14
``````

Can anyone share any advice? If possible, I would like to avoid doing this with Zygote. I suppose using a linear ODE solver and differentiating it with `SciMLSensitivity` could work, but there must be a simpler solution to this, right?

Adjoint for matrix exp is fairly straightforward and indeed it is implemented in ChainRules.jl

ReverseDiff and ForwardDiff however won’t use ChainRules. They both need to be told how to differentiate that function. I didn’t check this but I’m guessing there is some non-Julia in the back that prevents them from doing it out of the box. ReverseDiff.jl can be told to use Chainrules. Maybe that’s the best option.

I’m more concerned about the type instability when you use Zygote. I can reproduce this. Maybe somebody should comment whether this is a Zygote or a ChainRules issue. I’m guessing it is Zygote, since calling `rrule` directly is type stable:

``````using ChainRules, LinearAlgebra
A = randn(3,3)
expA, pb = ChainRules.rrule(exp, A)
@code_warntype pb(A)
``````

P.S.: I’m assuming your example is just a MWE. Otherwise I would say just implement it yourself, possibly using `ChainRules.rrule`.

1 Like

I did write this blog post comparing a few different matrix exponential implementations using `ForwardDiff.jl`, but odds are the `ChainRules` are a better approach:

2 Likes

I see, thank you very much. I gave this a shot, but still couldn’t get it to work:

``````using ReverseDiff
M = rand(2,2)

ReverseDiff.gradient(x -> sum(exp(x)),M) #MethodError: no method matching iterate(::Nothing)
``````

Or did I misunderstand how the macro is suposed to be used?

I would guess it’s Zygote: I’ve had plenty of similar issues in the process of building my application, which is why refactored everything to make use of ReverseDiff (and would prefer not to use Zygote).

Yeah, it’s supposed to be part of a custom layer in a Neural Network.

First off: great blog post, learned a lot from it.
Still, given that my actual application will be in the context of neural networks, I think reverse-mode differentiation would be a better fit. Thank you, in any case.

My guess is that the `exp` function primitive is not defined for exponential of a matrix? If instead you used a hand-written version of the exponential function (which approximates quite well the actual exponential)

``````function exp2(m)
res = zeros(size(m))
for i in 0:10
res += m^i ./ factorial(i)
end
res
end
``````

this works when doing the reverse pass

``````ReverseDiff.gradient(x -> sum(exp2(x)),M)
``````

You can customized pullback rules in ChainRulesBase and add them to your set of rules. Something like this should do the job, although I didn’t manage to make it work

``````function ChainRulesCore.rrule(::typeof(exp), M::AbstractMatrix)
exp_pullback(ΔM) = (NoTangent(), exp(M) * ΔM)
return exp(M), exp_pullback
end
``````

Hope this helps. Please post once you find a solution to this! I would like to know how to make this work too

It actually is defined, see the posts above.

This Taylor-series algorithm is discussed in section 3 (“method 1”) of the classic paper Nineteen dubious ways to compute the exponential of a matrix (1978). Basically, the naive series is quite unreliable, even if you sum enough terms, because it is susceptible to catastrophic cancellation.

(Contrary to popular misconception from first-year calculus, Taylor series are not typically how special functions are computed.)

It’s really much safer to use the built-in `exp` function here, which means that you need to teach your AD system to use a custom rule (e.g. the one from ChainRules.jl).

2 Likes

The following works fine for me:

``````
using ReverseDiff, LinearAlgebra, ChainRules, Zygote
A = randn(3,3)
F = x -> sum(exp(x))
@show G1 ≈ G2
``````

P.S.: @Bizzi – I’m not sure why your script didn’t work. Probably some combination of imports/using that are required to run this. This is where I sometimes get confused and just try until it runs

P.P.S.: neither Zygote nor ReverseDiff have great performance here…

``````using BenchmarkTools
You’re right, importing `ChainRules` was also necessary . Kinda misleading error message but anyway.
Related issue #51008. `exp` just doesn’t work with many input types (it only works for matrices with `eltype`s that promote to `LinearAlgebra.BlasFloat`, which does not include things like dual numbers). Fixing it doesn’t require any significant algorithmic adjustments, just that someone go through and add a code path for more generic types.