I’m trying to solve a parameter estimation problem for an SDE. The parameters are a small, one-layer neural network (technically, it’s a random feature model). My parameters are contained in one matrix of size (4 xn ) where I would like n to be in the hundreds. I’m trying to compute the gradients of a mean-square loss with respect to this matrix.
function loss(p, noise, training_trajectories, initial_conditions,)
parameters = hcat((p, noise)...)
temp_prob = SDEProblem( drift_rff,diffussion_rff,initial_conditions,(t_in,t_fin), parameters)
tmp_sol = solve(temp_prob, saveat=dt, verbose=false, abstol = 1e-1, reltol = 1e-1)
arrsol = Array(tmp_sol)
return mean((arrsol - training_trajectories).^2)
objective = parameters -> loss(parameters, noise, target,initial_conditions)
Computing the loss is fast (8 ms wallclock time). Computing the gradients with ForwardDiff works but is slow (2.5 seconds with n = 25):
result = DiffResults.GradientResult(p_opt) # Preallocate gradient vector
result = ForwardDiff.gradient!(result,objective, p_opt);
grads = DiffResults.gradient(result)
When computing the gradients with ReverseDiff (or Zygote), there is no issue with a small number of parameters (although this is still slow):
result = DiffResults.GradientResult(p_opt) # Preallocate gradient vector
result = ReverseDiff.gradient!(result,objective, p_opt);
grads = DiffResults.gradient(result)
When the number of parameters becomes modestly large (n = 20) I get the following warning and subsequent error (only with ReverseDiff or Zygote, ForwardDiff works fine):
┌ Warning: Using fallback BLAS replacements, performance may be degraded
└ @ Enzyme.Compiler ~/.julia/packages/GPUCompiler/U36Ed/src/utils.jl:59
warning: didn’t implement memmove, using memcpy as fallback which can result in errors
warning: didn’t implement memmove, using memcpy as fallback which can result in errors
┌ Warning: Using fallback BLAS replacements, performance may be degraded
└ @ Enzyme.Compiler ~/.julia/packages/GPUCompiler/U36Ed/src/utils.jl:59
┌ Warning: Automatic AD choice of autojacvec failed in ODE adjoint, failing back to ODE adjoint + numerical vjp
└ @ SciMLSensitivity ~/.julia/packages/SciMLSensitivity/4Ah3r/src/sensitivity_interface.jl:381
ERROR: LoadError: MethodError: no method matching forwarddiffs_model_time(::Nothing)
Closest candidates are:
forwarddiffs_model_time(::Union{OrdinaryDiffEq.OrdinaryDiffEqRosenbrockAdaptiveAlgorithm, OrdinaryDiffEq.OrdinaryDiffEqRosenbrockAlgorithm})
@ OrdinaryDiffEq ~/.julia/packages/OrdinaryDiffEq/mhqg9/src/alg_utils.jl:22
@ SciMLBase ~/.julia/packages/SciMLBase/g3O0b/src/alg_traits.jl:32
Any idea what might be the issue or how to speed up the computation time?
I’m using the latest version of all these packages and of Julia. I’m on linux and this error happens on two different machines.