Optimizating computational time of gradient on big linear UDEs

I have an high-dimensional UDE model that includes a linear model. A MWE of the linear component of the model is:

using ForwardDiff, DifferentialEquations, SciMLSensitivity

trainingData = rand(100,4)
p0 = rand(100,100)

function nn!(du,u,p,t)
	diffs = [u[j]-u[i] for i in 1:100, j in 1:100]
	du = sum(p.*diffs,dims=2)
end

function predict(p)
	prob = ODEProblem(nn!,trainingData[:,1],(1.,4.),p)
	Array(solve(prob,saveat=1.))
end

function loss(p)
	pred = predict(p)
	sum(abs2,pred .- trainingData)
end

@time ForwardDiff.gradient(loss,p0);
 27.527524 seconds (1.45 M allocations: 122.377 GiB, 33.60% gc time, 0.32% compilation time: 100% of which was recompilation)

In the actual code, it takes forever to calculate a single gradient (after several minutes it hasn’t been calculated).
Could somebody advise me on ways to optimize this computational time to be more reasonable?

1 Like

You have 10^4 parameters — computing the gradient by forward-mode AD effectively involves solving the ODE 10^4 times. Using a reverse-mode (“adjoint”) algorithm, in contrast, will effectively involve solving the ODE one additional time to get the gradient.

You should probably use an adjoint/reverse method in this regime, not ForwardDiff.

A classic reference on adjoint-method (reverse-mode/backpropagation) differentiation of ODEs (and generalizations thereof) is Cao et al (2003) (pdf). See also the SciMLSensitivity.jl package’s documention on reverse-mode AD for adjoint-method sensitivity analysis with DifferentialEquations.jl, along with Chris Rackauckas’s notes from 18.337. There is a nice YouTube lecture on adjoint sensitivity of ODEs, again using a similar notation.

1 Like

And also, since it’s a linear ODE, you can specialize this to simply use u(t,p) = u0*exp(A(p)*t) which then has a very simple derivative. SciMLSensitivity currently won’t specialize on linear ODEs but it should in the future, for now it’s quite straight forward to derive. As Steven says though, you will want to do this using the adjoint as forward mode will have O(np) scaling vs O(n+p) scaling of using the adjoint.

1 Like

Beware that differentiating a matrix exponential (with respect to parameters of the matrix) is not as simple as many people expect (but ChainRules.jl and hence Zygote.jl can do it). See also Differentiating random walk probability w.t.r. rate of jump - #14 by stevengj

1 Like

Yeah I meant the code is relatively straightforward. If you put that into Zygote it should just work. Though it won’t work with ForwardDiff.jl for the reason that you mention, that the squaring and scaling algorithm used in Base is not differentiable and ForwardDiff is missing a specialized rule on it.

BTW this reminds me, Normally for this kind of function instead of recommending exp I’d normally recommend ExponentialUtilities.jl with expv, but there’s a missing derivative there for various reasons. Do you happen to know of a better trick than the one mentioned here Add ChainRules rules · Issue #40 · SciML/ExponentialUtilities.jl · GitHub ?

Thanks to both of you for your help! Two comments regarding the discussion:

  • When trying to use Zygote, I get the following error on the MWE:
f = LLVM.Function("julia__mapreducedim__5966")
(gty, inst, v) = (LLVM.IntegerType[LLVM.IntegerType(i64)], LLVM.PHIInst(%129 = phi double addrspace(13)* [ poison, %L89.L110.loopexit_crit_edge.us125.unr-lcssa.us.1.L110.us129.us.1_crit_edge ], [ %113, %L89.L110.loopexit_crit_edge.us125.unr-lcssa.us.1.thread ], [ %119, %L93.us121.epil.us.1 ]), LLVM.PoisonValue(0x000000006a63fb20))
f = LLVM.Function("julia__mapreducedim__6969")
(gty, inst, v) = (LLVM.IntegerType[LLVM.IntegerType(i64)], LLVM.PHIInst(%129 = phi double addrspace(13)* [ poison, %L89.L110.loopexit_crit_edge.us125.unr-lcssa.us.1.L110.us129.us.1_crit_edge ], [ %113, %L89.L110.loopexit_crit_edge.us125.unr-lcssa.us.1.thread ], [ %119, %L93.us121.epil.us.1 ]), LLVM.PoisonValue(0x00000000086ed390))
┌ Warning: EnzymeVJP tried and failed in the automated AD choice algorithm with the following error. (To turn off this printing, add `verbose = false` to the `solve` call)
└ @ SciMLSensitivity ~/.julia/packages/SciMLSensitivity/Rm4xX/src/concrete_solve.jl:23
AssertionError: false
ERROR: UndefRefError: access to undefined reference

follow by a long stacktrace.

  • Unfortunately my model is not just linear, but the linear part is the one that takes most of the space in memory as it has thousands of parameters.
    A small neural network makes predictions over subsets of the data. Maybe a more reasonable MWE of what I am trying to do would be:
using Zygote, DifferentialEquations, SciMLSensitivity, Lux, Random, ComponentArrays
rng = Random.default_rng()

trainingData = rand(100,4)
p0 = rand(100,100)
chain = Lux.Chain(Lux.Dense(4,5),Lux.Dense(5,4))
ltup = Lux.setup(rng, chain)
ps = ltup[1]
st = ltup[2]

p = ComponentVector(model_params = ps, connectivityMatrix = p0)

function nn!(du,u,p,t)
    nns = reduce(vcat,[first(chain(u[((i-1)*4+1):((i-1)*4+4)],p.model_params,st)) for i in 1:25])
	diffs = [u[j]-u[i] for i in 1:100, j in 1:100]
	du = nns .+ sum(p.connectivityMatrix*diffs,dims=2)
end

function predict(p)
	prob = ODEProblem(nn!,trainingData[:,1],(1.,4.),p)
	Array(solve(prob,saveat=1.))
end

function loss(p)
	pred = predict(p)
	sum(abs2,pred .- trainingData)
end

@time Zygote.gradient(loss,p);

That error message is coming from Enzyme, and seemingly an older version (that code has been substantially improved since).

What is your package status [and thus version of Enzyme]

Have you tried rewriting it in terms of a matrix exponential, as Chris recommended?

Thanks William! I knew I’ve seen that error before. My status was

Status `~/.julia/environments/v1.9/Project.toml`
⌃ [7da242da] Enzyme v0.11.7
⌃ [e88e6eb3] Zygote v0.6.65

I upgraded and now it is

Status `/central/home/jarroyoe/.julia/environments/v1.8/Project.toml`
  [7da242da] Enzyme v0.11.12
  [e88e6eb3] Zygote v0.6.68

and I am getting a different error:

ERROR: BoundsError: attempt to access 4-element StepRangeLen{Float64, Base.TwicePrecision{Float64}, Base.TwicePrecision{Float64}, Int64} at index [0]

I just did and there is a significant performance boost when doing loss(p0) and ForwardDiff.gradient(loss,p0). I am getting errors on the adjoint though.

using Zygote, DifferentialEquations, SciMLSensitivity

trainingData = rand(100,4)
p0 = rand(100,100)

function predict(p)
	A = [i==j ? -sum(p[i,:]) : p[i,j] for i in 1:100, j in 1:100]
	reduce(hcat,[trainingData[:,1]'*exp.(A*i) for i in 0:3]')
end

function loss(p)
	pred = predict(p)
	sum(abs2,pred .- trainingData)
end

@time Zygote.gradient(loss,p0);

yields:

ERROR: MethodError: no method matching adjoint(::Nothing)
Closest candidates are:
  adjoint(::Union{LinearAlgebra.QR, LinearAlgebra.QRCompactWY, LinearAlgebra.QRPivoted}) at /central/software/julia/1.8.5/share/julia/stdlib/v1.8/LinearAlgebra/src/qr.jl:517
  adjoint(::Union{LinearAlgebra.Cholesky, LinearAlgebra.CholeskyPivoted}) at /central/software/julia/1.8.5/share/julia/stdlib/v1.8/LinearAlgebra/src/cholesky.jl:558
  adjoint(::LinearAlgebra.Hessenberg) at /central/software/julia/1.8.5/share/julia/stdlib/v1.8/LinearAlgebra/src/hessenberg.jl:424

I thought you wanted a matrix exponential? This is an elementwise exponential, which is very much not the same thing as exp(A*i).

(Also, exp(A*i) for i = 0:3 is the same as exp(A)^i for i = 0:3, but computing exp(A) once and re-using it will be much faster. I’m also a little confused by why you are multiplying the initial condition as a row vector on the left, though of course it is possible to do this if you have transposed your system matrix. Make sure that your new calculation matches your ODE solution … it looks quite different from what you wrote before at first glance!)

It appears you are trying to calculate nothing' (== adjoint(nothing)) somewhere. The stacktrace will tell you where this is being called.

Thanks for pointing that out, this was me messing up the math! This code works flawlessly for the linear case.

using Zygote, DifferentialEquations, SciMLSensitivity

trainingData = rand(100,4)
p0 = rand(100,100)

function predict(p)
	A = [i==j ? -sum(p[i,:]) : p[i,j] for i in 1:100, j in 1:100]
	reduce(hcat,[exp(A*i)*trainingData[:,1] for i in 0:3])
end

function loss(p)
	pred = predict(p)
	sum(abs2,pred .- trainingData)
end

@time Zygote.gradient(loss,p0);