Correct preallocation for different Autodiff libraries

I have some quite complex Julia package which I would like to tweak for performance. As I understand, the key for this will be to make the code type-stable and reusing memory, especially for repeatedly called expensive functions. I however get stuck on how to handle and pass down correct allocations for different AD libraries at once.

The actual package is huge and tried to come up with a MWE that keeps the spirit. The following code compiles & runs on Julia 1.6:

using BandedMatrices
using DiffEqFlux
using OrdinaryDiffEq
using DiffEqSensitivity
using Flux
using GalacticOptim
using LinearAlgebra
using Zygote

#band building stuff
function Vband(n,l,u,Δx,v)
  A = BandedMatrix{Float64}(undef, (n,n), (l,u))
  A[band(0)] .= 0
  A[band(1)] .= -1/(2*Δx)
  A[band(-1)] .= 1/(2*Δx)
  A[1,1:2] .= 0
  A[n,n] = -1/(2*Δx)
  A[n,n-1] = 1/(2*Δx)
  A = A.*v
  return A

function build_band(npoints,L,v)
  Δx = L/(npoints-1)
  n = npoints
  VB = Vband(n,1,1,Δx,v)
  return VB

#ODE system
RHS =  FastChain(FastDense(2,9,tanh,bias = false),FastDense(9,1,bias = false))

function eqsys(du,u,θ,t,n,Φ,BandMat,Band_c)
  c =  @view u[1:n]
  dcdt =  @view du[1:n]
  dqdt =  @view du[n+1:2*n]
  dqdt .= vec(RHS(reshape(u,n,:)',θ))
  @. dcdt = Band_c - Φ*dqdt
  dqdt[1] = 0
  dcdt[1] = 0

#wrapper with closure &  preallocation
function fun(p,u0,n,Φ,BandMat,save_times)
  Band_c = Vector{Real}(undef,n)
  ode_closure = ODEFunction((du,u,p,t) -> eqsys(du,u,p,t,n,Φ,BandMat,Band_c))
  prob = ODEProblem(ode_closure,u0,tspan,p)
  solve(prob,KenCarp4(autodiff=false),saveat = save_times,sensealg = QuadratureAdjoint(autojacvec=ReverseDiffVJP(true)))

#callback & loss
cb = function (p,l)
  if l < 1e-3
      return true
      return false

function loss(p,u0,n,Φ,BandMat,save_times,data)
  sol = fun(p,u0,n,Φ,BandMat,save_times)
  diff_vec = vec(sol[n,:].-data)
  loss = sqrt(sum(abs2,diff_vec))
  return loss

#some toy paras & data  (will probably never fit)
n = 20
L = 10
Epsilon = 0.36
Φ = zeros(n)
Φ .= (1-Epsilon)/Epsilon

c_in = 1
save_times  = collect(0:0.2*pi:20)
data = save_times./20

u0 = zeros(2*n)
u0[1] = c_in
v = 2.8

tᵢ = save_times[1]
tₑ = save_times[end]
tspan = (tᵢ,tₑ)

#initialize & optimize
BandMat = build_band(n,L,v)
p = Float64.(initial_params(RHS)).*1e-3
optprob = OptimizationFunction((p,x) -> loss(p,u0,n,Φ,BandMat,save_times,data), GalacticOptim.AutoZygote())
prob = GalacticOptim.OptimizationProblem(optprob, p)
res = GalacticOptim.solve(prob,ADAM(1e-2),cb = cb,maxiters = 100)

Since eqsys(du,u,θ,t,n,Φ,BandMat,Band_c) is the most evaluated function, this has been the focus of my performance optimization. Especially I would like to allocate a cache vector Ax for the results of mul!(Ax,A,x) and it’s where I have troubles since the different modes/libraries of AD that get called work with different types. As I understand from this post its best if I use some kind of nested architectures with additional closures in it.

Therefore I allocate Band_c = Vector{Real}(undef,n) in the outer function fun. I however do not manage to give it the correct concrete type for each AD mode.
As I understand in my limited knowledge about Julia, allocating just a Vector{AbstractType} as I do won’t help with speed since at runtime the compiler has to figure out what to do with it inside eqsys at each call.

So I have two questions:

  1. How can I concretise the preallocated type without any loss of generality?
  2. Is this “nested” approach the most efficient way to go or is there another way without rebuilding the ODEProblem each time?

Pre-allocating doesn’t make sense when reverse mode is involved.

Why not?
I guess I’m misunderstanding something, but I don’t know what.
My understanding was that preallocating arrays for mul! is of essence to have fast DiffEq code, at least it’s what I read in the tutorials.

Reverse mode automatic differentiation environments are different because by design they require a form of caching, so you can try to fight against it but the derivative calculations will still have to allocate. To see why, you can see the lecture notes on how it works.