Taking gradients of a matrix exponential

I’m having trouble differentiating simple matrix exponentials. ForwardDiff and ReverseDiff can’t really do it, while Zygote’s result is type-unstable.

using Zygote, ForwardDiff, ReverseDiff

M = rand(2,2)
ReverseDiff.gradient(x -> sum(exp(x)),M) #MethodError
ForwardDiff.gradient(x -> sum(exp(x)),M) #MethodError

Zygote.gradient(x -> sum(exp(x)),M) #Works, but:
@code_warntype Zygote.gradient(x -> sum(exp(x)),M) #Type-Unstable
@code_warntype
MethodInstance for Zygote.gradient(::var"#13#14", ::Matrix{Float64})
  from gradient(f, args...) @ Zygote C:\Users\55619\.julia\packages\Zygote\YYT6v\src\compiler\interface.jl:95
Arguments
  #self#::Core.Const(Zygote.gradient)
  f::Core.Const(var"#13#14"())
  args::Tuple{Matrix{Float64}}
Locals
  @_4::Int64
  grad::Union{Nothing, Tuple}
  back::Zygote.var"#75#76"{Zygote.Pullback{Tuple{var"#13#14", Matrix{Float64}}, var"#s178"}} where var"#s178"<:Tuple{Union{Zygote.ZBack{ChainRules.var"#exp_pullback#2042"{Tuple{Int64, Int64, Vector{Float64}, Vector{Float64}, Int64, Vector{Matrix{Float64}}, Matrix{Float64}, LinearAlgebra.LU{Float64, Matrix{Float64}, Vector{Int64}}, Vector{Matrix{Float64}}}, Matrix{Float64}, Matrix{Float64}}}, Zygote.ZBack{ChainRules.var"#exp_pullback_hermitian#2041"{Tuple{Vector{Float64}, Matrix{Float64}, Vector{Float64}, Vector{Float64}}, LinearAlgebra.Symmetric{Float64, Matrix{Float64}}, LinearAlgebra.Hermitian{Float64, Matrix{Float64}}}}}, Zygote.var"#2991#back#766"{Zygote.var"#760#764"{Matrix{Float64}}}}
  y::Float64
Body::Union{Nothing, Tuple}
1 ─ %1  = Zygote.pullback::Core.Const(ZygoteRules.pullback)
│   %2  = Core.tuple(f)::Core.Const((var"#13#14"(),))
│   %3  = Core._apply_iterate(Base.iterate, %1, %2, args)::Tuple{Float64, Zygote.var"#75#76"{Zygote.Pullback{Tuple{var"#13#14", Matrix{Float64}}, var"#s178"}} where var"#s178"<:Tuple{Union{Zygote.ZBack{ChainRules.var"#exp_pullback#2042"{Tuple{Int64, Int64, Vector{Float64}, Vector{Float64}, Int64, Vector{Matrix{Float64}}, Matrix{Float64}, LinearAlgebra.LU{Float64, Matrix{Float64}, Vector{Int64}}, Vector{Matrix{Float64}}}, Matrix{Float64}, Matrix{Float64}}}, Zygote.ZBack{ChainRules.var"#exp_pullback_hermitian#2041"{Tuple{Vector{Float64}, Matrix{Float64}, Vector{Float64}, Vector{Float64}}, LinearAlgebra.Symmetric{Float64, Matrix{Float64}}, LinearAlgebra.Hermitian{Float64, Matrix{Float64}}}}}, Zygote.var"#2991#back#766"{Zygote.var"#760#764"{Matrix{Float64}}}}}
│   %4  = Base.indexed_iterate(%3, 1)::Core.PartialStruct(Tuple{Float64, Int64}, Any[Float64, Core.Const(2)])
│         (y = Core.getfield(%4, 1))
│         (@_4 = Core.getfield(%4, 2))
│   %7  = Base.indexed_iterate(%3, 2, @_4::Core.Const(2))::Core.PartialStruct(Tuple{Zygote.var"#75#76"{Zygote.Pullback{Tuple{var"#13#14", Matrix{Float64}}, var"#s178"}} where var"#s178"<:Tuple{Union{Zygote.ZBack{ChainRules.var"#exp_pullback#2042"{Tuple{Int64, Int64, Vector{Float64}, Vector{Float64}, Int64, Vector{Matrix{Float64}}, Matrix{Float64}, LinearAlgebra.LU{Float64, Matrix{Float64}, Vector{Int64}}, Vector{Matrix{Float64}}}, Matrix{Float64}, Matrix{Float64}}}, Zygote.ZBack{ChainRules.var"#exp_pullback_hermitian#2041"{Tuple{Vector{Float64}, Matrix{Float64}, Vector{Float64}, Vector{Float64}}, LinearAlgebra.Symmetric{Float64, Matrix{Float64}}, LinearAlgebra.Hermitian{Float64, Matrix{Float64}}}}}, Zygote.var"#2991#back#766"{Zygote.var"#760#764"{Matrix{Float64}}}}, Int64}, Any[Zygote.var"#75#76"{Zygote.Pullback{Tuple{var"#13#14", Matrix{Float64}}, var"#s178"}} where var"#s178"<:Tuple{Union{Zygote.ZBack{ChainRules.var"#exp_pullback#2042"{Tuple{Int64, Int64, Vector{Float64}, Vector{Float64}, Int64, Vector{Matrix{Float64}}, Matrix{Float64}, LinearAlgebra.LU{Float64, Matrix{Float64}, Vector{Int64}}, Vector{Matrix{Float64}}}, Matrix{Float64}, Matrix{Float64}}}, Zygote.ZBack{ChainRules.var"#exp_pullback_hermitian#2041"{Tuple{Vector{Float64}, Matrix{Float64}, Vector{Float64}, Vector{Float64}}, LinearAlgebra.Symmetric{Float64, Matrix{Float64}}, LinearAlgebra.Hermitian{Float64, Matrix{Float64}}}}}, Zygote.var"#2991#back#766"{Zygote.var"#760#764"{Matrix{Float64}}}}, Core.Const(3)])
│         (back = Core.getfield(%7, 1))
│   %9  = Zygote.sensitivity(y)::Core.Const(1.0)
│         (grad = (back)(%9))
│   %11 = Zygote.isnothing(grad)::Bool
└──       goto #3 if not %11
2 ─       return Zygote.nothing
3 ─ %14 = Zygote.map(Zygote._project, args, grad::Tuple)::Union{Tuple{}, Tuple{Any}}
└──       return %14

Can anyone share any advice? If possible, I would like to avoid doing this with Zygote. I suppose using a linear ODE solver and differentiating it with SciMLSensitivity could work, but there must be a simpler solution to this, right?

Adjoint for matrix exp is fairly straightforward and indeed it is implemented in ChainRules.jl

ReverseDiff and ForwardDiff however won’t use ChainRules. They both need to be told how to differentiate that function. I didn’t check this but I’m guessing there is some non-Julia in the back that prevents them from doing it out of the box. ReverseDiff.jl can be told to use Chainrules. Maybe that’s the best option.

I’m more concerned about the type instability when you use Zygote. I can reproduce this. Maybe somebody should comment whether this is a Zygote or a ChainRules issue. I’m guessing it is Zygote, since calling rrule directly is type stable:

using ChainRules, LinearAlgebra
A = randn(3,3)
expA, pb = ChainRules.rrule(exp, A)
@code_warntype pb(A)

P.S.: I’m assuming your example is just a MWE. Otherwise I would say just implement it yourself, possibly using ChainRules.rrule.

1 Like

I did write this blog post comparing a few different matrix exponential implementations using ForwardDiff.jl, but odds are the ChainRules are a better approach:
https://spmd.org/posts/multithreadedallocations/

2 Likes

I see, thank you very much. I gave this a shot, but still couldn’t get it to work:

using ReverseDiff
M = rand(2,2)

ReverseDiff.@grad_from_chainrules exp(x::ReverseDiff.TrackedArray)
ReverseDiff.@grad_from_chainrules exp(x::ReverseDiff.TrackedReal)
ReverseDiff.gradient(x -> sum(exp(x)),M) #MethodError: no method matching iterate(::Nothing)

Or did I misunderstand how the macro is suposed to be used?

I would guess it’s Zygote: I’ve had plenty of similar issues in the process of building my application, which is why refactored everything to make use of ReverseDiff (and would prefer not to use Zygote).

Yeah, it’s supposed to be part of a custom layer in a Neural Network.

First off: great blog post, learned a lot from it.
Still, given that my actual application will be in the context of neural networks, I think reverse-mode differentiation would be a better fit. Thank you, in any case.

My guess is that the exp function primitive is not defined for exponential of a matrix? If instead you used a hand-written version of the exponential function (which approximates quite well the actual exponential)

function exp2(m)
    res = zeros(size(m))
    for i in 0:10
        res += m^i ./ factorial(i)
    end
    res
end

this works when doing the reverse pass

ReverseDiff.gradient(x -> sum(exp2(x)),M)

You can customized pullback rules in ChainRulesBase and add them to your set of rules. Something like this should do the job, although I didn’t manage to make it work

function ChainRulesCore.rrule(::typeof(exp), M::AbstractMatrix)
    exp_pullback(ΔM) = (NoTangent(), exp(M) * ΔM)
    return exp(M), exp_pullback
end

Hope this helps. Please post once you find a solution to this! I would like to know how to make this work too :wink:

It actually is defined, see the posts above.

This Taylor-series algorithm is discussed in section 3 (“method 1”) of the classic paper Nineteen dubious ways to compute the exponential of a matrix (1978). Basically, the naive series is quite unreliable, even if you sum enough terms, because it is susceptible to catastrophic cancellation.

(Contrary to popular misconception from first-year calculus, Taylor series are not typically how special functions are computed.)

It’s really much safer to use the built-in exp function here, which means that you need to teach your AD system to use a custom rule (e.g. the one from ChainRules.jl).

2 Likes

The following works fine for me:


using ReverseDiff, LinearAlgebra, ChainRules, Zygote
ReverseDiff.@grad_from_chainrules LinearAlgebra.exp(x::ReverseDiff.TrackedArray)
A = randn(3,3)
F = x -> sum(exp(x))
G1 = ReverseDiff.gradient(F, A)
G2 = Zygote.gradient(F, A)[1]
@show G1 ≈ G2

P.S.: @Bizzi – I’m not sure why your script didn’t work. Probably some combination of imports/using that are required to run this. This is where I sometimes get confused and just try until it runs :frowning:

P.P.S.: neither Zygote nor ReverseDiff have great performance here…

using BenchmarkTools
@btime ReverseDiff.gradient($F, $A) 
#   5.611 μs (81 allocations: 8.34 KiB)
@btime Zygote.gradient($F, $A)
#  3.970 μs (74 allocations: 8.03 KiB)
2 Likes

You’re right, importing ChainRules was also necessary :melting_face:. Kinda misleading error message but anyway.
Thank you all very much!

1 Like

Related issue #51008. exp just doesn’t work with many input types (it only works for matrices with eltypes that promote to LinearAlgebra.BlasFloat, which does not include things like dual numbers). Fixing it doesn’t require any significant algorithmic adjustments, just that someone go through and add a code path for more generic types.