I am trying out automatic differentiation on part of my code in place of the analytic derivative, but ForwardDiff.jacobian of softmax seems to take twice as long as the analytic Jacobian. I’m not sure if this is an expected result or I’m doing something to slow down ForwardDiff. Any feedback is appreciated.
Below is a MWE with data representative of my problem. The loop within the functions is also representative (i.e. compute a list of Jacobians for subsets of a larger array). I tried passing an anonymous function to the jacobian, as well as a function that doesn’t compute softmax in-place (both commented out). I know that pushing to an Array{Any}
is not great, but that’s just for the MWE.
using ForwardDiff
using BenchmarkTools
# softmax copied from StatsFuns without convert
function softmax2(x::AbstractArray{T}) where {T<:Real}
n = length(x)
u = maximum(x)
s = 0.
r = similar(x)
@inbounds for i = 1:n
s += (r[i] = exp(x[i] - u))
end
invs = inv(s)
@inbounds for i = 1:n
r[i] *= invs
end
r
end
function softmax2!(r::AbstractArray{T}, x::AbstractArray{T}) where {T<:Real}
n = length(x)
length(r) == n || throw(DimensionMismatch("Inconsistent array lengths."))
u = maximum(x)
s = 0.
@inbounds for i = 1:n
s += (r[i] = exp(x[i] - u))
end
invs = inv(s)
@inbounds for i = 1:n
r[i] *= invs
end
r
end
function jac1(x, beta)
result = Array{Any}(undef, 0)
for i in 1:10:(size(x)[1] - 9)
thisx = x[i:(i + 9),:]
pp!(y::Vector, b::Vector) = softmax2!(y, thisx * b)
# pp! = (y, b) -> softmax2!(y, thisx * b)
p = zeros(Float64, 10)
lambda = ForwardDiff.jacobian(pp!, p, beta)::Array{Float64, 2}
# pp = b -> softmax2(thisx * b)
# 𝐩 = pp(beta)
# lambda = ForwardDiff.jacobian(pp, beta)
push!(result, lambda)
end
return result
end
function jac2(x, beta)
result = Array{Any}(undef, 0)
for i in 1:10:(size(x)[1] - 9)
thisx = x[i:(i + 9),:]
p = softmax2(thisx * beta)
lambda = similar(thisx)
for k in 1:length(beta)
last_term = 0.
for s in 1:10
last_term += thisx[s, k] * p[s]
end
for t in 1:10
lambda[t, k] = p[t] * (thisx[t, k] - last_term)
end
end
push!(result, lambda)
end
return result
end
X = [0.398106 0.496961; -0.612026 -0.224875; 0.34112 -1.11714;
-1.12936 -0.394995; 1.43302 1.54983; 1.9804 -0.743514; -0.367221 -2.33171;
-1.04413 0.812245; 0.56972 -0.501311; -0.135055 -0.510887;
2.40162 -1.21536; -0.03924 -0.0225586; 0.689739 0.701239;
0.0280022 -0.587482; -0.743273 -0.606728; 0.188792 1.09664;
-1.80496 -0.24751; 1.46555 -0.159902; 0.153253 -0.625778;
2.17261 0.900435]
β = [0.924499; 0.869371]
X = repeat(X, outer = (5, 2))
β = repeat(β, 2)
display(isapprox(jac1(X, β), jac2(X, β)))
@btime jac1(X, β)
@btime jac2(X, β)
Output
true 8.137 μs (94 allocations: 25.50 KiB)
4.086 μs (44 allocations: 12.22 KiB)