ReversedDiff/Zygote with SDE DifferentialEquations fails to compute gradients after a certain number of parameters


I’m trying to solve a parameter estimation problem for an SDE. The parameters are a small, one-layer neural network (technically, it’s a random feature model). My parameters are contained in one matrix of size (4 xn ) where I would like n to be in the hundreds. I’m trying to compute the gradients of a mean-square loss with respect to this matrix.

function loss(p, noise, training_trajectories, initial_conditions,)
    parameters = hcat((p, noise)...)
    temp_prob = SDEProblem( drift_rff,diffussion_rff,initial_conditions,(t_in,t_fin), parameters)
    tmp_sol = solve(temp_prob, saveat=dt, verbose=false, abstol = 1e-1, reltol = 1e-1)
    arrsol = Array(tmp_sol)

    return mean((arrsol - training_trajectories).^2)  

objective = parameters -> loss(parameters, noise,  target,initial_conditions)

Computing the loss is fast (8 ms wallclock time). Computing the gradients with ForwardDiff works but is slow (2.5 seconds with n = 25):

    result = DiffResults.GradientResult(p_opt) # Preallocate gradient vector
    result = ForwardDiff.gradient!(result,objective, p_opt);
    grads = DiffResults.gradient(result)

When computing the gradients with ReverseDiff (or Zygote), there is no issue with a small number of parameters (although this is still slow):

    result = DiffResults.GradientResult(p_opt) # Preallocate gradient vector
    result = ReverseDiff.gradient!(result,objective, p_opt);
    grads = DiffResults.gradient(result)

When the number of parameters becomes modestly large (n = 20) I get the following warning and subsequent error (only with ReverseDiff or Zygote, ForwardDiff works fine):

┌ Warning: Using fallback BLAS replacements, performance may be degraded
└ @ Enzyme.Compiler ~/.julia/packages/GPUCompiler/U36Ed/src/utils.jl:59
warning: didn’t implement memmove, using memcpy as fallback which can result in errors
warning: didn’t implement memmove, using memcpy as fallback which can result in errors
┌ Warning: Using fallback BLAS replacements, performance may be degraded
└ @ Enzyme.Compiler ~/.julia/packages/GPUCompiler/U36Ed/src/utils.jl:59
┌ Warning: Automatic AD choice of autojacvec failed in ODE adjoint, failing back to ODE adjoint + numerical vjp
└ @ SciMLSensitivity ~/.julia/packages/SciMLSensitivity/4Ah3r/src/sensitivity_interface.jl:381
ERROR: LoadError: MethodError: no method matching forwarddiffs_model_time(::Nothing)

Closest candidates are:
forwarddiffs_model_time(::Union{OrdinaryDiffEq.OrdinaryDiffEqRosenbrockAdaptiveAlgorithm, OrdinaryDiffEq.OrdinaryDiffEqRosenbrockAlgorithm})
@ OrdinaryDiffEq ~/.julia/packages/OrdinaryDiffEq/mhqg9/src/alg_utils.jl:22
@ SciMLBase ~/.julia/packages/SciMLBase/g3O0b/src/alg_traits.jl:32

Any idea what might be the issue or how to speed up the computation time?

I’m using the latest version of all these packages and of Julia. I’m on linux and this error happens on two different machines.

It’s swapping the default at that point and there’s an issue from the v1.10 updates. I think the bugs around that will get fixed within the next week, but it’s been a bit of a process for multiple reasons.

It sounds like you’re concatenating a bunch of different parameters for the same SDE? It would be much better to use the GPU ensemble solvers then. See Using the EnsembleGPUKernel SDE solvers for the expectation of SDEs · DiffEqGPU.jl and some of the other tutorials on making the runs have different parameters. This would make the differentition a lot quicker since the derivatives are sparse across the different sets.

1 Like

Thank you for your answer! If it’s a bug, then I guess I’ll just wait until it’s fixed.

As for your second point, my SDE is a 4-dimensional system. The drift function is essentially a fixed non-linearity applied to the input, which lifts it to some higher dimension n. I then apply a matrix of size 4xn to this output, and these weights are what I want to learn. The noise is additive and I don’t wish to learn it.

I might check out the GPU tutorials. I initially didn’t want to use GPUs because that comes with its own difficulties in my experience and the number of parameters I wish to learn seems fairly modest (between 400 and 2000) but I guess that I underestimated the computational cost.