I have some quite complex Julia package which I would like to tweak for performance. As I understand, the key for this will be to make the code type-stable and reusing memory, especially for repeatedly called expensive functions. I however get stuck on how to handle and pass down correct allocations for different AD libraries at once.
The actual package is huge and tried to come up with a MWE that keeps the spirit. The following code compiles & runs on Julia 1.6:
using BandedMatrices
using DiffEqFlux
using OrdinaryDiffEq
using DiffEqSensitivity
using Flux
using GalacticOptim
using LinearAlgebra
using Zygote
#band building stuff
function Vband(n,l,u,Δx,v)
A = BandedMatrix{Float64}(undef, (n,n), (l,u))
A[band(0)] .= 0
A[band(1)] .= -1/(2*Δx)
A[band(-1)] .= 1/(2*Δx)
A[1,1:2] .= 0
A[n,n] = -1/(2*Δx)
A[n,n-1] = 1/(2*Δx)
A = A.*v
return A
end
function build_band(npoints,L,v)
Δx = L/(npoints-1)
n = npoints
VB = Vband(n,1,1,Δx,v)
return VB
end
#ODE system
RHS = FastChain(FastDense(2,9,tanh,bias = false),FastDense(9,1,bias = false))
function eqsys(du,u,θ,t,n,Φ,BandMat,Band_c)
c = @view u[1:n]
dcdt = @view du[1:n]
dqdt = @view du[n+1:2*n]
dqdt .= vec(RHS(reshape(u,n,:)',θ))
mul!(Band_c,BandMat,c)
@. dcdt = Band_c - Φ*dqdt
dqdt[1] = 0
dcdt[1] = 0
end
#wrapper with closure & preallocation
function fun(p,u0,n,Φ,BandMat,save_times)
Band_c = Vector{Real}(undef,n)
ode_closure = ODEFunction((du,u,p,t) -> eqsys(du,u,p,t,n,Φ,BandMat,Band_c))
prob = ODEProblem(ode_closure,u0,tspan,p)
solve(prob,KenCarp4(autodiff=false),saveat = save_times,sensealg = QuadratureAdjoint(autojacvec=ReverseDiffVJP(true)))
end
#callback & loss
cb = function (p,l)
println(l)
if l < 1e-3
return true
else
return false
end
end
function loss(p,u0,n,Φ,BandMat,save_times,data)
sol = fun(p,u0,n,Φ,BandMat,save_times)
diff_vec = vec(sol[n,:].-data)
loss = sqrt(sum(abs2,diff_vec))
return loss
end
#some toy paras & data (will probably never fit)
n = 20
L = 10
Epsilon = 0.36
Φ = zeros(n)
Φ .= (1-Epsilon)/Epsilon
c_in = 1
save_times = collect(0:0.2*pi:20)
data = save_times./20
u0 = zeros(2*n)
u0[1] = c_in
v = 2.8
tᵢ = save_times[1]
tₑ = save_times[end]
tspan = (tᵢ,tₑ)
#initialize & optimize
BandMat = build_band(n,L,v)
p = Float64.(initial_params(RHS)).*1e-3
optprob = OptimizationFunction((p,x) -> loss(p,u0,n,Φ,BandMat,save_times,data), GalacticOptim.AutoZygote())
prob = GalacticOptim.OptimizationProblem(optprob, p)
res = GalacticOptim.solve(prob,ADAM(1e-2),cb = cb,maxiters = 100)
Since eqsys(du,u,θ,t,n,Φ,BandMat,Band_c)
is the most evaluated function, this has been the focus of my performance optimization. Especially I would like to allocate a cache vector Ax
for the results of mul!(Ax,A,x)
and it’s where I have troubles since the different modes/libraries of AD that get called work with different types. As I understand from this post its best if I use some kind of nested architectures with additional closures in it.
Therefore I allocate Band_c = Vector{Real}(undef,n)
in the outer function fun
. I however do not manage to give it the correct concrete type for each AD mode.
As I understand in my limited knowledge about Julia, allocating just a Vector{AbstractType}
as I do won’t help with speed since at runtime the compiler has to figure out what to do with it inside eqsys
at each call.
So I have two questions:
- How can I concretise the preallocated type without any loss of generality?
- Is this “nested” approach the most efficient way to go or is there another way without rebuilding the ODEProblem each time?