I would like to improve some part of a code which relies on vector-jacobian product (vjp), ie computing the action of the adjoint of an operator. On my computer, the computation of vjp is ~100x slower than the evaluation of the linear operator itself.
I would be grateful if some of you can come up with better ideas
Thanks a lot.
Here is a very simplified example of my problem
using Revise, Pkg, LinearAlgebra
pkg"activate 2d"
using Revise, Parameters, ForwardDiff
function init(vmin, vmax, Nv)
v = LinRange(vmin, vmax, Nv) |> collect
Dv = v[2]-v[1]
fv = @. exp(v)
V = @. v^2
return (;vmin, vmax, Nv, v, Dv, fv, V)
function inner(i, divergence, Nv, invDv, g, V)
# Values
gi = g[i]
gim = i > 1 ? g[i-1] : zero(gi)
gip = (i+1) <= Nv ? g[i+1] : zero(gi)
# Upwind fluxes + null flux boundary conditions
Fi = i > 1 ? (V[i] > 0 ? V[i] * gim : V[i] * gi) : zero(gi)
Fip = i < Nv ? (V[i+1] > 0 ? V[i+1]* gi : V[i+1] * gip) : zero(gi)
# Divergence
divergence[i] = invDv * (Fip - Fi)
function rhs1d!(result, g, p)
@unpack vmin, vmax, Dv, fv, Nv, V = p
invDv = 1 / Dv
Vtmp = V .+ (dot(g, fv) * Dv)
for i=1:Nv
inner(i, result, Nv, invDv, g, Vtmp)
rhs1d(g, p) = rhs1d!(similar(g), g, p)
mesh = init(-1, 3, 1000)
g0 = @. exp(-mesh.v^2/(2*0.1^2))
g0 ./= sum(g0)*mesh.Dv
# plot(mesh.v,g0)
@time rhs1d(g0, mesh);
import AbstractDifferentiation as AD
using ReverseDiff
vjp(x,p,dx) = AD.pullback_function(AD.ReverseDiffBackend(), z->rhs1d(z, p), x)(dx)[1]
@btime vjp($g0,$mesh,$g0); # 530.292 μs (33040 allocations: 1.46 MiB)
@btime rhs1d($g0, $mesh); # 2.338 μs (2 allocations: 15.88 KiB)