I have tried to reproduce results from the paper “A Comparison of Automatic Differentiation and Continuous Sensitivity Analysis for Derivatives of Differential Equation Solutions” (https://arxiv.org/pdf/1812.01892, Fig. 2 plot A) for Brusselator model. In my code DSAAD (i.e. differentiating whole solver using ForwardDiff) is faster than what is presented on the plot from the paper. At the same time CASA, in particular QuadratureAdjoint with EnzymeVJP is much slower compared to the results from the paper, where it is the fastest method (100x faster than ForwardDiff for 500 parameters). In my benchmark Quadrature Adjoint with EnzymeVJP is at most 2x faster than ForwardDiff for 500 parameters.
I tried my best to choose the same parameters (or similar if not explicitly stated in the paper) for benchmark, also I am aware about the hardware difference. Nevertheless relative comparison between algorithms should hold regardless of those factors.
- Where does the difference in performance comes from?
1a. Why the QuadratureAdjoint adjoint does not perform as well as expected?
1b. Did I implement something wrong? - Does anyone know if the code used to generate plots in the paper can be found somewhere (for comparison)?
Plot from the paper:
Plot generated by my benchmark:
Below is the code:
using Revise
using OrdinaryDiffEq, DifferentialEquations, DiffEqParamEstim
using Optimization, OptimizationOptimJL, OptimizationBBO
using Symbolics, ADTypes, SparseConnectivityTracer, SciMLSensitivity
using LinearSolve, Sparspak
using DifferentiationInterface
import ForwardDiff, ReverseDiff, Zygote, Enzyme
import Enzyme.Forward as EnzymeForward
import Enzyme.Reverse as EnzymeReverse
using Plots, BenchmarkTools, Serialization, Measures
"""
Brusselator model for testing sensitivities with various autodiff methods.
size(du) = (2, N, N)
size(u) = (2, N, N)
size(p) = (4, N, N)
No-flux boundary conditions are applied at the edges of the grid.
"""
# Julia uses column-major order, so the first index is the fastest changing index.
bruss_f(x, y, t) = (((x - 0.3)^2 + (y - 0.6)^2) <= 0.1^2) * (t >= 1.1) * 5.0
boundary(idx, N) = idx == N+1 ? N-1 : idx == 0 ? 2 : idx
# ::Val{N} is needed to avoid EnzymeVJP crash
function f(dstate, state, parr, t, ::Val{N}) where {N}
xyd = range(0.0, stop=1.0, length=N)
dxy = step(xyd)
arr = reshape(state, (2, N, N)) # does not copy
u = view(arr, 1, :, :)
v = view(arr, 2, :, :)
darr = reshape(dstate, (2, N, N)) # does not copy
du = view(darr, 1, :, :)
dv = view(darr, 2, :, :)
p = reshape(parr, (4, N, N)) # does not copy
@inbounds for I in CartesianIndices((N, N))
i, j = Tuple(I)
x, y = xyd[I[1]], xyd[I[2]]
p1, p2, p3, p4 = p[:, i, j]
uv = u[i, j]^2 * v[i, j]
# ∂²u/∂x² + ∂²u/∂y²
ux1 = u[boundary(i-1, N), j]
ux2 = u[boundary(i+1, N), j]
uy1 = u[i, boundary(j-1, N)]
uy2 = u[i, boundary(j+1, N)]
Δu = (ux1 + ux2 + uy1 + uy2 - 4u[i, j]) / dxy^2
# ∂²v/∂x² + ∂²v/∂y²
vx1 = v[boundary(i-1, N), j]
vx2 = v[boundary(i+1, N), j]
vy1 = v[i, boundary(j-1, N)]
vy2 = v[i, boundary(j+1, N)]
Δv = (vx1 + vx2 + vy1 + vy2 - 4v[i, j]) / dxy^2
# ∂u/∂t
du[i, j] = p2 + uv - (p1 + 1) * u[i, j] + p3 * Δu + bruss_f(x, y, t)
# ∂v/∂t
dv[i, j] = p1 * u[i, j] - uv + p4 * Δv
end
end
function get_u0(N)
xyd = range(0.0, stop=1.0, length=N)
state = zeros(2, N, N)
@inbounds for I in CartesianIndices((N, N))
x = xyd[I[1]]
y = xyd[I[2]]
state[1, I] = 22 * (y * (1 - y))^(3 / 2)
state[2, I] = 27 * (x * (1 - x))^(3 / 2)
end
return state[:] # flatten the array
end
function get_p(params, N)
@assert length(params) == 4 "Expected 4 parameters, got $(length(params))"
p = zeros(4, N, N)
@inbounds for I in CartesianIndices((N, N))
p[:, I] = params
end
return p[:] # flatten the array
end
function jac_sparsity_adtypes(u0, func, p)
du0 = similar(u0)
return ADTypes.jacobian_sparsity((du, u) -> func(du, u, p, 0.0), du0, u0, TracerSparsityDetector())
end
function jac_sparsity_symbolics(u0, func, p)
du0 = similar(u0)
return float.(Symbolics.jacobian_sparsity((du, u) -> func(du, u, p, 0.0), du0, u0))
end
# Problem definition
abstol=1e-5
reltol=1e-5
tspan = (0.0, 10.0)
algo, algo_name = Rodas5(), "Rodas5"
# Sensitivity algorithms list
sensealg_ga = GaussAdjoint()
sensealg_gaf = GaussAdjoint(; autodiff=true, autojacvec=false) # FiniteDifferences when autojacvec=false
sensealg_gad = GaussAdjoint(; autodiff=true, autojacvec=true) # ForwardDiff when autojacvec=true
sensealg_gae = GaussAdjoint(; autodiff=true, autojacvec=EnzymeVJP())
sensealg_ia = InterpolatingAdjoint()
sensealg_iaf = InterpolatingAdjoint(; autodiff=true, autojacvec=false)
sensealg_iad = InterpolatingAdjoint(; autodiff=true, autojacvec=true)
sensealg_iae = InterpolatingAdjoint(; autodiff=true, autojacvec=EnzymeVJP())
sensealg_qa = QuadratureAdjoint()
sensealg_qaf = QuadratureAdjoint(; autodiff=true, autojacvec=false)
sensealg_qad = QuadratureAdjoint(; autodiff=true, autojacvec=true)
sensealg_qae = QuadratureAdjoint(; autodiff=true, autojacvec=EnzymeVJP())
# Define the sensitivity algorithms to be benchmarked
config = [
(framework = AutoZygote(), sensealg = sensealg_ga, name = "GaussAdjoint"),
(framework = AutoZygote(), sensealg = sensealg_gaf, name = "GaussAdjoint FiniteDifferences"),
#(framework = AutoZygote(), sensealg = sensealg_gad, name = "GaussAdjoint ForwardDiff"), # ERROR: autojacvec choice true is not supported by GaussAdjoint
(framework = AutoZygote(), sensealg = sensealg_gae, name = "GaussAdjoint EnzymeVJP"),
(framework = AutoZygote(), sensealg = sensealg_ia, name = "InterpolatingAdjoint"),
#(framework = AutoZygote(), sensealg = sensealg_iaf, name = "InterpolatingAdjoint FiniteDifferences"), # Does not work with Rodas5, for some reason throws ForwardDiff error
#(framework = AutoZygote(), sensealg = sensealg_iad, name = "InterpolatingAdjoint ForwardDiff"), # Does not work with Rodas5
(framework = AutoZygote(), sensealg = sensealg_iae, name = "InterpolatingAdjoint EnzymeVJP"),
(framework = AutoZygote(), sensealg = sensealg_qa, name = "QuadratureAdjoint"),
(framework = AutoZygote(), sensealg = sensealg_qaf, name = "QuadratureAdjoint FiniteDifferences"),
(framework = AutoZygote(), sensealg = sensealg_qad, name = "QuadratureAdjoint ForwardDiff"),
(framework = AutoZygote(), sensealg = sensealg_qae, name = "QuadratureAdjoint EnzymeVJP"),
(framework = AutoForwardDiff(), sensealg = nothing, name = "ForwardDiff"),
]
sens_test(framework, sensealg, p, loss) = DifferentiationInterface.gradient((p) -> loss(p, sensealg), framework, p) # has to be in global scope so that the @benchmark macro can capture it
config, results, algo = deserialize("data/sensitivities_brusselator_benchmark_$(algo_name)_post.jls") # run only if you want to load the results from a previous run
results = Dict{Int, Any}()
@time for N in 2:12
println("### Running benchmark for algo = $(algo_name), N = $N ###")
p = get_p([3.4, 1.0, 10.0, 10.0], N)
u0 = get_u0(N)
f_ode(du, u, p, t) = f(du, u, p, t, Val(N))
jac_sparsity = jac_sparsity_adtypes(u0, f_ode, p)
fun = ODEFunction(f_ode, jac_prototype=jac_sparsity)
prob = ODEProblem(fun, u0, tspan, p)
function loss(p, sensealg)
prob_remake = remake(prob, u0 = eltype(p).(u0), p = p)
sum(solve(prob_remake, algo, saveat=0.1, sensealg=sensealg, abstol=abstol, reltol=reltol, maxiters=1000_000_000))
end
function bench_sens(framework, sensealg)
bench = @benchmark sens_test($framework, $sensealg, $p, $loss)
m = mean(bench)
s = std(bench)
return (mean = m.time * 1e-9, std = s.time * 1e-9)
end
# warm-up: compile all of these once
println("Warming up...")
@time for cfg in config
println(" Warming up $(cfg.name)")
sens_test(cfg.framework, cfg.sensealg, p, loss)
end
println("Warming up complete!")
res = Array{Any}(undef, length(config))
println("Benchmarking ...")
@time for (i, cfg) in enumerate(config)
println(" Benchmarking $(cfg.name) with N = $N")
res[i] = bench_sens(cfg.framework, cfg.sensealg)
end
results[N] = res
serialize("data/sensitivities_brusselator_benchmark_$(algo_name)_post.jls", (config, results, algo)) # somehow pwd() is project root
println("Results saved for N = $N")
end
serialize("data/sensitivities_brusselator_benchmark_$(algo_name)_post", (config, results, algo))
results
function plot_results(config, results, algo_name)
markers = [:circle, :square, :utriangle, :star5]
fig = plot(title="Brusselator (algo = $algo_name)",
xlabel="params",
ylabel="Time (s)",
legend=:bottomright,
palette=:twelvebitrainbow,
size=(1200, 900),
left_margin=5mm,
yscale=:log10,
# make ticks every 10^n
xticks = collect(0:20:1000),
yticks = collect(10.0 .^ (-2:2)),
)
Ns = sort(collect(keys(results)))
params = @. Ns^2 * 4
for (i, cfg) in enumerate(config)
times = [results[N][i].mean for N in Ns]
stds = [results[N][i].std for N in Ns]
label = cfg.name
plot!(fig, params, times;
label=label, lw=2, fillalpha=0.3,
marker=markers[(i-1) % length(markers) + 1], markersize=3, markerstrokewidth=0
)
end
fig
end
plt = plot_results(config, results, algo_name)