I’m having trouble differentiating simple matrix exponentials. ForwardDiff
and ReverseDiff
can’t really do it, while Zygote’s result is type-unstable.
using Zygote, ForwardDiff, ReverseDiff
M = rand(2,2)
ReverseDiff.gradient(x -> sum(exp(x)),M) #MethodError
ForwardDiff.gradient(x -> sum(exp(x)),M) #MethodError
Zygote.gradient(x -> sum(exp(x)),M) #Works, but:
@code_warntype Zygote.gradient(x -> sum(exp(x)),M) #Type-Unstable
@code_warntype
MethodInstance for Zygote.gradient(::var"#13#14", ::Matrix{Float64})
from gradient(f, args...) @ Zygote C:\Users\55619\.julia\packages\Zygote\YYT6v\src\compiler\interface.jl:95
Arguments
#self#::Core.Const(Zygote.gradient)
f::Core.Const(var"#13#14"())
args::Tuple{Matrix{Float64}}
Locals
@_4::Int64
grad::Union{Nothing, Tuple}
back::Zygote.var"#75#76"{Zygote.Pullback{Tuple{var"#13#14", Matrix{Float64}}, var"#s178"}} where var"#s178"<:Tuple{Union{Zygote.ZBack{ChainRules.var"#exp_pullback#2042"{Tuple{Int64, Int64, Vector{Float64}, Vector{Float64}, Int64, Vector{Matrix{Float64}}, Matrix{Float64}, LinearAlgebra.LU{Float64, Matrix{Float64}, Vector{Int64}}, Vector{Matrix{Float64}}}, Matrix{Float64}, Matrix{Float64}}}, Zygote.ZBack{ChainRules.var"#exp_pullback_hermitian#2041"{Tuple{Vector{Float64}, Matrix{Float64}, Vector{Float64}, Vector{Float64}}, LinearAlgebra.Symmetric{Float64, Matrix{Float64}}, LinearAlgebra.Hermitian{Float64, Matrix{Float64}}}}}, Zygote.var"#2991#back#766"{Zygote.var"#760#764"{Matrix{Float64}}}}
y::Float64
Body::Union{Nothing, Tuple}
1 ─ %1 = Zygote.pullback::Core.Const(ZygoteRules.pullback)
│ %2 = Core.tuple(f)::Core.Const((var"#13#14"(),))
│ %3 = Core._apply_iterate(Base.iterate, %1, %2, args)::Tuple{Float64, Zygote.var"#75#76"{Zygote.Pullback{Tuple{var"#13#14", Matrix{Float64}}, var"#s178"}} where var"#s178"<:Tuple{Union{Zygote.ZBack{ChainRules.var"#exp_pullback#2042"{Tuple{Int64, Int64, Vector{Float64}, Vector{Float64}, Int64, Vector{Matrix{Float64}}, Matrix{Float64}, LinearAlgebra.LU{Float64, Matrix{Float64}, Vector{Int64}}, Vector{Matrix{Float64}}}, Matrix{Float64}, Matrix{Float64}}}, Zygote.ZBack{ChainRules.var"#exp_pullback_hermitian#2041"{Tuple{Vector{Float64}, Matrix{Float64}, Vector{Float64}, Vector{Float64}}, LinearAlgebra.Symmetric{Float64, Matrix{Float64}}, LinearAlgebra.Hermitian{Float64, Matrix{Float64}}}}}, Zygote.var"#2991#back#766"{Zygote.var"#760#764"{Matrix{Float64}}}}}
│ %4 = Base.indexed_iterate(%3, 1)::Core.PartialStruct(Tuple{Float64, Int64}, Any[Float64, Core.Const(2)])
│ (y = Core.getfield(%4, 1))
│ (@_4 = Core.getfield(%4, 2))
│ %7 = Base.indexed_iterate(%3, 2, @_4::Core.Const(2))::Core.PartialStruct(Tuple{Zygote.var"#75#76"{Zygote.Pullback{Tuple{var"#13#14", Matrix{Float64}}, var"#s178"}} where var"#s178"<:Tuple{Union{Zygote.ZBack{ChainRules.var"#exp_pullback#2042"{Tuple{Int64, Int64, Vector{Float64}, Vector{Float64}, Int64, Vector{Matrix{Float64}}, Matrix{Float64}, LinearAlgebra.LU{Float64, Matrix{Float64}, Vector{Int64}}, Vector{Matrix{Float64}}}, Matrix{Float64}, Matrix{Float64}}}, Zygote.ZBack{ChainRules.var"#exp_pullback_hermitian#2041"{Tuple{Vector{Float64}, Matrix{Float64}, Vector{Float64}, Vector{Float64}}, LinearAlgebra.Symmetric{Float64, Matrix{Float64}}, LinearAlgebra.Hermitian{Float64, Matrix{Float64}}}}}, Zygote.var"#2991#back#766"{Zygote.var"#760#764"{Matrix{Float64}}}}, Int64}, Any[Zygote.var"#75#76"{Zygote.Pullback{Tuple{var"#13#14", Matrix{Float64}}, var"#s178"}} where var"#s178"<:Tuple{Union{Zygote.ZBack{ChainRules.var"#exp_pullback#2042"{Tuple{Int64, Int64, Vector{Float64}, Vector{Float64}, Int64, Vector{Matrix{Float64}}, Matrix{Float64}, LinearAlgebra.LU{Float64, Matrix{Float64}, Vector{Int64}}, Vector{Matrix{Float64}}}, Matrix{Float64}, Matrix{Float64}}}, Zygote.ZBack{ChainRules.var"#exp_pullback_hermitian#2041"{Tuple{Vector{Float64}, Matrix{Float64}, Vector{Float64}, Vector{Float64}}, LinearAlgebra.Symmetric{Float64, Matrix{Float64}}, LinearAlgebra.Hermitian{Float64, Matrix{Float64}}}}}, Zygote.var"#2991#back#766"{Zygote.var"#760#764"{Matrix{Float64}}}}, Core.Const(3)])
│ (back = Core.getfield(%7, 1))
│ %9 = Zygote.sensitivity(y)::Core.Const(1.0)
│ (grad = (back)(%9))
│ %11 = Zygote.isnothing(grad)::Bool
└── goto #3 if not %11
2 ─ return Zygote.nothing
3 ─ %14 = Zygote.map(Zygote._project, args, grad::Tuple)::Union{Tuple{}, Tuple{Any}}
└── return %14
Can anyone share any advice? If possible, I would like to avoid doing this with Zygote. I suppose using a linear ODE solver and differentiating it with SciMLSensitivity
could work, but there must be a simpler solution to this, right?