DiffEqGPU.jl with CUDA: Error computing gradients through SDE solver

Hi Julia community!

I’m trying to fit a custom, one-layer NN to an SDE by computing an MSE loss. Here is my code defining the model and computing the loss (noise and sample_features are cuda matrices defined globally)

d = 4
l_list = ones((d,1))
samples_rff = rand(TDist(2*nu), (d, n_samples));
A = randn((d, n_samples))/sqrt(n_samples)

# Send the parameters to the GPU
A = cu(A)
samples_rff = cu(samples_rff)
l_list = cu(l_list)
weights = hcat((A, l_list)...)

function rff_model(X, A, sample_features, l)
    """
    X matrix of size (d, N)
    A is a matrix of size (d, n_samples)
    sample_features matrix of size (n_samples, d)
    l vector of size (d,1)
    """
    tau = l.^(-1)
    W = tau.*sample_features
    M = W'*X
    M = cos.(M) + sin.(M)
    return A*M
end

function drift(dstate, state, p, t)	
    A = p[:, 1:end-1]
    l = p[:, end:end]
    dstate .= rff_model(state, A, samples_rff, l)
end

function diffussion_rff(dstate,state,p,t)
    A = p[:, 1:end-1]
    l = p[:, end:end]
    dstate .= noise
end

function loss(weights, training_trajectories, initial_conditions, training_time, reg = 0.1)
    """
    parameters is of size (d, n_features +1)
    A is of size (d, n_features)
    l if of size (d, 1)

    initial_conditions is of size (d, m)
    """
    
    temp_prob = prob = SDEProblem( drift_rff,diffussion_rff, initial_conditions,(t_in, t_fin), weights)
    tmp_sol = solve(prob,SOSRI(), saveat=training_time);
    arrsol = CuArray(tmp_sol)
    #l = weights[:, end]
    return mean((arrsol - training_trajectories).^2)  #+ reg*sum(l.^(-1)) # mean(arrsol)
end

objective = weights -> loss(weights, training_trajectories, initial_conditions, training_times, 0.1)

val, grads = Zygote.withgradient(objective, weights)

Computing the error works without a hitch. However, computing gradients with Zygote or ReverseDiff returns the following warning and subsequent error:

┌ Warning: Potential performance improvement omitted. ZygoteVJP tried and failed in the automated AD choice algorithm. To show the stack trace, set SciMLSensitivity.STACKTRACE_WITH_VJPWARN = true. To turn off this printing, add verbose = false to the solve call. └ @ SciMLSensitivity ~/.julia/packages/SciMLSensitivity/4Ah3r/src/concrete_solve.jl:99
┌ Warning: Potential performance improvement omitted. ReverseDiffVJP tried and failed in the automated AD choice algorithm. To show the stack trace, set SciMLSensitivity.STACKTRACE_WITH_VJPWARN = true. To turn off this printing, add verbose = false to the solve call. └ @ SciMLSensitivity ~/.julia/packages/SciMLSensitivity/4Ah3r/src/concrete_solve.jl:116
┌ Warning: Potential performance improvement omitted. TrackerVJP tried and failed in the automated AD choice algorithm. To show the stack trace, set SciMLSensitivity.STACKTRACE_WITH_VJPWARN = true. To turn off this printing, add verbose = false to the solve call. └ @ SciMLSensitivity ~/.julia/packages/SciMLSensitivity/4Ah3r/src/concrete_solve.jl:134
┌ Warning: Reverse-Mode AD VJP choices all failed. Falling back to numerical VJPs └ @ SciMLSensitivity ~/.julia/packages/SciMLSensitivity/4Ah3r/src/concrete_solve.jl:144

ERROR: LoadError: MethodError: no method matching SciMLBase.SensitivityInterpolation(::Vector{Float64}, ::Vector{CuArray{Float64, 2, CUDA.Mem.DeviceBuffer}})

Computing gradients with ForwardDiff also returns an error:

ERROR: LoadError: MethodError: randn!(::CUDA.RNG, ::CuArray{ForwardDiff.Dual{ForwardDiff.Tag{var"#47#48", Float32}, Float64, 12}, 2, CUDA.Mem.DeviceBuffer}) is ambiguous.
Candidates: randn!(rng::CUDA.RNG, A::AbstractArray{T}) where T @ CUDA ~/.julia/packages/CUDA/rXson/src/random.jl:255 randn!(rng::Random.AbstractRNG, A::GPUArraysCore.AnyGPUArray) @ GPUArrays ~/.julia/packages/GPUArrays/dAUOE/src/host/random.jl:116
To resolve the ambiguity, try making one of the methods more specific, or adding a new method more specific than any of the existing applicable methods.

Does anyone know what is going on? Any workaround?

Thank you for your help!

This is something that got caught up in the v1.10 transition, it’s getting worked out ASAP.

Ok thanks! Could I possibly downgrade my version to an older version where the error doesn’t happen? Or is this intrinsic to Julia 1.10?

If you do ]add RecursiveArrayTools@2 it should pull things back to a good spot I think

Unfortunately, this didn’t work, but thank you for the suggestion. To be clear I’m using

val, grads = Zygote.withgradient(objective, weights)

to compute the gradients.

What did you hit with that?

Do you mean which error did I get? Here it is (apologies for the wall of text):

┌ Warning: Potential performance improvement omitted. ZygoteVJP tried and failed in the automated AD choice algorithm. To show the stack trace, set SciMLSensitivity.STACKTRACE_WITH_VJPWARN = true. To turn off this printing, add verbose = false to the solve call.
└ @ SciMLSensitivity ~/.julia/packages/SciMLSensitivity/4Ah3r/src/concrete_solve.jl:99

┌ Warning: Potential performance improvement omitted. ReverseDiffVJP tried and failed in the automated AD choice algorithm. To show the stack trace, set SciMLSensitivity.STACKTRACE_WITH_VJPWARN = true. To turn off this printing, add verbose = false to the solve call.
└ @ SciMLSensitivity ~/.julia/packages/SciMLSensitivity/4Ah3r/src/concrete_solve.jl:116

┌ Warning: Potential performance improvement omitted. TrackerVJP tried and failed in the automated AD choice algorithm. To show the stack trace, set SciMLSensitivity.STACKTRACE_WITH_VJPWARN = true. To turn off this printing, add verbose = false to the solve call.
└ @ SciMLSensitivity ~/.julia/packages/SciMLSensitivity/4Ah3r/src/concrete_solve.jl:134

┌ Warning: Reverse-Mode AD VJP choices all failed. Falling back to numerical VJPs
└ @ SciMLSensitivity ~/.julia/packages/SciMLSensitivity/4Ah3r/src/concrete_solve.jl:144
ERROR: LoadError: The algorithm is not compatible with the chosen noise type. Please see the documentation on the solver methods
Stacktrace:
[1] error(s::String)
@ Base ./error.jl:35
[2] __init(_prob::SDEProblem{…}, alg::SOSRI, timeseries_init::Vector{…}, ts_init::Vector{…}, ks_init::Type, recompile::Type{…}; saveat::Vector{…}, tstops::Vector{…}, d_discontinuities::Tuple{}, save_idxs::Nothing, save_everystep::Bool, save_noise::Bool, save_on::Bool, save_start::Bool, save_end::Nothing, callback::CallbackSet{…}, dense::Bool, calck::Bool, dt::Float64, adaptive::Bool, gamma::Rational{…}, abstol::Float64, reltol::Float64, qmin::Rational{…}, qmax::Rational{…}, qsteady_min::Int64, qsteady_max::Int64, beta2::Nothing, beta1::Nothing, qoldinit::Rational{…}, controller::Nothing, fullnormalize::Bool, failfactor::Int64, delta::Rational{…}, maxiters::Int64, dtmax::Float64, dtmin::Float64, internalnorm::typeof(DiffEqBase.ODE_DEFAULT_NORM), isoutofdomain::typeof(DiffEqBase.ODE_DEFAULT_ISOUTOFDOMAIN), unstable_check::typeof(DiffEqBase.ODE_DEFAULT_UNSTABLE_CHECK), verbose::Bool, force_dtmin::Bool, timeseries_errors::Bool, dense_errors::Bool, advance_to_tstop::Bool, stop_at_next_tstop::Bool, initialize_save::Bool, progress::Bool, progress_steps::Int64, progress_name::String, progress_message::typeof(DiffEqBase.ODE_DEFAULT_PROG_MESSAGE), progress_id::Symbol, userdata::Nothing, initialize_integrator::Bool, seed::UInt64, alias_u0::Bool, alias_jumps::Bool, kwargs::@Kwargs{})
@ StochasticDiffEq ~/.julia/packages/StochasticDiffEq/adUnd/src/solve.jl:111
[3] __solve(prob::SDEProblem{…}, alg::SOSRI, timeseries::Vector{…}, ts::Vector{…}, ks::Nothing, recompile::Type{…}; kwargs::@Kwargs{…})
@ StochasticDiffEq ~/.julia/packages/StochasticDiffEq/adUnd/src/solve.jl:6
[4] solve_call(_prob::SDEProblem{…}, args::SOSRI; merge_callbacks::Bool, kwargshandle::Nothing, kwargs::@Kwargs{…})
@ DiffEqBase ~/.julia/packages/DiffEqBase/7s9cb/src/solve.jl:608
[5] solve_up(prob::SDEProblem{…}, sensealg::Nothing, u0::CuArray{…}, p::CuArray{…}, args::SOSRI; kwargs::@Kwargs{…})
@ DiffEqBase ~/.julia/packages/DiffEqBase/7s9cb/src/solve.jl:1057
[6] solve(prob::SDEProblem{…}, args::SOSRI; sensealg::Nothing, u0::Nothing, p::Nothing, wrap::Val{…}, kwargs::@Kwargs{…})
@ DiffEqBase ~/.julia/packages/DiffEqBase/7s9cb/src/solve.jl:980
[7] _adjoint_sensitivities(sol::RODESolution{…}, sensealg::InterpolatingAdjoint{…}, alg::SOSRI; t::Vector{…}, dgdu_discrete::Function, dgdp_discrete::Nothing, dgdu_continuous::Nothing, dgdp_continuous::Nothing, g::Nothing, abstol::Float64, reltol::Float64, checkpoints::Vector{…}, corfunc_analytical::Nothing, callback::Nothing, kwargs::@Kwargs{…})
@ SciMLSensitivity ~/.julia/packages/SciMLSensitivity/4Ah3r/src/sensitivity_interface.jl:432
[8] _adjoint_sensitivities
@ SciMLSensitivity ~/.julia/packages/SciMLSensitivity/4Ah3r/src/sensitivity_interface.jl:390 [inlined]
[9] #adjoint_sensitivities#63
@ SciMLSensitivity ~/.julia/packages/SciMLSensitivity/4Ah3r/src/sensitivity_interface.jl:386 [inlined]
[10] (::SciMLSensitivity.var"#adjoint_sensitivity_backpass#307"{…})(Δ::CuArray{…})
@ SciMLSensitivity ~/.julia/packages/SciMLSensitivity/4Ah3r/src/concrete_solve.jl:535
[11] ZBack
@ ~/.julia/packages/Zygote/WOy6z/src/compiler/chainrules.jl:211 [inlined]
[12] (::Zygote.var"#kw_zpullback#53"{…})(dy::CuArray{…})
@ Zygote ~/.julia/packages/Zygote/WOy6z/src/compiler/chainrules.jl:237
[13] #291
@ Zygote ~/.julia/packages/Zygote/WOy6z/src/lib/lib.jl:206 [inlined]
[14] (::Zygote.var"#2169#back#293"{Zygote.var"#291#292"{…}})(Δ::CuArray{Float64, 3, CUDA.Mem.DeviceBuffer})
@ Zygote ~/.julia/packages/ZygoteRules/4nXuu/src/adjoint.jl:71
[15] #solve#40
@ ~/.julia/packages/DiffEqBase/7s9cb/src/solve.jl:980 [inlined]
[16] (::Zygote.Pullback{Tuple{…}, Any})(Δ::CuArray{Float64, 3, CUDA.Mem.DeviceBuffer})
@ Zygote ~/.julia/packages/Zygote/WOy6z/src/compiler/interface2.jl:0
[17] #291
@ Zygote ~/.julia/packages/Zygote/WOy6z/src/lib/lib.jl:206 [inlined]
[18] (::Zygote.var"#2169#back#293"{Zygote.var"#291#292"{…}})(Δ::CuArray{Float64, 3, CUDA.Mem.DeviceBuffer})
@ Zygote ~/.julia/packages/ZygoteRules/4nXuu/src/adjoint.jl:71
[19] solve
@ ~/.julia/packages/DiffEqBase/7s9cb/src/solve.jl:970 [inlined]
[20] (::Zygote.Pullback{Tuple{…}, Any})(Δ::CuArray{Float64, 3, CUDA.Mem.DeviceBuffer})
@ Zygote ~/.julia/packages/Zygote/WOy6z/src/compiler/interface2.jl:0
[21] loss
@ ~/SDE/SDELabquake/inference_rff/inference_rff.jl:178 [inlined]
[22] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Float64)
@ Zygote ~/.julia/packages/Zygote/WOy6z/src/compiler/interface2.jl:0
[23] loss
@ ~/SDE/SDELabquake/inference_rff/inference_rff.jl:169 [inlined]
[24] #1
@ ~/SDE/SDELabquake/inference_rff/inference_rff.jl:184 [inlined]
[25] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Float64)
@ Zygote ~/.julia/packages/Zygote/WOy6z/src/compiler/interface2.jl:0
[26] (::Zygote.var"#75#76"{Zygote.Pullback{Tuple{…}, Tuple{…}}})(Δ::Float64)
@ Zygote ~/.julia/packages/Zygote/WOy6z/src/compiler/interface.jl:45
[27] withgradient(f::Function, args::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer})
@ Zygote ~/.julia/packages/Zygote/WOy6z/src/compiler/interface.jl:162
[28] macro expansion
@ ./timing.jl:279 [inlined]
[29] top-level scope
@ ~/SDE/SDELabquake/inference_rff/inference_rff.jl:269
[30] include(fname::String)
@ Base.MainInclude ./client.jl:489
[31] top-level scope
@ REPL[3]:1

The adjoint of a diagonal noise SDE is a commutative noise problem, so that’s why it throws that.

I apologize, but I don’t understand what you just said :sweat_smile: does this mean that I cannot compute gradients with diagonal noise? What is the alternative?

EDIT: after some investigation, here is what I found.First, it seems that the solver I’m using is not compatible with non-diagonal noise. Changing the solve option to

solve(prob,SOSRA(),saveat=training_time)

removes the error. By setting

SciMLSensitivity.STACKTRACE_WITH_VJPWARN[] = true

I now get:

Warning: Potential performance improvement omitted. TrackerVJP tried and failed in the automated AD choice algorithm. To show the stack trace, set SciMLSensitivity.STACKTRACE_WITH_VJPWARN = true. To turn off this printing, add verbose = false to the solve call.
└ @ SciMLSensitivity ~/.julia/packages/SciMLSensitivity/4Ah3r/src/concrete_solve.jl:134
MethodError: no method matching drift_rff(::Tracker.TrackedMatrix{Float32, CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}, ::Tracker.TrackedMatrix{Float32, CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}, ::Float32)

Closest candidates are:
drift_rff(::Any, ::Any, ::Any, ::Float32)
@ Main ~/SDE/SDELabquake/inference_rff/inference_rff.jl:84
drift_rff(::Any, ::Any, ::Any, ::Any)
@ Main ~/SDE/SDELabquake/inference_rff/inference_rff.jl:84

┌ Warning: Reverse-Mode AD VJP choices all failed. Falling back to numerical VJPs
└ @ SciMLSensitivity ~/.julia/packages/SciMLSensitivity/4Ah3r/src/concrete_solve.jl:144
ERROR: LoadError: CuArray only supports element types that are allocated inline.
Any is not allocated inline

I would appreciate any insight for this, thanks!

@ChrisRackauckas here is the full error that I get:

┌ Warning: Potential performance improvement omitted. ZygoteVJP tried and failed in the automated AD choice algorithm. To show the stack trace, set SciMLSensitivity.STACKTRACE_WITH_VJPWARN = true. To turn off this printing, add verbose = false to the solve call.
└ @ SciMLSensitivity ~/.julia/packages/SciMLSensitivity/4Ah3r/src/concrete_solve.jl:99
MethodError: no method matching drift_rff(::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, ::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, ::Float32)

Closest candidates are:
drift_rff(::Any, ::Any, ::Any, ::Any)
@ Main ~/SDE/SDELabquake/inference_rff/inference_rff.jl:84

┌ Warning: Potential performance improvement omitted. ReverseDiffVJP tried and failed in the automated AD choice algorithm. To show the stack trace, set SciMLSensitivity.STACKTRACE_WITH_VJPWARN = true. To turn off this printing, add verbose = false to the solve call.
└ @ SciMLSensitivity ~/.julia/packages/SciMLSensitivity/4Ah3r/src/concrete_solve.jl:116
MethodError: no method matching gradient(::SciMLSensitivity.var"#282#288"{SDEProblem{CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, Tuple{Float32, Float32}, true, CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, Nothing, SDEFunction{true, SciMLBase.FullSpecialize, typeof(drift_rff), typeof(diffussion_rff), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing}, typeof(diffussion_rff), Base.Pairs{Symbol, Union{}, Tuple{}, @NamedTuple{}}, Nothing}}, ::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, ::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer})

Closest candidates are:
gradient(::Any, ::Any)
@ ReverseDiff ~/.julia/packages/ReverseDiff/UJhiD/src/api/gradients.jl:21
gradient(::Any, ::Any, ::ReverseDiff.GradientConfig)
@ ReverseDiff ~/.julia/packages/ReverseDiff/UJhiD/src/api/gradients.jl:21

┌ Warning: Potential performance improvement omitted. TrackerVJP tried and failed in the automated AD choice algorithm. To show the stack trace, set SciMLSensitivity.STACKTRACE_WITH_VJPWARN = true. To turn off this printing, add verbose = false to the solve call.
└ @ SciMLSensitivity ~/.julia/packages/SciMLSensitivity/4Ah3r/src/concrete_solve.jl:134
MethodError: no method matching drift_rff(::Tracker.TrackedMatrix{Float32, CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}, ::Tracker.TrackedMatrix{Float32, CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}, ::Float32)

Closest candidates are:
drift_rff(::Any, ::Any, ::Any, ::Any)
@ Main ~/SDE/SDELabquake/inference_rff/inference_rff.jl:84

┌ Warning: Reverse-Mode AD VJP choices all failed. Falling back to numerical VJPs
└ @ SciMLSensitivity ~/.julia/packages/SciMLSensitivity/4Ah3r/src/concrete_solve.jl:144
ERROR: LoadError: CuArray only supports element types that are allocated inline.
Any is not allocated inline

Stacktrace:
[1] error(s::String)
@ Base ./error.jl:35
[2] check_eltype(T::Type)
@ CUDA ~/.julia/packages/CUDA/rXson/src/array.jl:51
[3] CuArray{Any, 1, CUDA.Mem.DeviceBuffer}(::UndefInitializer, dims::Tuple{Int64})
@ CUDA ~/.julia/packages/CUDA/rXson/src/array.jl:66
[4] (CuArray{Any, 1})(::UndefInitializer, dims::Tuple{Int64})
@ CUDA ~/.julia/packages/CUDA/rXson/src/array.jl:147
[5] (CuArray{Any})(::UndefInitializer, dims::Tuple{Int64})
@ CUDA ~/.julia/packages/CUDA/rXson/src/array.jl:166
[6] similar(::Type{CuArray{Any}}, dims::Tuple{Int64})
@ Base ./abstractarray.jl:874
[7] similar(::Type{CuArray{Any}}, shape::Tuple{Base.OneTo{Int64}})
@ Base ./abstractarray.jl:873
[8] similar(bc::Base.Broadcast.Broadcasted{…}, ::Type{…})
@ CUDA ~/.julia/packages/CUDA/rXson/src/broadcast.jl:11
[9] copy
@ GPUArrays ~/.julia/packages/GPUArrays/dAUOE/src/host/broadcast.jl:42 [inlined]
[10] materialize(bc::Base.Broadcast.Broadcasted{CUDA.CuArrayStyle{…}, Nothing, StochasticDiffEq.var"#90#91"{…}, Tuple{…}})
@ Base.Broadcast ./broadcast.jl:903
[11] map(::Function, ::CuArray{Int64, 1, CUDA.Mem.DeviceBuffer})
@ GPUArrays ~/.julia/packages/GPUArrays/dAUOE/src/host/broadcast.jl:89
[12] sde_interpolation(tvals::CuArray{…}, id::StochasticDiffEq.LinearInterpolationData{…}, idxs::Nothing, deriv::Type, p::CuArray{…}, continuity::Symbol)
@ StochasticDiffEq ~/.julia/packages/StochasticDiffEq/adUnd/src/dense.jl:114
[13] (::StochasticDiffEq.LinearInterpolationData{…})(tvals::CuArray{…}, idxs::Nothing, deriv::Type, p::CuArray{…}, continuity::Symbol)
@ StochasticDiffEq ~/.julia/packages/StochasticDiffEq/adUnd/src/interp_func.jl:7
[14] (::RODESolution{…})(t::CuArray{…}, ::Type{…}; idxs::Nothing, continuity::Symbol)
@ SciMLBase ~/.julia/packages/SciMLBase/8XHkk/src/solutions/rode_solutions.jl:66
[15] (::RODESolution{…})(t::CuArray{…}, ::Type{…})
@ SciMLBase ~/.julia/packages/SciMLBase/8XHkk/src/solutions/rode_solutions.jl:64
[16] _concrete_solve_adjoint(::SDEProblem{…}, ::SOSRA, ::InterpolatingAdjoint{…}, ::CuArray{…}, ::CuArray{…}, ::SciMLBase.ChainRulesOriginator; save_start::Bool, save_end::Bool, saveat::CuArray{…}, save_idxs::Nothing, kwargs::@Kwargs{…})
@ SciMLSensitivity ~/.julia/packages/SciMLSensitivity/4Ah3r/src/concrete_solve.jl:389
[17] _concrete_solve_adjoint
@ ~/.julia/packages/SciMLSensitivity/4Ah3r/src/concrete_solve.jl:287 [inlined]
[18] #_concrete_solve_adjoint#291
@ ~/.julia/packages/SciMLSensitivity/4Ah3r/src/concrete_solve.jl:215 [inlined]
[19] _concrete_solve_adjoint
@ ~/.julia/packages/SciMLSensitivity/4Ah3r/src/concrete_solve.jl:200 [inlined]
[20] #_solve_adjoint#64
@ ~/.julia/packages/DiffEqBase/7s9cb/src/solve.jl:1512 [inlined]
[21] _solve_adjoint
@ ~/.julia/packages/DiffEqBase/7s9cb/src/solve.jl:1486 [inlined]
[22] #rrule#6
@ ~/.julia/packages/DiffEqBase/7s9cb/ext/DiffEqBaseChainRulesCoreExt.jl:25 [inlined]
[23] rrule
@ ~/.julia/packages/DiffEqBase/7s9cb/ext/DiffEqBaseChainRulesCoreExt.jl:21 [inlined]
[24] rrule
@ ~/.julia/packages/ChainRulesCore/zoCjl/src/rules.jl:140 [inlined]
[25] chain_rrule_kw
@ ~/.julia/packages/Zygote/WOy6z/src/compiler/chainrules.jl:235 [inlined]
[26] macro expansion
@ ~/.julia/packages/Zygote/WOy6z/src/compiler/interface2.jl:0 [inlined]
[27] _pullback(ctx::ZygoteRules.AContext, f::Any, args::Vararg{Any})
@ Zygote ~/.julia/packages/Zygote/WOy6z/src/compiler/interface2.jl:81 [inlined]
[28] _apply(::Function, ::Vararg{Any})
@ Core ./boot.jl:838 [inlined]
[29] adjoint
@ ~/.julia/packages/Zygote/WOy6z/src/lib/lib.jl:203 [inlined]
[30] _pullback
@ ~/.julia/packages/ZygoteRules/4nXuu/src/adjoint.jl:66 [inlined]
[31] #solve#40
@ ~/.julia/packages/DiffEqBase/7s9cb/src/solve.jl:980 [inlined]
[32] _pullback(::Zygote.Context{…}, ::DiffEqBase.var"##solve#40", ::Nothing, ::Nothing, ::Nothing, ::Val{…}, ::@Kwargs{…}, ::typeof(solve), ::SDEProblem{…}, ::SOSRA)
@ Zygote ~/.julia/packages/Zygote/WOy6z/src/compiler/interface2.jl:0
[33] _apply(::Function, ::Vararg{Any})
@ Core ./boot.jl:838
[34] adjoint
@ ~/.julia/packages/Zygote/WOy6z/src/lib/lib.jl:203 [inlined]
[35] _pullback
@ ~/.julia/packages/ZygoteRules/4nXuu/src/adjoint.jl:66 [inlined]
[36] solve
@ ~/.julia/packages/DiffEqBase/7s9cb/src/solve.jl:970 [inlined]
[37] _pullback(::Zygote.Context{…}, ::typeof(Core.kwcall), ::@NamedTuple{…}, ::typeof(solve), ::SDEProblem{…}, ::SOSRA)
@ Zygote ~/.julia/packages/Zygote/WOy6z/src/compiler/interface2.jl:0
[38] loss
@ ~/SDE/SDELabquake/inference_rff/inference_rff.jl:184 [inlined]
[39] _pullback(::Zygote.Context{…}, ::typeof(loss), ::CuArray{…}, ::CuArray{…}, ::CuArray{…}, ::CuArray{…}, ::CuArray{…}, ::CuArray{…}, ::Float64, ::UnitRange{…})
@ Zygote ~/.julia/packages/Zygote/WOy6z/src/compiler/interface2.jl:0
[40] loss
@ ~/SDE/SDELabquake/inference_rff/inference_rff.jl:175 [inlined]
[41] _pullback(::Zygote.Context{…}, ::typeof(loss), ::CuArray{…}, ::CuArray{…}, ::CuArray{…}, ::CuArray{…}, ::CuArray{…}, ::CuArray{…}, ::Float64)
@ Zygote ~/.julia/packages/Zygote/WOy6z/src/compiler/interface2.jl:0
[42] #1
@ ~/SDE/SDELabquake/inference_rff/inference_rff.jl:190 [inlined]
[43] _pullback(ctx::Zygote.Context{false}, f::var"#1#2", args::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer})
@ Zygote ~/.julia/packages/Zygote/WOy6z/src/compiler/interface2.jl:0
[44] pullback(f::Function, cx::Zygote.Context{false}, args::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer})
@ Zygote ~/.julia/packages/Zygote/WOy6z/src/compiler/interface.jl:44
[45] pullback(f::Function, cx::Zygote.Context{false}, args::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer})
@ Zygote ~/.julia/packages/Zygote/WOy6z/src/compiler/interface.jl:42 [inlined]
[46] withgradient(f::Function, args::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer})
@ Zygote ~/.julia/packages/Zygote/WOy6z/src/compiler/interface.jl:154
[47] macro expansion
@ ./timing.jl:279 [inlined]
[48] top-level scope
@ ~/SDE/SDELabquake/inference_rff/inference_rff.jl:269
[49] include(fname::String)
@ Base.MainInclude ./client.jl:489
[50] top-level scope
@ REPL[1]:1

Do you have an idea of what this means?

Those are performance warnings. It should still run though?

Your function is using mutation so it’s falling back to AD methods which are not compatible with CUDA. I would recommend defining it out of place f(u,p,t) g(u,p,t) so it defaults to ZygoteVJP.

Unfortunately, that doesn’t seem to help. My code looks like this:

nu = 1.5
n_samples = 100
d = 4
l_list = ones((d,1))
samples_rff = rand(TDist(2*nu), (d, n_samples));
# Generate A from a normal distribution
A = randn((d, n_samples))/sqrt(n_samples)
# Send the parameters to the GPU
A = cu(A)
samples_rff = cu(samples_rff)
l_list = cu(l_list)
weights = hcat((A, l_list)...)

function rff_model(X, A, sample_features, l)
    tau = l.^(-1)
    W = tau.*sample_features
    W = sample_features
    M = W'*X
    M = cos.(M) + sin.(M)
    return A*M
end
forward_model(x, A, l) = rff_model(x, A, samples_rff, l)

# Define our drift and diffusion functions
function drift_rff(dstate, state, p, t)	
    A = p[:, 1:end-1]
    l = p[:, end:end]
    return forward_model(state, A, l)
end

function diffussion_rff(dstate,state,p,t)
# noise is a vector of size(4,1)
    return dstate .+ noise 
end

prob = SDEProblem{true}( drift_rff, diffussion_rff,  initial_conditions,(t_in, t_fin), weights)
function loss(weights, training_trajectories, initial_conditions, training_times)
    temp_prob = remake(prob, p = weights)
    tmp_sol = solve(temp_prob,SOSRA(),saveat=training_times,  force_dtmin = true);
    arrsol = CuArray(tmp_sol)
    return sum((arrsol - training_trajectories).^2) 
end
objective = weights -> loss(weights, training_trajectories, initial_conditions, training_times)
val, grads = Zygote.withgradient(objective, weights)

Is perhaps indexing the parameters like this

A = p[:, 1:end-1]
l = p[:, end:end]

forbidden? Or is using globally defined arguments to define my forward model a problem? My noise is a vector of size (4,1), could that be an issue? I’m somewhat confused because I feel like what I’m doing is fairly simple and computing the loss itself is not a problem.

Thank you for your help.

EDIT: after some more testing, I tried to define the problem out-of-place in the following way:

# Define our drift and diffusion functions
function drift_rff(state, p, t)	
    A = p[:, 1:n_samples]
    l = p[:, n_samples+1:n_samples+1]
    noise = p[:, end:end]
    return forward_model(state, A, l)
end
function diffussion_rff(state,p,t)
    A = p[:, 1:n_samples]
    l = p[:, n_samples+1:n_samples+1]
    noise = p[:, end:end]
    return noise
end

@show drift_rff(cu(u0), weights, 0.0)
@show diffussion_rff(cu(u0), weights, 0.0)
prob = SDEProblem{true}( drift_rff, diffussion_rff,  initial_conditions,(t_in, t_fin), weights)

but this throws me an error:

ERROR: LoadError: Nonconforming functions detected. If a model function f is defined as in-place, then all constituent functions like jac and paramjac must be in-place (and vice versa with out-of-place). Detected that some overloads did not conform to the same convention as f.

This error pops up no matter the definition of diffusion (even if it’s defined to be the drift itself!). I’m not sure what is going on. Note that my initial condition is either a vector of dimension d or a matrix of size dxn, so I’m solving a system of equations.