SciMLSensitivity fails on GPU?

Hi all,

I’m running into a problem when trying to compute gradients of an ODE solution on GPU (with Metal.jl) using SciMLSensitivity (SciMLSensitivity: Automatic Differentiation and Adjoints for (Differential) Equation Solvers · SciMLSensitivity.jl).

Here’s a minimal working example:

using Random
using Lux
using OrdinaryDiffEq
using SciMLSensitivity
using Zygote
using Metal

# device = Lux.cpu_device()
device = Lux.gpu_device()

velocity_test(x, p, t) = p * x

X₀ = rand(Float32, 2, 4) |> device
p  = rand(Float32, 2, 2) |> device
tspan = (zero(eltype(X₀)), one(eltype(X₀)))

function loss(p)
    problem = ODEProblem(velocity_test, X₀, tspan, p)
    sol = solve(problem)
    X = last(sol.u)
    return mean(X)
end

loss(p)                       # OK on CPU and GPU
Zygote.gradient(loss, p)      # OK on CPU, fails on GPU

On CPU, everything works fine.

On GPU the forward solve (loss(p)) works without issue.

But when I call Zygote.gradient(loss, p) I get:

Scalar indexing is disallowed.
Invocation of getindex resulted in scalar indexing of a GPU array.
This is typically caused by calling an iterating implementation of a method.
Such implementations *do not* execute on the GPU, but very slowly on the CPU,
and therefore should be avoided.

I’ve tried changing the sensitivity algorithm like this:

function loss(p)
    problem = ODEProblem(velocity_test, X₀, tspan, p)
    sol = solve(problem, Tsit5();
        sensealg = BacksolveAdjoint(autojacvec=ZygoteVJP())    # Change here
    )
    X = last(sol.u)
    return mean(X)
end

But I now get a different error:

InvalidIRError: ... Metal does not support Float64 values, try using Float32 instead

even though all of my inputs (X₀, p, tspan) are explicitly Float32.

Is this a know issue of SciMLSensitivity + GPU + Zygote? Is there a sensealg that works on GPU? Any pointers would be very welcome :folded_hands:t2:

What is the stack trace associated with the InvalidIRError? There might be a place in OrdinaryDiffEq where we use a Float64 by accident.

1 Like

Thanks for the reply :slight_smile:
The full stack is too long for copy-pasting, I’ve removed some of the very long arguments:

InvalidIRError: compiling MethodInstance for (::Metal.var"#broadcast_linear#204")(::MtlDeviceVector{Float32, 1}, ::Base.Broadcast.Broadcasted{Metal.MtlArrayStyle{1, Metal.PrivateStorage}, Tuple{Base.OneTo{Int64}}, typeof(muladd), Tuple{Base.Broadcast.Broadcasted{Metal.MtlArrayStyle{1, Metal.PrivateStorage}, Nothing, typeof(DiffEqBase.ODE_DEFAULT_NORM), Tuple{Base.Broadcast.Extruded{MtlDeviceVector{Float32, 1}, Tuple{Bool}, Tuple{Int64}}, Float32}}, Float64, Float64}}) resulted in invalid LLVM IR
Reason: unsupported use of double value
Reason: unsupported use of double value
Reason: unsupported use of double value
Stacktrace:
 [1] Float64
   @ ./float.jl:341
 [2] convert
   @ ./number.jl:7
 [3] _promote
   @ ./promotion.jl:384
 [4] promote
   @ ./promotion.jl:406
 [5] muladd
   @ ./promotion.jl:481
 [6] _broadcast_getindex_evalf
   @ ./broadcast.jl:678
 [7] _broadcast_getindex
   @ ./broadcast.jl:651
 [8] getindex
   @ ./broadcast.jl:610
 [9] broadcast_linear
   @ ~/.julia/packages/Metal/N2ABH/src/broadcast.jl:89
Reason: unsupported use of double value
Stacktrace:
 [1] muladd
   @ ./float.jl:496
 [2] muladd
   @ ./promotion.jl:481
 [3] _broadcast_getindex_evalf
   @ ./broadcast.jl:678
 [4] _broadcast_getindex
   @ ./broadcast.jl:651
 [5] getindex
   @ ./broadcast.jl:610
 [6] broadcast_linear
   @ ~/.julia/packages/Metal/N2ABH/src/broadcast.jl:89
Reason: unsupported use of double value
Stacktrace:
 [1] muladd
   @ ./float.jl:496
 [2] muladd
   @ ./promotion.jl:481
 [3] _broadcast_getindex_evalf
   @ ./broadcast.jl:678
 [4] _broadcast_getindex
   @ ./broadcast.jl:651
 [5] getindex
   @ ./broadcast.jl:610
 [6] broadcast_linear
   @ ~/.julia/packages/Metal/N2ABH/src/broadcast.jl:89
Reason: unsupported use of double value
Stacktrace:
 [1] Float32
   @ ./float.jl:338
 [2] convert
   @ ./number.jl:7
 [3] setindex!
   @ ~/.julia/packages/Metal/N2ABH/src/device/array.jl:105
 [4] broadcast_linear
   @ ~/.julia/packages/Metal/N2ABH/src/broadcast.jl:89
Reason: unsupported use of double value
Stacktrace:
 [1] Float64
   @ ./float.jl:341
 [2] convert
   @ ./number.jl:7
 [3] _promote
   @ ./promotion.jl:384
 [4] promote
   @ ./promotion.jl:406
 [5] muladd
   @ ./promotion.jl:481
 [6] _broadcast_getindex_evalf
   @ ./broadcast.jl:678
 [7] _broadcast_getindex
   @ ./broadcast.jl:651
 [8] getindex
   @ ./broadcast.jl:610
 [9] broadcast_linear
   @ ~/.julia/packages/Metal/N2ABH/src/broadcast.jl:89
Reason: unsupported use of double value
Stacktrace:
 [1] muladd
   @ ./float.jl:496
 [2] muladd
   @ ./promotion.jl:481
 [3] _broadcast_getindex_evalf
   @ ./broadcast.jl:678
 [4] _broadcast_getindex
   @ ./broadcast.jl:651
 [5] getindex
   @ ./broadcast.jl:610
 [6] broadcast_linear
   @ ~/.julia/packages/Metal/N2ABH/src/broadcast.jl:89
Reason: unsupported use of double value
Stacktrace:
 [1] muladd
   @ ./float.jl:496
 [2] muladd
   @ ./promotion.jl:481
 [3] _broadcast_getindex_evalf
   @ ./broadcast.jl:678
 [4] _broadcast_getindex
   @ ./broadcast.jl:651
 [5] getindex
   @ ./broadcast.jl:610
 [6] broadcast_linear
   @ ~/.julia/packages/Metal/N2ABH/src/broadcast.jl:89
Reason: unsupported use of double value
Stacktrace:
 [1] Float32
   @ ./float.jl:338
 [2] convert
   @ ./number.jl:7
 [3] setindex!
   @ ~/.julia/packages/Metal/N2ABH/src/device/array.jl:105
 [4] broadcast_linear
   @ ~/.julia/packages/Metal/N2ABH/src/broadcast.jl:89
Hint: catch this exception as `err` and call `code_typed(err; interactive = true)` to introspect the erronous code with Cthulhu.jl

Stacktrace:
  [1] check_ir(job::GPUCompiler.CompilerJob{GPUCompiler.MetalCompilerTarget, Metal.MetalCompilerParams}, args::LLVM.Module)
    @ GPUCompiler ~/.julia/packages/GPUCompiler/OGnEB/src/validation.jl:167
  [2] macro expansion
    @ ~/.julia/packages/GPUCompiler/OGnEB/src/driver.jl:381 [inlined]
  [3] emit_llvm(job::GPUCompiler.CompilerJob; toplevel::Bool, libraries::Bool, optimize::Bool, cleanup::Bool, validate::Bool, only_entry::Bool)
    @ GPUCompiler ~/.julia/packages/GPUCompiler/OGnEB/src/utils.jl:110
  [4] emit_llvm
    @ ~/.julia/packages/GPUCompiler/OGnEB/src/utils.jl:108 [inlined]
  [5] codegen(output::Symbol, job::GPUCompiler.CompilerJob; toplevel::Bool, libraries::Bool, optimize::Bool, cleanup::Bool, validate::Bool, strip::Bool, only_entry::Bool, parent_job::Nothing)
    @ GPUCompiler ~/.julia/packages/GPUCompiler/OGnEB/src/driver.jl:100
  [6] codegen(output::Symbol, job::GPUCompiler.CompilerJob)
    @ GPUCompiler ~/.julia/packages/GPUCompiler/OGnEB/src/driver.jl:82
  [7] compile(target::Symbol, job::GPUCompiler.CompilerJob; kwargs::@Kwargs{})
    @ GPUCompiler ~/.julia/packages/GPUCompiler/OGnEB/src/driver.jl:79
  [8] compile
    @ ~/.julia/packages/GPUCompiler/OGnEB/src/driver.jl:74 [inlined]
  [9] (::Metal.var"#157#165"{GPUCompiler.CompilerJob{GPUCompiler.MetalCompilerTarget, Metal.MetalCompilerParams}})(ctx::LLVM.Context)
    @ Metal ~/.julia/packages/Metal/N2ABH/src/compiler/compilation.jl:108
 [10] JuliaContext(f::Metal.var"#157#165"{GPUCompiler.CompilerJob{GPUCompiler.MetalCompilerTarget, Metal.MetalCompilerParams}}; kwargs::@Kwargs{})
    @ GPUCompiler ~/.julia/packages/GPUCompiler/OGnEB/src/driver.jl:34
 [11] JuliaContext(f::Function)
    @ GPUCompiler ~/.julia/packages/GPUCompiler/OGnEB/src/driver.jl:25
 [12] macro expansion
    @ ~/.julia/packages/Metal/N2ABH/src/compiler/compilation.jl:107 [inlined]
 [13] macro expansion
    @ ~/.julia/packages/ObjectiveC/TgrW6/src/os.jl:264 [inlined]
 [14] compile(job::GPUCompiler.CompilerJob)
    @ Metal ~/.julia/packages/Metal/N2ABH/src/compiler/compilation.jl:105
 [15] actual_compilation(cache::Dict{Any, Any}, src::Core.MethodInstance, world::UInt64, cfg::GPUCompiler.CompilerConfig{GPUCompiler.MetalCompilerTarget, Metal.MetalCompilerParams}, compiler::typeof(Metal.compile), linker::typeof(Metal.link))
    @ GPUCompiler ~/.julia/packages/GPUCompiler/OGnEB/src/execution.jl:237
 [16] cached_compilation(cache::Dict{Any, Any}, src::Core.MethodInstance, cfg::GPUCompiler.CompilerConfig{GPUCompiler.MetalCompilerTarget, Metal.MetalCompilerParams}, compiler::Function, linker::Function)
    @ GPUCompiler ~/.julia/packages/GPUCompiler/OGnEB/src/execution.jl:151
 [17] macro expansion
    @ ~/.julia/packages/Metal/N2ABH/src/compiler/execution.jl:189 [inlined]
 [18] macro expansion
    @ ./lock.jl:273 [inlined]
 [19] mtlfunction(f::Metal.var"#broadcast_linear#204", tt::Type{Tuple{MtlDeviceVector{Float32, 1}, Base.Broadcast.Broadcasted{Metal.MtlArrayStyle{1, Metal.PrivateStorage}, Tuple{Base.OneTo{Int64}}, typeof(muladd), Tuple{Base.Broadcast.Broadcasted{Metal.MtlArrayStyle{1, Metal.PrivateStorage}, Nothing, typeof(DiffEqBase.ODE_DEFAULT_NORM), Tuple{Base.Broadcast.Extruded{MtlDeviceVector{Float32, 1}, Tuple{Bool}, Tuple{Int64}}, Float32}}, Float64, Float64}}}}; name::Nothing, kwargs::@Kwargs{})
    @ Metal ~/.julia/packages/Metal/N2ABH/src/compiler/execution.jl:184
 [20] mtlfunction
    @ ~/.julia/packages/Metal/N2ABH/src/compiler/execution.jl:182 [inlined]
 [21] macro expansion
    @ ~/.julia/packages/Metal/N2ABH/src/compiler/execution.jl:85 [inlined]
 [22] _copyto!
    @ ~/.julia/packages/Metal/N2ABH/src/broadcast.jl:95 [inlined]
 [23] materialize!
    @ ~/.julia/packages/Metal/N2ABH/src/broadcast.jl:43 [inlined]
 [24] materialize!
    @ ./broadcast.jl:880 [inlined]
 [25] fast_materialize!
    @ ~/.julia/packages/FastBroadcast/wfdTr/src/FastBroadcast.jl:279 [inlined]
 [26] ode_determine_initdt(u0::MtlVector{Float32, Metal.PrivateStorage}, t::Float32, tdir::Float32, dtmax::Float32, abstol::Float64, reltol::Float64, internalnorm::typeof(DiffEqBase.ODE_DEFAULT_NORM), prob::ODEProblem{MtlVector{Float32, Metal.PrivateStorage}, Tuple{Float32, Float32}, true, MtlMatrix{Float32, Metal.PrivateStorage}, ODEFunction{true, true, [...]
    @ OrdinaryDiffEqCore ~/.julia/packages/OrdinaryDiffEqCore/vS7Uo/src/initdt.jl:26
 [27] auto_dt_reset!
    @ ~/.julia/packages/OrdinaryDiffEqCore/vS7Uo/src/integrators/integrator_interface.jl:431 [inlined]
 [28] handle_dt!(integrator::OrdinaryDiffEqCore.ODEIntegrator{Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), [...]
    @ OrdinaryDiffEqCore ~/.julia/packages/OrdinaryDiffEqCore/vS7Uo/src/solve.jl:647
 [29] __init(prob::ODEProblem{MtlVector{Float32, Metal.PrivateStorage}, Tuple{Float32, Float32}, true, MtlMatrix{Float32, Metal.PrivateStorage}, [...]
    @ OrdinaryDiffEqCore ~/.julia/packages/OrdinaryDiffEqCore/vS7Uo/src/solve.jl:609
 [30] __init (repeats 5 times)
    @ ~/.julia/packages/OrdinaryDiffEqCore/vS7Uo/src/solve.jl:11 [inlined]
 [31] #__solve#62
    @ ~/.julia/packages/OrdinaryDiffEqCore/vS7Uo/src/solve.jl:6 [inlined]
 [32] __solve
    @ ~/.julia/packages/OrdinaryDiffEqCore/vS7Uo/src/solve.jl:1 [inlined]
 [33] solve_call(_prob::ODEProblem{MtlVector{Float32, Metal.PrivateStorage}, Tuple{Float32, Float32}, true, MtlMatrix{Float32, Metal.PrivateStorage}, ODEFunction{true, true, [...]
    @ DiffEqBase ~/.julia/packages/DiffEqBase/PbBEl/src/solve.jl:635
 [34] solve_call
    @ ~/.julia/packages/DiffEqBase/PbBEl/src/solve.jl:592 [inlined]
 [35] #solve_up#44
    @ ~/.julia/packages/DiffEqBase/PbBEl/src/solve.jl:1128 [inlined]
 [36] solve_up
    @ ~/.julia/packages/DiffEqBase/PbBEl/src/solve.jl:1106 [inlined]
 [37] #solve#42
    @ ~/.julia/packages/DiffEqBase/PbBEl/src/solve.jl:1043 [inlined]
 [38] _adjoint_sensitivities(sol::ODESolution{Float32, 3, Vector{MtlMatrix{Float32, Metal.PrivateStorage}}, Nothing, Nothing, Vector{Float32}, Vector{Vector{MtlMatrix{Float32, Metal.PrivateStorage}}}, Nothing, ODEProblem{MtlMatrix{Float32, Metal.PrivateStorage}, Tuple{Float32, Float32}, false, MtlMatrix{Float32, Metal.PrivateStorage}, ODEFunction{false, SciMLBase.AutoSpecialize, typeof(velocity_test), UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing, Nothing, Nothing}, @Kwargs{}, SciMLBase.StandardODEProblem}, Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}, OrdinaryDiffEqCore.InterpolationData{ODEFunction{false, SciMLBase.AutoSpecialize, typeof(velocity_test), UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing, Nothing, Nothing}, Vector{MtlMatrix{Float32, Metal.PrivateStorage}}, Vector{Float32}, Vector{Vector{MtlMatrix{Float32, Metal.PrivateStorage}}}, Nothing, OrdinaryDiffEqTsit5.Tsit5ConstantCache, Nothing}, SciMLBase.DEStats, Nothing, Nothing, Nothing, Nothing}, sensealg::BacksolveAdjoint{0, true, Val{:central}, ZygoteVJP}, alg::Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}; t::Vector{Float32}, dgdu_discrete::Function, dgdp_discrete::Nothing, dgdu_continuous::Nothing, dgdp_continuous::Nothing, g::Nothing, abstol::Float64, reltol::Float64, checkpoints::Vector{Float32}, corfunc_analytical::Nothing, callback::Nothing, kwargs::@Kwargs{verbose::Bool})
    @ SciMLSensitivity ~/.julia/packages/SciMLSensitivity/7UEpc/src/sensitivity_interface.jl:448
 [39] _adjoint_sensitivities
    @ ~/.julia/packages/SciMLSensitivity/7UEpc/src/sensitivity_interface.jl:405 [inlined]
 [40] #adjoint_sensitivities#63
    @ ~/.julia/packages/SciMLSensitivity/7UEpc/src/sensitivity_interface.jl:401 [inlined]
 [41] (::SciMLSensitivity.var"#adjoint_sensitivity_backpass#323"{@Kwargs{}, Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}, BacksolveAdjoint{0, true, Val{:central}, ZygoteVJP}, MtlMatrix{Float32, Metal.PrivateStorage}, MtlMatrix{Float32, Metal.PrivateStorage}, SciMLBase.ChainRulesOriginator, Tuple{}, Colon, @NamedTuple{}})(Δ::ODESolution{Float32, 3, Vector{MtlMatrix{Float32, Metal.PrivateStorage}}, Nothing, Nothing, Vector{Float32}, Nothing, Nothing, ODEProblem{MtlMatrix{Float32, Metal.PrivateStorage}, Tuple{Float32, Float32}, false, MtlMatrix{Float32, Metal.PrivateStorage}, ODEFunction{false, SciMLBase.AutoSpecialize, typeof(velocity_test), UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing, Nothing, Nothing}, @Kwargs{}, SciMLBase.StandardODEProblem}, Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}, SciMLBase.LinearInterpolation{Vector{Float32}, Vector{MtlMatrix{Float32, Metal.PrivateStorage}}}, Nothing, Nothing, Nothing, Nothing, Nothing})
    @ SciMLSensitivity ~/.julia/packages/SciMLSensitivity/7UEpc/src/concrete_solve.jl:627
 [42] ZBack
    @ ~/.julia/packages/Zygote/TWpme/src/compiler/chainrules.jl:212 [inlined]
 [43] (::Zygote.var"#295#296"{Tuple{NTuple{4, Nothing}, Tuple{Nothing}}, Zygote.ZBack{SciMLSensitivity.var"#adjoint_sensitivity_backpass#323"{@Kwargs{}, Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}, BacksolveAdjoint{0, true, Val{:central}, ZygoteVJP}, MtlMatrix{Float32, Metal.PrivateStorage}, MtlMatrix{Float32, Metal.PrivateStorage}, SciMLBase.ChainRulesOriginator, Tuple{}, Colon, @NamedTuple{}}}})(Δ::ODESolution{Float32, 3, Vector{MtlMatrix{Float32, Metal.PrivateStorage}}, Nothing, Nothing, Vector{Float32}, Nothing, Nothing, ODEProblem{MtlMatrix{Float32, Metal.PrivateStorage}, Tuple{Float32, Float32}, false, MtlMatrix{Float32, Metal.PrivateStorage}, ODEFunction{false, SciMLBase.AutoSpecialize, typeof(velocity_test), UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing, Nothing, Nothing}, @Kwargs{}, SciMLBase.StandardODEProblem}, Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}, SciMLBase.LinearInterpolation{Vector{Float32}, Vector{MtlMatrix{Float32, Metal.PrivateStorage}}}, Nothing, Nothing, Nothing, Nothing, Nothing})
    @ Zygote ~/.julia/packages/Zygote/TWpme/src/lib/lib.jl:205
 [44] (::Zygote.var"#2169#back#297"{Zygote.var"#295#296"{Tuple{NTuple{4, Nothing}, Tuple{Nothing}}, Zygote.ZBack{SciMLSensitivity.var"#adjoint_sensitivity_backpass#323"{@Kwargs{}, Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}, BacksolveAdjoint{0, true, Val{:central}, ZygoteVJP}, MtlMatrix{Float32, Metal.PrivateStorage}, MtlMatrix{Float32, Metal.PrivateStorage}, SciMLBase.ChainRulesOriginator, Tuple{}, Colon, @NamedTuple{}}}}})(Δ::ODESolution{Float32, 3, Vector{MtlMatrix{Float32, Metal.PrivateStorage}}, Nothing, Nothing, Vector{Float32}, Nothing, Nothing, ODEProblem{MtlMatrix{Float32, Metal.PrivateStorage}, Tuple{Float32, Float32}, false, MtlMatrix{Float32, Metal.PrivateStorage}, ODEFunction{false, SciMLBase.AutoSpecialize, typeof(velocity_test), UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing, Nothing, Nothing}, @Kwargs{}, SciMLBase.StandardODEProblem}, Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}, SciMLBase.LinearInterpolation{Vector{Float32}, Vector{MtlMatrix{Float32, Metal.PrivateStorage}}}, Nothing, Nothing, Nothing, Nothing, Nothing})
    @ Zygote ~/.julia/packages/ZygoteRules/CkVIK/src/adjoint.jl:72
 [45] #solve#42
    @ ~/.julia/packages/DiffEqBase/PbBEl/src/solve.jl:1043 [inlined]
 [46] (::Zygote.Pullback{Tuple{DiffEqBase.var"##solve#42", BacksolveAdjoint{0, true, Val{:central}, ZygoteVJP}, Nothing, Nothing, Val{true}, @Kwargs{}, typeof(solve), ODEProblem{MtlMatrix{Float32, Metal.PrivateStorage}, Tuple{Float32, Float32}, false, MtlMatrix{Float32, Metal.PrivateStorage}, ODEFunction{false, SciMLBase.AutoSpecialize, typeof(velocity_test), UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing,[...]Nothing, Nothing, Nothing, Nothing})
    @ Zygote ~/.julia/packages/Zygote/TWpme/src/compiler/interface2.jl:0
 [47] #295
    @ ~/.julia/packages/Zygote/TWpme/src/lib/lib.jl:205 [inlined]
 [48] (::Zygote.var"#2169#back#297"{Zygote.var"#295#296"{Tuple{NTuple{7, Nothing}, Tuple{Nothing}}, Zygote.Pullback{Tuple{DiffEqBase.var"##solve#42", BacksolveAdjoint{0, true, Val{:central}, ZygoteVJP}, [...]
    @ Zygote ~/.julia/packages/ZygoteRules/CkVIK/src/adjoint.jl:72
 [49] solve
    @ ~/.julia/packages/DiffEqBase/PbBEl/src/solve.jl:1033 [inlined]
 [50] (::Zygote.Pullback{Tuple{typeof(Core.kwcall), @NamedTuple{sensealg::BacksolveAdjoint{0, true, Val{:central}, ZygoteVJP}}, typeof(solve), ODEProblem{MtlMatrix{Float32, Metal.PrivateStorage}, [...]
    @ Zygote ~/.julia/packages/Zygote/TWpme/src/compiler/interface2.jl:0
 [51] loss
    @ ~/Documents/PhD/Projects/RiemannianFlows.jl/jl_notebook_cell_df34fa98e69747e1a8f8a730347b8e2f_Y130sZmlsZQ==.jl:20 [inlined]
 [52] (::Zygote.Pullback{Tuple{typeof(loss), MtlMatrix{Float32, Metal.PrivateStorage}}, Tuple{Zygote.ZBack{SciMLBaseChainRulesCoreExt.var"#ODEProblemAdjoint#14"}, [...]
    @ Zygote ~/.julia/packages/Zygote/TWpme/src/compiler/interface2.jl:0
 [53] (::Zygote.var"#78#79"{Zygote.Pullback{Tuple{typeof(loss), MtlMatrix{Float32, Metal.PrivateStorage}}, Tuple{Zygote.ZBack{SciMLBaseChainRulesCoreExt.var"#ODEProblemAdjoint#14"}, Zygote.Pullback{Tuple{typeof(Core.kwcall), @NamedTuple{autojacvec::ZygoteVJP}, Type{BacksolveAdjoint}}, Any}, Zygote.var"#2013#back#208"{typeof(identity)}, Zygote.var"#1986#back#198"{Zygote.var"#194#197"{Zygote.Context{false}, GlobalRef, MtlMatrix{Float32, Metal.PrivateStorage}}}, Zygote.Pullback{Tuple{typeof(last), Vector{MtlMatrix{Float32, Metal.PrivateStorage}}}, Tuple{Zygote.ZBack{ChainRules.var"#getindex_pullback#670"{Vector{MtlMatrix{Float32, Metal.PrivateStorage}}, Tuple{Int64}, Tuple{ChainRulesCore.NoTangent}}}, Zygote.Pullback{Tuple{typeof(lastindex), Vector{MtlMatrix{Float32, Metal.PrivateStorage}}}, Tuple{Zygote.Pullback{Tuple{typeof(last), Base.OneTo{Int64}}, Tuple{Zygote.ZBack{Zygote.var"#convert_pullback#334"}, Zygote.var"#2180#back#307"{Zygote.var"#back#306"{:stop, Zygote.Context{false}, Base.OneTo{Int64}, Int64}}}}, Zygote.ZBack{Returns{Tuple{ChainRulesCore.NoTangent, ChainRulesCore.NoTangent, ChainRulesCore.NoTangent}}}, Zygote.Pullback{Tuple{Type{IndexLinear}}, Tuple{}}}}}}, Zygote.ZBack{ChainRules.var"#mean_pullback#890"{Int64, ChainRules.var"#sum_pullback#720"{Colon, MtlMatrix{Float32, Metal.PrivateStorage}, ChainRulesCore.ProjectTo{AbstractArray, @NamedTuple{element::ChainRulesCore.ProjectTo{Float32, @NamedTuple{}}, axes::Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}}}}}, Zygote.Pullback{Tuple{typeof(Core.kwcall), @NamedTuple{sensealg::BacksolveAdjoint{0, true, Val{:central}, ZygoteVJP}}, typeof(solve), ODEProblem{MtlMatrix{Float32, Metal.PrivateStorage}, Tuple{Float32, Float32}, false, MtlMatrix{Float32, Metal.PrivateStorage}, ODEFunction{false, SciMLBase.AutoSpecialize, typeof(velocity_test), UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing, Nothing, Nothing}, @Kwargs{}, SciMLBase.StandardODEProblem}, Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}}, Any}, Zygote.Pullback{Tuple{Type{Tsit5}}, Tuple{}}, Zygote.Pullback{Tuple{Type{NamedTuple{(:sensealg,)}}, Tuple{BacksolveAdjoint{0, true, Val{:central}, ZygoteVJP}}}, Tuple{Zygote.var"#2220#back#319"{Zygote.Jnew{@NamedTuple{sensealg::BacksolveAdjoint{0, true, Val{:central}, ZygoteVJP}}, Nothing, true}}}}, SciMLSensitivity.var"#345#back#163"{SciMLSensitivity.var"#solu_adjoint#162"{ODESolution{Float32, 3, Vector{MtlMatrix{Float32, Metal.PrivateStorage}}, Nothing, Nothing, Vector{Float32}, Vector{Vector{MtlMatrix{Float32, Metal.PrivateStorage}}}, Nothing, ODEProblem{MtlMatrix{Float32, Metal.PrivateStorage}, Tuple{Float32, Float32}, false, MtlMatrix{Float32, Metal.PrivateStorage}, ODEFunction{false, SciMLBase.AutoSpecialize, typeof(velocity_test), UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing, Nothing, Nothing}, @Kwargs{}, SciMLBase.StandardODEProblem}, Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}, OrdinaryDiffEqCore.InterpolationData{ODEFunction{false, SciMLBase.AutoSpecialize, typeof(velocity_test), UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing, Nothing, Nothing}, Vector{MtlMatrix{Float32, Metal.PrivateStorage}}, Vector{Float32}, Vector{Vector{MtlMatrix{Float32, Metal.PrivateStorage}}}, Nothing, OrdinaryDiffEqTsit5.Tsit5ConstantCache, Nothing}, SciMLBase.DEStats, Nothing, Nothing, Nothing, Nothing}}}, Zygote.var"#1986#back#198"{Zygote.var"#194#197"{Zygote.Context{false}, GlobalRef, Tuple{Float32, Float32}}}, Zygote.Pullback{Tuple{Type{NamedTuple{(:autojacvec,)}}, Tuple{ZygoteVJP}}, Tuple{Zygote.var"#2220#back#319"{Zygote.Jnew{@NamedTuple{autojacvec::ZygoteVJP}, Nothing, true}}}}, Zygote.var"#2013#back#208"{typeof(identity)}, Zygote.Pullback{Tuple{Type{ZygoteVJP}}, Tuple{}}}}})(Δ::Float32)
    @ Zygote ~/.julia/packages/Zygote/TWpme/src/compiler/interface.jl:91
 [54] gradient(f::Function, args::MtlMatrix{Float32, Metal.PrivateStorage})
    @ Zygote ~/.julia/packages/Zygote/TWpme/src/compiler/interface.jl:148

oh, aparently you need to set abstol and reltol as Float32 also. We should make those be of utype also…

Edit: it looks like we already do this. I’m very confused as to what’s going wrong…

1 Like

One thing worth asking is are you using up to date package versions?

1 Like

Thanks this worked indeed! I didn’t spot this because the ODE solver still worked and only SciML gave the problem.

I probably also had old versions of the packages, but after updating I get the same error if I don’t fix the tolerances to Float32.

1 Like