I am trying to use ReverseDiff.jl and ForwardDiff.jl as a mixed mode AD approach to computing a Hessian vector product. I am using the following function:
import ForwardDiff as fdiff
import ReverseDiff as rdiff
function _hvp(f::F, x::S, v::S) where {F, S<:AbstractVector{<:AbstractFloat}}
dual = fdiff.Dual.(x,v)
return fdiff.partials.(rdiff.gradient(f, dual), 1)
end
but this is not giving me the expected output. See below for a simple example.
A = randn(2,2)
f(x) = x'*A*x
x = randn(2)
v = randn(2)
_hvp(f,x,v) â (A+A')*v #returns false
What is the correct way to compose these two packages?
The code you posted works for me (assuming Dual
, gradient
, and partials
all come from ForwardDiff.jl):
julia> using ForwardDiff: Dual, gradient, partials
julia> function _hvp(f::F, x::S, v::S) where {F, S<:AbstractVector{<:AbstractFloat}}
dual = Dual.(x, v)
return partials.(gradient(f, dual), 1)
end
_hvp (generic function with 1 method)
julia> A = randn(2, 2); x = randn(2); v = randn(2);
julia> f(x) = x' * A * x
f (generic function with 1 method)
julia> _hvp(f, x, v) â (A + A') * v
true
Itâs not clear from your post where the interplay between ForwardDiff.jl and ReverseDiff.jl comes in. Could you elaborate?
Thanks, in my reading of this thread it only seems related in that both the ForwardDiff.jl and ReverseDiff.jl packages are being used, but there is nothing dealing with the composition of them. I think there is perhaps an ambiguity in the term âmixed modeâ. In the referenced thread, the poster just wants to obtain the gradient by combining two different modes, whereas I want higher order derivatives using mixed-mode AD. If I am missing something feel free to point it out.
Ah, yes, apologies for the ambiguity. I updated the question to clarify, but I want to use ReverseDiff.gradient
. The goal is to have a forward-over-back approach, where ReverseDiff.jl is handling the backward mode. If I were to use Zygote.gradient
instead, then the posted code would work.
Looks like a bug in ReverseDiff, even for the first derivative of this f
:
julia> ForwardDiff.gradient(f, x)
2-element Vector{Float64}:
0.7876859520316463
1.5891432962932512
julia> Zygote.gradient(f, x)
([0.7876859520316463, 1.589143296293251],)
julia> ReverseDiff.gradient(f, x) # wrong
2-element Vector{Float64}:
1.1826512915394356
1.154013797470248
julia> ReverseDiff.gradient(v -> v' * A' * v, x)
2-element Vector{Float64}:
0.7876859520316463
1.589143296293251
julia> ReverseDiff.gradient(x -> dot(x, A, x), x)
2-element Vector{Float64}:
0.7876859520316463
1.5891432962932512
julia> @which x'*A*x
*(tu::Union{Adjoint{T, var"#s967"}, Transpose{T, var"#s967"}} where {T, var"#s967"<:(AbstractVector)}, B::AbstractMatrix, v::AbstractVector)
@ LinearAlgebra ~/.julia/dev/julia/usr/share/julia/stdlib/v1.9/LinearAlgebra/src/matmul.jl:1152
I see this on Julia 1.7 and later, but not on 1.6. Thus I presume itâs related to PR 37898 for 3-arg *
, but I donât see why.
1 Like
I am on 1.6.4 and seeing that same issue.
There is this old issue on ReverseDiff.jl that may be related, I am uncertain.
I also see it on the official linux builds for v1.7.1. But I get the correct answers for f2(x) = dot(x, A*x)
and f3(x) = dot(x, A, x)
. Spooky.