Implementing Three-Dimensional Interpolation, ODEProblem, and Gradient Computation with sensealg=GaussAdjoint(autojacvec=EnzymeVJP())

Hello, the problem I am currently facing is solving an ordinary differential equation (ODE) with interpolated data inputs. The problem can be visualized as a snow layer, where the snow layer represents the intermediate state value, snowfall is the input supply, and snowmelt is the consumption value. The state equation can be expressed as:

D(snowpack) ~ snowfall - melt

Here, snowfall is obtained based on observed inputs. We need to interpolate precipitation and temperature data in advance to obtain precipitation and temperature values at any time period, which are then used to calculate snowfall. The melt is calculated based on the snowpack combined with temperature values. The constructed ODE function is as follows:

using Zygote
using OrdinaryDiffEq
using SciMLSensitivity
using MLUtils
using ComponentArrays
using DataInterpolations

step_func(x) = (tanh(5.0 * x) + 1.0) * 0.5

tmp_func = (inputs, states, pas) -> begin
    prcp = inputs[1]
    temp = inputs[2]
    lday = inputs[3]
    snowpack = states[1]
    Tmin = pas.params.Tmin
    Df = pas.params.Df
    Tmax = pas.params.Tmax
    snowfall = step_func(Tmin - temp) * prcp
    melt = step_func(temp - Tmax) * min(snowpack, Df * (temp - Tmax))
    return [snowfall - melt]
end


Df, Tmax, Tmin = 2.67, 0.17, -2.09
ps = ComponentVector(params=(Df=Df, Tmax=Tmax, Tmin=Tmin))
ps_axes = getaxes(ps)
u0 = zeros(eltype(ps), 1)

node_input = rand(3, 20)
timeidx = collect(1:20)
itpfunc = LinearInterpolation(node_input, timeidx)

function ode_func!(du, u, p, t)
    du .= tmp_func(itpfunc(t), u, ComponentVector(p, ps_axes))
end

tspan = (1.0, 20.0)
prob = ODEProblem(ode_func!, u0, tspan, Vector(ps))
sol = solve(prob, Tsit5(), sensealg=BacksolveAdjoint(autojacvec=EnzymeVJP()))

Zygote.gradient(Vector(ps)) do p
    prob = ODEProblem(ode_func!, u0, tspan, p)
    sol = solve(prob, Tsit5(), sensealg=BacksolveAdjoint(autojacvec=EnzymeVJP())) |> Array
    return sum(sol)
end

At this point, I want to extend the dimensionality of this problem. Imagine that I have multiple snowpack modules, for example, 10. The input should then be three-dimensional: node_input=rand(3, 10, 20). The second dimension of the input data corresponds to the number of snowpack modules. At the same time, the initial value u0 in the ODE problem will also become a 1×10 dimensional vector. The specific code is as follows:

tmp_func2 = (inputs, states, pas) -> begin
    prcp = inputs[1]
    temp = inputs[2]
    lday = inputs[3]
    snowpack = states[1]
    Tmin = pas.params.Tmin
    Df = pas.params.Df
    Tmax = pas.params.Tmax
    snowfall = @. step_func(Tmin - temp) * prcp
    melt = @. step_func(temp - Tmax) * min(snowpack, Df * (temp - Tmax))
    return stack([snowfall .- melt], dims=1)
end

Df, Tmax, Tmin = 2.674548848, 0.175739196, -2.092959084
ps = ComponentVector(params=(Df=Df, Tmax=Tmax, Tmin=Tmin))
ps_axes = getaxes(ps)
u0 = rand(1, 10)

node_input = rand(3, 10, 20)
itpfuncs = LinearInterpolation.(eachslice(node_input, dims=1), Ref(1:20)) # raise wrong
tmp_input_func = (t) -> ntuple(i -> itpfuncs[i](t), length(itpfuncs))
itpfunc = LinearInterpolation(reshape(node_input, :, 20), 1:20) # reshape before seems ok

function multi_ode_func!(du, u, p, t)
    tmp_input = reshape(itpfunc(t), 3, 10)
    du .= tmp_func2(eachslice(tmp_input, dims=1), eachslice(u, dims=1), ComponentVector(p, ps_axes))
end

tspan = (1.0, 20.0)
prob = ODEProblem(multi_ode_func!, u0, tspan, Vector(ps))
sol = solve(prob, Tsit5(), sensealg=GaussAdjoint(autojacvec=EnzymeVJP())) |> Array

Zygote.gradient(Vector(ps)) do p
    prob = ODEProblem(multi_ode_func!, u0, tspan, p)
    sol = solve(prob, Tsit5(), sensealg=GaussAdjoint(autojacvec=EnzymeVJP())) |> Array |> sum
end

This error occurred when calculating the gradient:

ERROR: Enzyme compilation failed due to an internal error.
 Please open an issue with the code to reproduce and full error log on github.com/EnzymeAD/Enzyme.jl
 To toggle more information for debugging (needed for bug reports), set Enzyme.Compiler.VERBOSE_ERRORS[] = true (default false)

Illegal replace ficticious phi for:   %_replacementA677 = phi {} addrspace(10)* , !dbg !765 of   %493 = call noalias nonnull "enzyme_type"="{[-1]:Pointer, [-1,0]:Integer, [-1,1]:Integer, [-1,2]:Integer, [-1,3]:Integer, [-1,4]:Integer, [-1,5]:Integer, [-1,6]:Integer, [-1,7]:Integer, [-1,8]:Pointer, [-1,8,-1]:Float@double}" {} addrspace(10)* @jl_genericmemory_copy_slice({} addrspace(10)* %787, i64 %492, i64 %403) #57, !dbg !612

Stacktrace:
  [1] copy
    @ .\array.jl:350
  [2] unaliascopy
    @ .\abstractarray.jl:1516
  [3] unalias
    @ .\abstractarray.jl:1500
  [4] broadcast_unalias
    @ .\broadcast.jl:941
  [5] preprocess
    @ .\broadcast.jl:948
  [6] preprocess_args
    @ .\broadcast.jl:950
  [7] preprocess
    @ .\broadcast.jl:947
  [8] override_bc_copyto!
    @ D:\Julia\packages\packages\Enzyme\g1jMR\src\compiler\interpreter.jl:798
  [9] copyto!
    @ .\broadcast.jl:920
 [10] override_bc_materialize
    @ D:\Julia\packages\packages\Enzyme\g1jMR\src\compiler\interpreter.jl:854
 [11] #83
    @ e:\JlCode\HydroModels\dev\base\tmp_test.jl:41
 [12] #83
    @ e:\JlCode\HydroModels\dev\base\tmp_test.jl:0

Stacktrace:
  [1] julia_error(msg::String, val::Ptr{…}, errtype::Enzyme.API.ErrorType, data::Ptr{…}, data2::Ptr{…}, B::Ptr{…})
    @ Enzyme.Compiler D:\Julia\packages\packages\Enzyme\g1jMR\src\errors.jl:384
  [2] julia_error(cstr::Cstring, val::Ptr{…}, errtype::Enzyme.API.ErrorType, data::Ptr{…}, data2::Ptr{…}, B::Ptr{…})
    @ Enzyme.Compiler D:\Julia\packages\packages\Enzyme\g1jMR\src\errors.jl:210
  [3] EnzymeCreatePrimalAndGradient(logic::Enzyme.Logic, todiff::LLVM.Function, retType::Enzyme.API.CDIFFE_TYPE, constant_args::Vector{…}, TA::Enzyme.TypeAnalysis, returnValue::Bool, dretUsed::Bool, mode::Enzyme.API.CDerivativeMode, runtimeActivity::Bool, width::Int64, additionalArg::Ptr{…}, forceAnonymousTape::Bool, typeInfo::Enzyme.FnTypeInfo, uncacheable_args::Vector{…}, augmented::Ptr{…}, atomicAdd::Bool)
    @ Enzyme.API D:\Julia\packages\packages\Enzyme\g1jMR\src\api.jl:269
  [4] enzyme!(job::GPUCompiler.CompilerJob{…}, mod::LLVM.Module, primalf::LLVM.Function, TT::Type, mode::Enzyme.API.CDerivativeMode, width::Int64, parallel::Bool, actualRetType::Type, wrap::Bool, modifiedBetween::NTuple{…} where N, returnPrimal::Bool, expectedTapeType::Type, loweredArgs::Set{…}, boxedArgs::Set{…}) 
    @ Enzyme.Compiler D:\Julia\packages\packages\Enzyme\g1jMR\src\compiler.jl:1707
  [5] codegen(output::Symbol, job::GPUCompiler.CompilerJob{…}; libraries::Bool, deferred_codegen::Bool, optimize::Bool, toplevel::Bool, strip::Bool, validate::Bool, only_entry::Bool, parent_job::Nothing)
    @ Enzyme.Compiler D:\Julia\packages\packages\Enzyme\g1jMR\src\compiler.jl:4655
  [6] codegen
    @ D:\Julia\packages\packages\Enzyme\g1jMR\src\compiler.jl:3441 [inlined]
  [7] _thunk(job::GPUCompiler.CompilerJob{Enzyme.Compiler.EnzymeTarget, Enzyme.Compiler.EnzymeCompilerParams}, postopt::Bool)
    @ Enzyme.Compiler D:\Julia\packages\packages\Enzyme\g1jMR\src\compiler.jl:5515
  [8] _thunk
    @ D:\Julia\packages\packages\Enzyme\g1jMR\src\compiler.jl:5515 [inlined]
  [9] cached_compilation
    @ D:\Julia\packages\packages\Enzyme\g1jMR\src\compiler.jl:5567 [inlined]
 [10] thunkbase(mi::Core.MethodInstance, World::UInt64, FA::Type{…}, A::Type{…}, TT::Type, Mode::Enzyme.API.CDerivativeMode, width::Int64, ModifiedBetween::NTuple{…} where N, ReturnPrimal::Bool, ShadowInit::Bool, ABI::Type, ErrIfFuncWritten::Bool, RuntimeActivity::Bool, edges::Vector{…})
    @ Enzyme.Compiler D:\Julia\packages\packages\Enzyme\g1jMR\src\compiler.jl:5678
 [11] thunk_generator(world::UInt64, source::LineNumberNode, FA::Type, A::Type, TT::Type, Mode::Enzyme.API.CDerivativeMode, Width::Int64, ModifiedBetween::NTuple{…} where N, ReturnPrimal::Bool, ShadowInit::Bool, ABI::Type, ErrIfFuncWritten::Bool, RuntimeActivity::Bool, self::Any, fakeworld::Any, fa::Type, a::Type, tt::Type, mode::Type, width::Type, modifiedbetween::Type, returnprimal::Type, shadowinit::Type, abi::Type, erriffuncwritten::Type, runtimeactivity::Type)     
    @ Enzyme.Compiler D:\Julia\packages\packages\Enzyme\g1jMR\src\compiler.jl:5863
 [12] runtime_generic_augfwd(activity::Type{…}, runtimeActivity::Val{…}, width::Val{…}, ModifiedBetween::Val{…}, RT::Val{…}, f::var"#83#84", df::Nothing, primal_1::RowSlices{…}, shadow_1_1::RowSlices{…}, primal_2::RowSlices{…}, shadow_2_1::RowSlices{…}, primal_3::ComponentVector{…}, shadow_3_1::ComponentVector{…}) 
    @ Enzyme.Compiler D:\Julia\packages\packages\Enzyme\g1jMR\src\rules\jitrules.jl:445
 [13] multi_ode_func!
    @ e:\JlCode\HydroModels\dev\base\tmp_test.jl:56
Some type information was truncated. Use `show(err)` to see complete types.

Illegal replace ficticious phi for:   %_replacementA677 = phi {} addrspace(10)* , !dbg !765 of   %493 = call noalias nonnull "enzyme_type"="{[-1]:Pointer, [-1,0]:Integer, [-1,1]:Integer, [-1,2]:Integer, [-1,3]:Integer, [-1,4]:Integer, [-1,5]:Integer, [-1,6]:Integer, [-1,7]:Integer, [-1,8]:Pointer, [-1,8,-1]:Float@double}" {} addrspace(10)* @jl_genericmemory_copy_slice({} addrspace(10)* %787, i64 %492, i64 %403) #57, !dbg !612

But using ZygoteVJP is ok:

Zygote.gradient(Vector(ps)) do p
    prob = ODEProblem(multi_ode_func!, u0, tspan, p)
    sol = solve(prob, Tsit5(), sensealg=BacksolveAdjoint(autojacvec=ZygoteVJP())) |> Array |> sum
end

#output ([2.363584938462766, 210.78169299208085, 4.2921775834541405e-7],)

Although ZygoteVJP can meet my needs, I found during my own usage that EnzymeVJP computes gradients much faster. Therefore, I hope this requirement can be correctly implemented under EnzymeVJP. However, from the code errors, it seems there is a compilation issue. As a non-computer science Julia user, interpreting this problem is somewhat challenging for me. Additionally, I found that using the following form for input is more convenient:

itpfuncs = LinearInterpolation.(eachslice(node_input, dims=1), Ref(1:20)) # raises error
tmp_input_func = (t) -> ntuple(i -> itpfuncs[i](t), length(itpfuncs))

However, Enzyme does not seem to support this form well either.
In summary, I sincerely plead for someone to assist me in resolving my issue. My environment is as follows:

Julia Version 1.11.2
Commit 5e9a32e7af (2024-12-01 20:02 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Windows (x86_64-w64-mingw32)
  CPU: 24 × 12th Gen Intel(R) Core(TM) i9-12900HX
  WORD_SIZE: 64
  LLVM: libLLVM-16.0.6 (ORCJIT, alderlake)
Threads: 1 default, 0 interactive, 1 GC (on 24 virtual cores)
Environment:
  JULIA_DEPOT_PATH = D:\Julia\packages
  JULIA_PKG_SERVER = https://mirrors.pku.edu.cn/julia/
  JULIA_EDITOR = code
  JULIA_NUM_THREADS =
(dev) pkg> st Zygote OrdinaryDiffEq SciMLSensitivity MLUtils ComponentArrays 
  [b0b7db55] ComponentArrays v0.15.26
  [82cc6244] DataInterpolations v8.0.0
  [f1d291b0] MLUtils v0.4.8
  [1dea7af3] OrdinaryDiffEq v6.93.0
  [1ed8b502] SciMLSensitivity v7.76.0
  [e88e6eb3] Zygote v0.7.6

Open an issue on DataInterpolations.jl on Enzyme compatibility. It should be possible to make an MWE that does not have the ODE solver but is directly calling Enzyme. If you make that MWE and post an issue it will be much easier to solve this.

Following your suggestion, I first tried writing an MWE. Then I discovered that the issue arises when I call the computation process of other functions, but it doesn’t occur when I copy the computation process of those functions into the current function. The code is as follows:


matrices = rand(3, 3, 10)
t = collect(1:10)
multi_interps = LinearInterpolation(reshape(matrices, :, 10), t)

function tmp_func3(inputs, p)
    prcp = inputs[1]
    temp = inputs[2]
    lday = inputs[3]
    Tmin = p[1]
    Df = p[2]
    Tmax = p[3]
    snowfall = prcp .* (temp .- Tmin) .* Df
    melt = lday .* (temp .- Tmax)
    return stack([snowfall .- melt], dims=1) |> sum
end

function f(x, p)
    inputs = eachslice(reshape(multi_interps(x), 3, 3), dims=1)
    prcp = inputs[1]
    temp = inputs[2]
    lday = inputs[3]
    Tmin = p[1]
    Df = p[2]
    Tmax = p[3]
    snowfall = prcp .* (temp .- Tmin) .* Df
    melt = lday .* (temp .- Tmax)
    return stack([snowfall .- melt], dims=1) |> sum
    # return tmp_func3(inputs, p)
end

ps = [2.3, 3.4, 5.0]
ps_d = zeros(eltype(ps), length(ps))
grad = Enzyme.autodiff(Reverse, f, Active, Active(3.5), Duplicated(ps, ps_d))[1]

In this code, it works correctly, but when I call the line return tmp_func3(inputs, p), this error occurs:

ERROR: Enzyme compilation failed due to an internal error.
 Please open an issue with the code to reproduce and full error log on github.com/EnzymeAD/Enzyme.jl
 To toggle more information for debugging (needed for bug reports), set Enzyme.Compiler.VERBOSE_ERRORS[] = true (default false)     

Illegal replace ficticious phi for:   %_replacementA955 = phi {} addrspace(10)* , !dbg !649 of   %577 = call noalias nonnull "enzyme_type"="{[-1]:Pointer, [-1,0]:Integer, [-1,1]:Integer, [-1,2]:Integer, [-1,3]:Integer, [-1,4]:Integer, [-1,5]:Integer, [-1,6]:Integer, [-1,7]:Integer, [-1,8]:Pointer, [-1,8,-1]:Float@double}" {} addrspace(10)* @jl_genericmemory_copy_slice({} addrspace(10)* %843, i64 %576, i64 %468) #56, !dbg !498

Stacktrace:
  [1] copy
    @ .\array.jl:350
  [2] unaliascopy
    @ .\abstractarray.jl:1516
  [3] unalias
    @ .\abstractarray.jl:1500
  [4] broadcast_unalias
    @ .\broadcast.jl:941
  [5] preprocess
    @ .\broadcast.jl:948
  [6] preprocess_args
    @ .\broadcast.jl:951
  [7] preprocess_args
    @ .\broadcast.jl:950
  [8] preprocess
    @ .\broadcast.jl:947
  [9] override_bc_copyto!
    @ D:\Julia\packages\packages\Enzyme\g1jMR\src\compiler\interpreter.jl:798
 [10] copyto!
    @ .\broadcast.jl:920
 [11] override_bc_materialize
    @ D:\Julia\packages\packages\Enzyme\g1jMR\src\compiler\interpreter.jl:854
 [12] tmp_func3
    @ e:\JlCode\HydroModels\dev\bug\test_enzyme_interp.jl:38
 [13] tmp_func3
    @ e:\JlCode\HydroModels\dev\bug\test_enzyme_interp.jl:0

Stacktrace:
  [1] julia_error(msg::String, val::Ptr{…}, errtype::Enzyme.API.ErrorType, data::Ptr{…}, data2::Ptr{…}, B::Ptr{…})
    @ Enzyme.Compiler D:\Julia\packages\packages\Enzyme\g1jMR\src\errors.jl:384
  [2] julia_error(cstr::Cstring, val::Ptr{…}, errtype::Enzyme.API.ErrorType, data::Ptr{…}, data2::Ptr{…}, B::Ptr{…})
    @ Enzyme.Compiler D:\Julia\packages\packages\Enzyme\g1jMR\src\errors.jl:210
  [3] EnzymeCreatePrimalAndGradient(logic::Enzyme.Logic, todiff::LLVM.Function, retType::Enzyme.API.CDIFFE_TYPE, constant_args::Vector{…}, TA::Enzyme.TypeAnalysis, returnValue::Bool, dretUsed::Bool, mode::Enzyme.API.CDerivativeMode, runtimeActivity::Bool, width::Int64, additionalArg::Ptr{…}, forceAnonymousTape::Bool, typeInfo::Enzyme.FnTypeInfo, uncacheable_args::Vector{…}, augmented::Ptr{…}, atomicAdd::Bool)
    @ Enzyme.API D:\Julia\packages\packages\Enzyme\g1jMR\src\api.jl:269
  [4] enzyme!(job::GPUCompiler.CompilerJob{…}, mod::LLVM.Module, primalf::LLVM.Function, TT::Type, mode::Enzyme.API.CDerivativeMode, width::Int64, parallel::Bool, actualRetType::Type, wrap::Bool, modifiedBetween::NTuple{…} where N, returnPrimal::Bool, expectedTapeType::Type, loweredArgs::Set{…}, boxedArgs::Set{…})
    @ Enzyme.Compiler D:\Julia\packages\packages\Enzyme\g1jMR\src\compiler.jl:1707
  [5] codegen(output::Symbol, job::GPUCompiler.CompilerJob{…}; libraries::Bool, deferred_codegen::Bool, optimize::Bool, toplevel::Bool, strip::Bool, validate::Bool, only_entry::Bool, parent_job::Nothing)
    @ Enzyme.Compiler D:\Julia\packages\packages\Enzyme\g1jMR\src\compiler.jl:4655
  [6] codegen
    @ D:\Julia\packages\packages\Enzyme\g1jMR\src\compiler.jl:3441 [inlined]
  [7] _thunk(job::GPUCompiler.CompilerJob{Enzyme.Compiler.EnzymeTarget, Enzyme.Compiler.EnzymeCompilerParams}, postopt::Bool)       
    @ Enzyme.Compiler D:\Julia\packages\packages\Enzyme\g1jMR\src\compiler.jl:5515
  [8] _thunk
    @ D:\Julia\packages\packages\Enzyme\g1jMR\src\compiler.jl:5515 [inlined]
  [9] cached_compilation
    @ D:\Julia\packages\packages\Enzyme\g1jMR\src\compiler.jl:5567 [inlined]
 [10] thunkbase(mi::Core.MethodInstance, World::UInt64, FA::Type{…}, A::Type{…}, TT::Type, Mode::Enzyme.API.CDerivativeMode, width::Int64, ModifiedBetween::NTuple{…} where N, ReturnPrimal::Bool, ShadowInit::Bool, ABI::Type, ErrIfFuncWritten::Bool, RuntimeActivity::Bool, edges::Vector{…})
    @ Enzyme.Compiler D:\Julia\packages\packages\Enzyme\g1jMR\src\compiler.jl:5678
 [11] thunk_generator(world::UInt64, source::LineNumberNode, FA::Type, A::Type, TT::Type, Mode::Enzyme.API.CDerivativeMode, Width::Int64, ModifiedBetween::NTuple{…} where N, ReturnPrimal::Bool, ShadowInit::Bool, ABI::Type, ErrIfFuncWritten::Bool, RuntimeActivity::Bool, self::Any, fakeworld::Any, fa::Type, a::Type, tt::Type, mode::Type, width::Type, modifiedbetween::Type, returnprimal::Type, shadowinit::Type, abi::Type, erriffuncwritten::Type, runtimeactivity::Type)
    @ Enzyme.Compiler D:\Julia\packages\packages\Enzyme\g1jMR\src\compiler.jl:5863
 [12] runtime_generic_augfwd(activity::Type{…}, runtimeActivity::Val{…}, width::Val{…}, ModifiedBetween::Val{…}, RT::Val{…}, f::typeof(tmp_func3), df::Nothing, primal_1::RowSlices{…}, shadow_1_1::RowSlices{…}, primal_2::Vector{…}, shadow_2_1::Vector{…})
    @ Enzyme.Compiler D:\Julia\packages\packages\Enzyme\g1jMR\src\rules\jitrules.jl:445
 [13] f
    @ e:\JlCode\HydroModels\dev\bug\test_enzyme_interp.jl:55 [inlined]
 [14] augmented_julia_f_117721wrap
    @ e:\JlCode\HydroModels\dev\bug\test_enzyme_interp.jl:0
 [15] top-level scope
    @ e:\JlCode\HydroModels\dev\bug\test_enzyme_interp.jl:61
Some type information was truncated. Use `show(err)` to see complete types.

Illegal replace ficticious phi for:   %_replacementA955 = phi {} addrspace(10)* , !dbg !649 of   %577 = call noalias nonnull "enzyme_type"="{[-1]:Pointer, [-1,0]:Integer, [-1,1]:Integer, [-1,2]:Integer, [-1,3]:Integer, [-1,4]:Integer, [-1,5]:Integer, [-1,6]:Integer, [-1,7]:Integer, [-1,8]:Pointer, [-1,8,-1]:Float@double}" {} addrspace(10)* @jl_genericmemory_copy_slice({} addrspace(10)* %843, i64 %576, i64 %468) #56, !dbg !498

Therefore, I placed the computation process directly in the ODE function instead of calling it, and this resolved the issue when computing the gradient for the ODE problem. Another strange thing is that when the called function itself has computation issues—for example, if I didn’t use @. for matrix broadcasting—the process would crash directly.

I’m now certain this is likely an issue with DataInterpolations.jl, and I will post this MWE to the issue tracker.