Improve speed of vector-jacobian product

Hi,

I would like to improve some part of a code which relies on vector-jacobian product (vjp), ie computing the action of the adjoint of an operator. On my computer, the computation of vjp is ~100x slower than the evaluation of the linear operator itself.

I would be grateful if some of you can come up with better ideas :smiley:

Thanks a lot.

Here is a very simplified example of my problem

using Revise, Pkg, LinearAlgebra
cd(@__DIR__)
pkg"activate 2d"
using Revise, Parameters, ForwardDiff

function init(vmin, vmax, Nv)
    v = LinRange(vmin, vmax, Nv) |> collect
    Dv = v[2]-v[1]
    fv = @. exp(v)
    V = @. v^2
    return (;vmin, vmax, Nv, v, Dv, fv, V)
end

function inner(i, divergence, Nv, invDv, g, V)
    # Values
    gi  = g[i]
    gim = i > 1         ? g[i-1] : zero(gi)
    gip = (i+1) <= Nv   ? g[i+1] : zero(gi)
    # Upwind fluxes + null flux boundary conditions
    Fi  = i > 1         ? (V[i] > 0   ? V[i] * gim : V[i] * gi) : zero(gi)
    Fip = i < Nv        ? (V[i+1] > 0 ? V[i+1]* gi : V[i+1] * gip) : zero(gi)
    # Divergence
    divergence[i] = invDv * (Fip - Fi)
end

function rhs1d!(result, g, p)
    @unpack vmin, vmax, Dv, fv, Nv, V = p
    invDv = 1 / Dv

    Vtmp = V .+ (dot(g, fv) * Dv)
    for i=1:Nv
        inner(i, result, Nv, invDv, g, Vtmp)
    end
    result
end
rhs1d(g, p) = rhs1d!(similar(g), g, p)

mesh = init(-1, 3, 1000)

g0 = @. exp(-mesh.v^2/(2*0.1^2))
g0 ./= sum(g0)*mesh.Dv
# plot(mesh.v,g0)

@time rhs1d(g0, mesh);

import AbstractDifferentiation as AD
using ReverseDiff

vjp(x,p,dx) = AD.pullback_function(AD.ReverseDiffBackend(), z->rhs1d(z, p), x)(dx)[1]
@btime vjp($g0,$mesh,$g0); # 530.292 μs (33040 allocations: 1.46 MiB)
@btime rhs1d($g0, $mesh); # 2.338 μs (2 allocations: 15.88 KiB)

Reverse-mode differentiation may not understand that each iteration of this loop updates a different element of result independently — I’m guessing that it makes a copy of the result array for every iteration of the loop, before running the loop backwards in order to backpropagate the vjp.

Moreover, you are effectively doing a sparse matrix–vector multiplication in which the elements of the matrix depend on your parameters, and AD tools often struggle with backpropagating through a sparse-matrix construction for reasons I explained in this thread: Zygote.jl: How to get the gradient of sparse matrix - #6 by stevengj

I’ve typically found that you need to write a manual vJp for some step(s) in most reasonable-scale scientific problems, unless you are using a package like DiffEqFlux.jl that has done that for you. AD tools are more reliable for cookie-cutter ML-style problems where you are plugging together large components that it already knows about, and are only fiddling around the edges with small scalar functions or code written in a functional/non-mutating style (e.g. composing vectorized primitives).

3 Likes

thanks for this, it is worrying indeed. In my case I use ReverseDiff but I imagine that your remark remains.

I’ve typically found that you need to write a manual vJp for some step(s) in most reasonable-scale scientific problems,

ah ! it is a nice suggestion. I will try to decompose the problem and use AD only where necessary.