Discrepancy of ODE Sensitivity Analysis paper results with benchmarks

I have tried to reproduce results from the paper “A Comparison of Automatic Differentiation and Continuous Sensitivity Analysis for Derivatives of Differential Equation Solutions” (https://arxiv.org/pdf/1812.01892, Fig. 2 plot A) for Brusselator model. In my code DSAAD (i.e. differentiating whole solver using ForwardDiff) is faster than what is presented on the plot from the paper. At the same time CASA, in particular QuadratureAdjoint with EnzymeVJP is much slower compared to the results from the paper, where it is the fastest method (100x faster than ForwardDiff for 500 parameters). In my benchmark Quadrature Adjoint with EnzymeVJP is at most 2x faster than ForwardDiff for 500 parameters.

I tried my best to choose the same parameters (or similar if not explicitly stated in the paper) for benchmark, also I am aware about the hardware difference. Nevertheless relative comparison between algorithms should hold regardless of those factors.

  1. Where does the difference in performance comes from?
    1a. Why the QuadratureAdjoint adjoint does not perform as well as expected?
    1b. Did I implement something wrong?
  2. Does anyone know if the code used to generate plots in the paper can be found somewhere (for comparison)?

Plot from the paper:

Plot generated by my benchmark:

Below is the code:

using Revise
using OrdinaryDiffEq, DifferentialEquations, DiffEqParamEstim
using Optimization, OptimizationOptimJL, OptimizationBBO
using Symbolics, ADTypes, SparseConnectivityTracer, SciMLSensitivity
using LinearSolve, Sparspak

using DifferentiationInterface
import ForwardDiff, ReverseDiff, Zygote, Enzyme
import Enzyme.Forward as EnzymeForward
import Enzyme.Reverse as EnzymeReverse

using Plots, BenchmarkTools, Serialization, Measures

"""
Brusselator model for testing sensitivities with various autodiff methods.

size(du) = (2, N, N)
size(u) = (2, N, N)
size(p) = (4, N, N)

No-flux boundary conditions are applied at the edges of the grid.
"""

# Julia uses column-major order, so the first index is the fastest changing index.
bruss_f(x, y, t) = (((x - 0.3)^2 + (y - 0.6)^2) <= 0.1^2) * (t >= 1.1) * 5.0
boundary(idx, N) = idx == N+1 ? N-1 : idx == 0 ? 2 : idx

# ::Val{N} is needed to avoid EnzymeVJP crash
function f(dstate, state, parr, t, ::Val{N}) where {N}
    xyd = range(0.0, stop=1.0, length=N)
    dxy = step(xyd)

    arr = reshape(state, (2, N, N)) # does not copy
    u = view(arr, 1, :, :)
    v = view(arr, 2, :, :)

    darr = reshape(dstate, (2, N, N)) # does not copy
    du = view(darr, 1, :, :)
    dv = view(darr, 2, :, :)

    p = reshape(parr, (4, N, N)) # does not copy

    @inbounds for I in CartesianIndices((N, N))
        i, j = Tuple(I)
        x, y = xyd[I[1]], xyd[I[2]]
        p1, p2, p3, p4 = p[:, i, j]
        uv = u[i, j]^2 * v[i, j]

        # ∂²u/∂x² + ∂²u/∂y²
        ux1 = u[boundary(i-1, N), j]
        ux2 = u[boundary(i+1, N), j]
        uy1 = u[i, boundary(j-1, N)]
        uy2 = u[i, boundary(j+1, N)]
        Δu = (ux1 + ux2 + uy1 + uy2 - 4u[i, j]) / dxy^2

        # ∂²v/∂x² + ∂²v/∂y²
        vx1 = v[boundary(i-1, N), j]
        vx2 = v[boundary(i+1, N), j]
        vy1 = v[i, boundary(j-1, N)]
        vy2 = v[i, boundary(j+1, N)]
        Δv = (vx1 + vx2 + vy1 + vy2 - 4v[i, j]) / dxy^2

        # ∂u/∂t
        du[i, j] = p2 + uv - (p1 + 1) * u[i, j] + p3 * Δu + bruss_f(x, y, t)

        # ∂v/∂t
        dv[i, j] = p1 * u[i, j] - uv + p4 * Δv
    end
end

function get_u0(N)
    xyd = range(0.0, stop=1.0, length=N)
    state = zeros(2, N, N)
    @inbounds for I in CartesianIndices((N, N))
        x = xyd[I[1]]
        y = xyd[I[2]]
        state[1, I] = 22 * (y * (1 - y))^(3 / 2)
        state[2, I] = 27 * (x * (1 - x))^(3 / 2)
    end
    return state[:] # flatten the array
end

function get_p(params, N)
    @assert length(params) == 4 "Expected 4 parameters, got $(length(params))"
    p = zeros(4, N, N)
    @inbounds for I in CartesianIndices((N, N))
        p[:, I] = params
    end
    return p[:] # flatten the array
end

function jac_sparsity_adtypes(u0, func, p)
    du0 = similar(u0)
    return ADTypes.jacobian_sparsity((du, u) -> func(du, u, p, 0.0), du0, u0, TracerSparsityDetector())
end

function jac_sparsity_symbolics(u0, func, p)
    du0 = similar(u0)
    return float.(Symbolics.jacobian_sparsity((du, u) -> func(du, u, p, 0.0), du0, u0))
end

# Problem definition
abstol=1e-5
reltol=1e-5
tspan = (0.0, 10.0)

algo, algo_name = Rodas5(), "Rodas5"

# Sensitivity algorithms list
sensealg_ga = GaussAdjoint()
sensealg_gaf = GaussAdjoint(; autodiff=true, autojacvec=false) # FiniteDifferences when autojacvec=false
sensealg_gad = GaussAdjoint(; autodiff=true, autojacvec=true) # ForwardDiff when autojacvec=true
sensealg_gae = GaussAdjoint(; autodiff=true, autojacvec=EnzymeVJP())
sensealg_ia = InterpolatingAdjoint()
sensealg_iaf = InterpolatingAdjoint(; autodiff=true, autojacvec=false)
sensealg_iad = InterpolatingAdjoint(; autodiff=true, autojacvec=true)
sensealg_iae = InterpolatingAdjoint(; autodiff=true, autojacvec=EnzymeVJP())
sensealg_qa = QuadratureAdjoint()
sensealg_qaf = QuadratureAdjoint(; autodiff=true, autojacvec=false)
sensealg_qad = QuadratureAdjoint(; autodiff=true, autojacvec=true) 
sensealg_qae = QuadratureAdjoint(; autodiff=true, autojacvec=EnzymeVJP())

# Define the sensitivity algorithms to be benchmarked
config = [
    (framework = AutoZygote(), sensealg = sensealg_ga, name = "GaussAdjoint"),
    (framework = AutoZygote(), sensealg = sensealg_gaf, name = "GaussAdjoint FiniteDifferences"),
    #(framework = AutoZygote(), sensealg = sensealg_gad, name = "GaussAdjoint ForwardDiff"), # ERROR: autojacvec choice true is not supported by GaussAdjoint
    (framework = AutoZygote(), sensealg = sensealg_gae, name = "GaussAdjoint EnzymeVJP"),
    (framework = AutoZygote(), sensealg = sensealg_ia, name = "InterpolatingAdjoint"),
    #(framework = AutoZygote(), sensealg = sensealg_iaf, name = "InterpolatingAdjoint FiniteDifferences"), # Does not work with Rodas5, for some reason throws ForwardDiff error
    #(framework = AutoZygote(), sensealg = sensealg_iad, name = "InterpolatingAdjoint ForwardDiff"), # Does not work with Rodas5
    (framework = AutoZygote(), sensealg = sensealg_iae, name = "InterpolatingAdjoint EnzymeVJP"),
    (framework = AutoZygote(), sensealg = sensealg_qa, name = "QuadratureAdjoint"),
    (framework = AutoZygote(), sensealg = sensealg_qaf, name = "QuadratureAdjoint FiniteDifferences"),
    (framework = AutoZygote(), sensealg = sensealg_qad, name = "QuadratureAdjoint ForwardDiff"),
    (framework = AutoZygote(), sensealg = sensealg_qae, name = "QuadratureAdjoint EnzymeVJP"),
    (framework = AutoForwardDiff(), sensealg = nothing, name = "ForwardDiff"),
]

sens_test(framework, sensealg, p, loss) = DifferentiationInterface.gradient((p) -> loss(p, sensealg), framework, p) # has to be in global scope so that the @benchmark macro can capture it

config, results, algo = deserialize("data/sensitivities_brusselator_benchmark_$(algo_name)_post.jls") # run only if you want to load the results from a previous run

results = Dict{Int, Any}()
@time for N in 2:12
    println("### Running benchmark for algo = $(algo_name), N = $N ###")
    p = get_p([3.4, 1.0, 10.0, 10.0], N)
    u0 = get_u0(N)

    f_ode(du, u, p, t) = f(du, u, p, t, Val(N))
    jac_sparsity = jac_sparsity_adtypes(u0, f_ode, p)
    fun = ODEFunction(f_ode, jac_prototype=jac_sparsity)
    prob = ODEProblem(fun, u0, tspan, p)

    function loss(p, sensealg)
        prob_remake = remake(prob, u0 = eltype(p).(u0), p = p)
        sum(solve(prob_remake, algo, saveat=0.1, sensealg=sensealg, abstol=abstol, reltol=reltol, maxiters=1000_000_000))
    end

    function bench_sens(framework, sensealg)
        bench = @benchmark sens_test($framework, $sensealg, $p, $loss)
        m = mean(bench)
        s = std(bench)
        return (mean = m.time * 1e-9, std = s.time * 1e-9)
    end

    # warm-up: compile all of these once
    println("Warming up...")
    @time for cfg in config
        println("  Warming up $(cfg.name)")
        sens_test(cfg.framework, cfg.sensealg, p, loss)
    end
    println("Warming up complete!")

    res = Array{Any}(undef, length(config))
    println("Benchmarking ...")
    @time for (i, cfg) in enumerate(config)
        println("  Benchmarking $(cfg.name) with N = $N")
        res[i] = bench_sens(cfg.framework, cfg.sensealg)
    end

    results[N] = res
    serialize("data/sensitivities_brusselator_benchmark_$(algo_name)_post.jls", (config, results, algo)) # somehow pwd() is project root
    println("Results saved for N = $N")
end

serialize("data/sensitivities_brusselator_benchmark_$(algo_name)_post", (config, results, algo))

results

function plot_results(config, results, algo_name)
    markers = [:circle, :square, :utriangle, :star5]

    fig = plot(title="Brusselator (algo = $algo_name)", 
        xlabel="params", 
        ylabel="Time (s)", 
        legend=:bottomright,
        palette=:twelvebitrainbow,
        size=(1200, 900),
        left_margin=5mm, 
        yscale=:log10,
        # make ticks every 10^n
        xticks = collect(0:20:1000),
        yticks = collect(10.0 .^ (-2:2)),
    )

    Ns = sort(collect(keys(results)))
    params = @. Ns^2 * 4
    for (i, cfg) in enumerate(config)
        times = [results[N][i].mean for N in Ns]
        stds = [results[N][i].std for N in Ns]
        label = cfg.name
        plot!(fig, params, times; 
            label=label, lw=2, fillalpha=0.3, 
            marker=markers[(i-1) % length(markers) + 1], markersize=3, markerstrokewidth=0
        )
    end

    fig
end

plt = plot_results(config, results, algo_name)

Can you show what your plot looks like in log-log scale, for easier comparison?

For reference, the complete code is maintained here:

That said it probably needs an update. Enzyme now allocates a lot more than it used to due to some safety measures that had to be added in a very recent update.

which is due to a bug in Enzyme that makes it so the core loop cannot be zero allocations on many problems:

The potential issue does not exist in this case (it only shows up in specific cases with closures), but I don’t know of a function that can prove that (i.e. that the function has no mutable elements, maybe checking isbits(f) is enough to fix a few regressions?) but the workaround effects the Enzyme performance in all cases so for the last month things have been a bit slower. If I had to point to something that could be the cause, I would point here, and I know @wsmoses said there’s a quick fix for this but I don’t think it has been merged yet.

The last SciMLSensitivity benchmark was ran with [1ed8b502] SciMLSensitivity v7.64.0 while v7.81.0 would be the one with the safety measure / performance regression

So one thing you can try is just fixing SciMLSensitivity on your system to v7.80.0 and see if that reverts.

I setup a run for the version bump here:

the build artifacts will say whether it’s a general regression or a difference in codes. But my money is on needing to get this make_zero! stuff fixed in Enzyme since I know that regression will be non-trivial for many cases.


Here is the plot with log scale on both axis.

Just to confirm @ChrisRackauckas , there was no change in Enzyme that caused a performance loss, it was how you were using it in scimlsensitivity that changed?

Thanks for the insight regarding Enzyme.

I still have issue with ForwardDiff.

I checked out the code from the SciMLBenchmarks, where the results for ForwardDiff for N=5 and N=12 are respectively:

(n, t) = (5, 0.520653741)
...
(n, t) = (12, 343.941457935)

My benchmark results are:
N=5, time 0.32s
N=12, time 15.01s

Given that results for N=5 are comparable and for N=12 my results are ~23 times better, it seems as if in my benchmark ForwardDiff was somehow scaling better.

Do you have any idea why this might be?

I am not familiar with ODE sensitivity analysis in general but it seems that your benchmark uses sparse Jacobians. OrdinaryDiffEq recently switched to a new DI-based sparse autodiff pipeline, could that be a factor for some of the changes you observed? This might especially be true if sparsity detection is counted into the measurements at some point.
Note that DI has nothing to do with SciMLSensitivity, it’s not used there.

Here I have prepared a plot to better illustrate the issue.

Is the SciMLBenchmark curve in your plot obtained with the up-to-date package versions, or with the package versions from the SciMLBenchmarks website?

The sparsity detection is not a part of the time measurement as it is not evaluated inside loss function (in fact it cannot be because autodiff will throw errors when trying to differentiate it).

Though you are right that the use of sparse jacobian might be the explanation for the different scaling. I will try to test my code again without the sparse jacobian to see if anything changes.

I grabbed the values straight from the website. I copied the exact numbers, so you should be able to do Ctrl+F and find the correct section easily, if that would help you.

I only now noticed that the SciMLBenchmark for CSA algorithms uses a custom jacobian (brusselator_jac implemented by hand) rather than relying on autodiff to compute it, which is what I am doing.

This could explain why CSA methods are so fast in SciMLBenchmark.

Try changing to SciMLBase.FullSpecialize in the ODEProblem. That would change the chunk sizes in autodiff.

@ChrisRackauckas your feature PR is here: feat: Add remake_zero by wsmoses · Pull Request #2436 · EnzymeAD/Enzyme.jl · GitHub

For those looking on the outside, this is not a bug in Enzyme. SciMLSensitivity recently added some extra allocations when calling Enzyme.

@ChrisRackauckas wants to use an Enzyme function make_zero! to zero the allocation. As it says in the docstring, “Only applicable for mutable types T.” As a safety measure it throws an error if it is given an immutable type it cannot zero.

For example:

x = (3.0, [5.0])
Enzyme.make_zero!(x)
# Throws error cannot set 3.0 to zero in-place, as it contains differentiable values in immutable positions

Tuples are immutable so we cannot zero the 3.0 element at the start.

The linked PR adds a new remake_zero! function which will not throw an error and is only correct if the original allocation was created by make_zero to begin with.

2 Likes

We cannot safely call autodiff on the same f unless the derivatives within the shadow for f are cleared. If they aren’t cleared, we have cases that will calculate the wrong value. However, it’s not possible to universally apply make_zero! to reset back to the the state you get from make_zero, in which case the only way to be universally correct is to always call make_zero before each autodiff, and that’s what allocates. Removing that will break some cases, either make_zero! errors because of mutating immutables (which it shouldn’t need to, of course the values are still zero from a previous autodiff) or we will simply compute the wrong gradient, which is not something we will do just to get better performance. So until that make_zero! issue is handled we have to take the hit.

If you want to call it not a bug in Enzyme, okay. But it’s an issue such that repeated calls to Enzyme either have to allocate or be incorrect for the reason explained above, so we choose to allocate until we have a fix.

In any case a version of Enzyme with the new function (called remake_zero!) was released earlier today

3 Likes

The regression was found on the benchmarking server, but it’s a lot smaller than what you are showing.

Regressed:

Before:

But

with Billy’s fix there we should have it fixed in no time.

So @mra your code running faster with ForwardDiff, did you test if that was a difference of FullSpecialize?