As mentioned in the title, FiniteDiff, ForwardDiff and Enzyme return different results for the following code. Any ideas what is going wrong and which partial derivatives are correct?
using Pkg
Pkg.activate(temp = true)
Pkg.add(["LinearAlgebra", "ComponentArrays", "Enzyme", "ForwardDiff", "FiniteDiff"])
struct QLearner{T}
q::T
end
QLearner() = QLearner(zeros(10, 3))
function logp(data, m::QLearner, parameters)
q = m.q .= parameters.q₀
logp = 0.
for (; s, a, s′, r, done) in data
logp += logsoftmax(1.6, q, s, a)
td_error = r + .99 * maximum(view(q, s′, :)) - q[s, a]
q[s, a] += .1 * td_error
end
logp
end
logsoftmax(β, q, s, a) = β * q[s, a] - logsumexp(β, view(q, s, :))
function logsumexp(β, v)
m = β * maximum(v)
sumexp = zero(eltype(v))
for vᵢ in v
sumexp += exp(β * vᵢ - m)
end
m + log(sumexp)
end
using ComponentArrays
model = QLearner()
p = ComponentArray((; q₀ = zeros(10, 3)))
data = [(s = 1, a = 1, s′ = 1, r = 0.0, done = 0), (s = 1, a = 2, s′ = 8, r = 1.3292143757496304, done = 0), (s = 8, a = 1, s′ = 8, r = 0.4959145376981169, done = 0), (s = 8, a = 3, s′ = 9, r = 0.5496578578701407, done = 0), (s = 9, a = 3, s′ = 8, r = -0.8020690510161607, done = 0), (s = 8, a = 3, s′ = 2, r = 0.5496578578701407, done = 0), (s = 2, a = 2, s′ = 7, r = -1.160974404564671, done = 0), (s = 7, a = 1, s′ = 5, r = 0.44575583592069273, done = 0), (s = 5, a = 1, s′ = 3, r = -0.11691331382831936, done = 0), (s = 3, a = 3, s′ = 3, r = -1.6057824171796733, done = 0), (s = 3, a = 2, s′ = 5, r = -1.41550921655304, done = 0)]
using FiniteDiff, Enzyme, ForwardDiff, LinearAlgebra
julia> dp_fd = FiniteDiff.finite_difference_gradient(x -> logp(data, model, x), p)
ComponentVector{Float64}(q₀ = [0.5602665047905974 0.5069334100771912 -1.0930665899018253; -0.5333333333752303 1.0666666667504605 -0.5333333333752303; … ; -0.4813755911262806 -0.4813755911262806 1.118624408852736; 0.0 0.0 0.0])
julia> dp_fw = ForwardDiff.gradient(x -> logp(data, QLearner(zeros(eltype(x), 10, 3)), x), p)
ComponentVector{Float64}(q₀ = [0.5866666666666669 0.5333333333333335 -1.119466666666667; -0.5333333333333333 1.0666666666666669 -0.5333333333333334; … ; -0.5333333333333333 -0.5333333333333333 1.170582168402159; 0.0 0.0 0.0])
julia> dp = zero(p);
julia> autodiff(Reverse, logp, Const(data), Duplicated(model, Enzyme.make_zero(model)), Duplicated(p, dp))
((nothing, nothing, nothing),)
julia> dp
ComponentVector{Float64}(q₀ = [0.5338666666666663 0.5333333333333335 -1.0666666666666667; -0.5333333333333335 1.0666666666666669 -0.5333333333333333; … ; -0.4294178315978412 -0.5333333333333333 1.0666666666666669; 0.0 0.0 0.0])
julia> norm(dp_fd - dp_fw)/norm(dp_fw) # FiniteDiff - ForwardDiff
0.03695183098661434
julia> norm(dp - dp_fw)/norm(dp_fw) # Enzyme - ForwardDiff
0.06034204206694352
julia> norm(dp_fd - dp)/norm(dp) # FiniteDiff - Enzyme
0.036998623660816486