FiniteDiff, ForwardDiff and Enzyme disagree: who is right?

As mentioned in the title, FiniteDiff, ForwardDiff and Enzyme return different results for the following code. Any ideas what is going wrong and which partial derivatives are correct?

using Pkg
Pkg.activate(temp = true)
Pkg.add(["LinearAlgebra", "ComponentArrays", "Enzyme", "ForwardDiff", "FiniteDiff"])

struct QLearner{T}
    q::T
end
QLearner() = QLearner(zeros(10, 3))

function logp(data, m::QLearner, parameters)
    q = m.q .= parameters.q₀
    logp = 0.
    for (; s, a, s′, r, done) in data
        logp += logsoftmax(1.6, q, s, a)
        td_error = r + .99 * maximum(view(q, s′, :)) - q[s, a]
        q[s, a] += .1 * td_error
    end
    logp
end

logsoftmax(β, q, s, a) = β * q[s, a] - logsumexp(β, view(q, s, :))
function logsumexp(β, v)
    m = β * maximum(v)
    sumexp = zero(eltype(v))
    for vᵢ in v
        sumexp += exp(β * vᵢ - m)
    end
    m + log(sumexp)
end

using ComponentArrays

model = QLearner()
p = ComponentArray((; q₀ = zeros(10, 3)))

data = [(s = 1, a = 1, s′ = 1, r = 0.0, done = 0), (s = 1, a = 2, s′ = 8, r = 1.3292143757496304, done = 0), (s = 8, a = 1, s′ = 8, r = 0.4959145376981169, done = 0), (s = 8, a = 3, s′ = 9, r = 0.5496578578701407, done = 0), (s = 9, a = 3, s′ = 8, r = -0.8020690510161607, done = 0), (s = 8, a = 3, s′ = 2, r = 0.5496578578701407, done = 0), (s = 2, a = 2, s′ = 7, r = -1.160974404564671, done = 0), (s = 7, a = 1, s′ = 5, r = 0.44575583592069273, done = 0), (s = 5, a = 1, s′ = 3, r = -0.11691331382831936, done = 0), (s = 3, a = 3, s′ = 3, r = -1.6057824171796733, done = 0), (s = 3, a = 2, s′ = 5, r = -1.41550921655304, done = 0)]

using FiniteDiff, Enzyme, ForwardDiff, LinearAlgebra
julia> dp_fd = FiniteDiff.finite_difference_gradient(x -> logp(data, model, x), p)
ComponentVector{Float64}(q₀ = [0.5602665047905974 0.5069334100771912 -1.0930665899018253; -0.5333333333752303 1.0666666667504605 -0.5333333333752303; … ; -0.4813755911262806 -0.4813755911262806 1.118624408852736; 0.0 0.0 0.0])

julia> dp_fw = ForwardDiff.gradient(x -> logp(data, QLearner(zeros(eltype(x), 10, 3)), x), p)
ComponentVector{Float64}(q₀ = [0.5866666666666669 0.5333333333333335 -1.119466666666667; -0.5333333333333333 1.0666666666666669 -0.5333333333333334; … ; -0.5333333333333333 -0.5333333333333333 1.170582168402159; 0.0 0.0 0.0])

julia> dp = zero(p);

julia> autodiff(Reverse, logp, Const(data), Duplicated(model, Enzyme.make_zero(model)), Duplicated(p, dp))

((nothing, nothing, nothing),)

julia> dp
ComponentVector{Float64}(q₀ = [0.5338666666666663 0.5333333333333335 -1.0666666666666667; -0.5333333333333335 1.0666666666666669 -0.5333333333333333; … ; -0.4294178315978412 -0.5333333333333333 1.0666666666666669; 0.0 0.0 0.0])

julia> norm(dp_fd - dp_fw)/norm(dp_fw) # FiniteDiff - ForwardDiff
0.03695183098661434

julia> norm(dp - dp_fw)/norm(dp_fw)    # Enzyme - ForwardDiff
0.06034204206694352

julia> norm(dp_fd - dp)/norm(dp)    # FiniteDiff - Enzyme
0.036998623660816486

It looks like you are differentiating a stochastic program, so the nature of what you are doing is something you have to think carefully about. (e.g. a finite-difference method may have a diverging variance as the step size goes to zero.) So it’s not so surprising to me that different methods here give different results; in fact, there are many different choices for how to define a “derivative” of a stochastic program.

There are tools in Julia that can help with stochastic derivatives, see e.g. StochasticAD.jl, but you should understand what you are doing first.

See e.g. this lecture from our Matrix Calculus class, the accompanying lecture notes, and the further-reading links.

1 Like

No, there is no stochasticity in the function I want to differentiate. Rand is only used to generate some artificial data. Given this data, I want to compute the derivative of the log-likelihood function logp(data, model, parameters) with respect to parameters.

It would be easier to help you if you stripped out all of the extraneous code and just show the function you are differentiating. e.g. generate 50 data points and simply paste data = [...] into your post.

I would also show the relative norm of the difference, e.g. instead of showing abbreviated data for dp - dp_fw, show the L2 relative error norm(dp - dp_fw) / norm(dp).

Thanks for helping. I changed the original post.

I don’t see an obvious reason why one would be right and the others not, very puzzling indeed.
I just want to suggest the use of DifferentiationInterface.jl, to quickly check with more backends. It’s a typical use case where you don’t want to spend time learning the syntax that is specific to each autodiff library.

Could it be because the logp function modifies the model?

Why should this be problematic?

maximum in logp is the culprit. Interestingly, ForwardDiff and Enzyme give the same result, when I use findmax instead of maximum. ForwardDiff gives a different result for maximum than for findmax. Should this be considered a bug in ForwardDiff?

Is more than one value in the array equal to the maximum? If so, the derivative isn’t well defined there.

5 Likes