# Improve speed of vector-jacobian product

Hi,

I would like to improve some part of a code which relies on vector-jacobian product (vjp), ie computing the action of the adjoint of an operator. On my computer, the computation of vjp is ~100x slower than the evaluation of the linear operator itself.

I would be grateful if some of you can come up with better ideas Thanks a lot.

Here is a very simplified example of my problem

``````using Revise, Pkg, LinearAlgebra
cd(@__DIR__)
pkg"activate 2d"
using Revise, Parameters, ForwardDiff

function init(vmin, vmax, Nv)
v = LinRange(vmin, vmax, Nv) |> collect
Dv = v-v
fv = @. exp(v)
V = @. v^2
return (;vmin, vmax, Nv, v, Dv, fv, V)
end

function inner(i, divergence, Nv, invDv, g, V)
# Values
gi  = g[i]
gim = i > 1         ? g[i-1] : zero(gi)
gip = (i+1) <= Nv   ? g[i+1] : zero(gi)
# Upwind fluxes + null flux boundary conditions
Fi  = i > 1         ? (V[i] > 0   ? V[i] * gim : V[i] * gi) : zero(gi)
Fip = i < Nv        ? (V[i+1] > 0 ? V[i+1]* gi : V[i+1] * gip) : zero(gi)
# Divergence
divergence[i] = invDv * (Fip - Fi)
end

function rhs1d!(result, g, p)
@unpack vmin, vmax, Dv, fv, Nv, V = p
invDv = 1 / Dv

Vtmp = V .+ (dot(g, fv) * Dv)
for i=1:Nv
inner(i, result, Nv, invDv, g, Vtmp)
end
result
end
rhs1d(g, p) = rhs1d!(similar(g), g, p)

mesh = init(-1, 3, 1000)

g0 = @. exp(-mesh.v^2/(2*0.1^2))
g0 ./= sum(g0)*mesh.Dv
# plot(mesh.v,g0)

@time rhs1d(g0, mesh);

using ReverseDiff

@btime vjp(\$g0,\$mesh,\$g0); # 530.292 μs (33040 allocations: 1.46 MiB)
@btime rhs1d(\$g0, \$mesh); # 2.338 μs (2 allocations: 15.88 KiB)
``````

Reverse-mode differentiation may not understand that each iteration of this loop updates a different element of `result` independently — I’m guessing that it makes a copy of the `result` array for every iteration of the loop, before running the loop backwards in order to backpropagate the vjp.

Moreover, you are effectively doing a sparse matrix–vector multiplication in which the elements of the matrix depend on your parameters, and AD tools often struggle with backpropagating through a sparse-matrix construction for reasons I explained in this thread: Zygote.jl: How to get the gradient of sparse matrix - #6 by stevengj

I’ve typically found that you need to write a manual vJp for some step(s) in most reasonable-scale scientific problems, unless you are using a package like DiffEqFlux.jl that has done that for you. AD tools are more reliable for cookie-cutter ML-style problems where you are plugging together large components that it already knows about, and are only fiddling around the edges with small scalar functions or code written in a functional/non-mutating style (e.g. composing vectorized primitives).

3 Likes

thanks for this, it is worrying indeed. In my case I use `ReverseDiff` but I imagine that your remark remains.

I’ve typically found that you need to write a manual vJp for some step(s) in most reasonable-scale scientific problems,

ah ! it is a nice suggestion. I will try to decompose the problem and use AD only where necessary.