I was able to write function to handle arrays of ForwardDiff.Dual, since in theory I should be able to compute the derivative of a certain function much more efficiently than just propagating Dual numbers through it. I have been able to write most of this function quite efficiently, unpacking the values and partials from the array of Dual, handling each part separately, however, filling the array back up at the end is about 1000X slower than all the other parts of the function, and slower than the function I am differentiating!
julia> using ForwardDiff: Dual, partials, value
julia> using BenchmarkTools
julia> @btime du = [Dual{nothing}(0., 1., 2., 3.) for x ∈ 1:64, y ∈ 1:64];
5.568 μs (2 allocations: 128.05 KiB)
julia> # extract values, partials
julia> @btime vs = value.(du);
2.953 μs (5 allocations: 32.11 KiB)
julia> function extractPartials(du)
nps = length(partials(du[1,1]))
ps = zeros(size(du)..., nps)
@inbounds for I ∈ CartesianIndices(du), j ∈ 1:nps
ps[I, j] = partials(du[I], j)
end
return ps
end
extractPartials (generic function with 1 method)
julia> @btime ps = extractPartials(du);
9.420 μs (2 allocations: 96.06 KiB)
julia> function fillArrDual!(du, vs, ps)
@inbounds for I ∈ CartesianIndices(du)
du[I] = Dual{nothing}(vs[I], ps[I, :]...)
end
end
fillArrDual! (generic function with 1 method)
julia> @btime fillArrDual!(du, vs, ps);
2.726 ms (57344 allocations: 1.44 MiB)
I’m not sure why it is allocating so much memory! Since Duals are immutable, it can’t just overwrite the entries of du
in place, so I get that there will be some allocation for each Dual, but then why is du = [Dual{nothing}(0., 1., 2., 3.) for x ∈ 1:64, y ∈ 1:64]
so much faster with far fewer allocations? Getting rid of the splat helps quite a bit:
julia> function fillArrDual!(du, vs, ps)
@inbounds for I ∈ CartesianIndices(du)
du[I] = Dual{nothing}(vs[I], ntuple(j -> ps[I, j], size(ps)[end]))
end
end
fillArrDual! (generic function with 1 method)
julia> @btime fillArrDual!(du, vs, ps);
530.914 μs (12288 allocations: 384.00 KiB)
but this is still way too slow for my application. I have tried looking through the ForwardDiff source, but I am a little out of my depth identifying exactly what is going wrong here