Implementing Three-Dimensional Interpolation, ODEProblem, and Gradient Computation with sensealg=GaussAdjoint(autojacvec=EnzymeVJP())

Hello, the problem I am currently facing is solving an ordinary differential equation (ODE) with interpolated data inputs. The problem can be visualized as a snow layer, where the snow layer represents the intermediate state value, snowfall is the input supply, and snowmelt is the consumption value. The state equation can be expressed as:

D(snowpack) ~ snowfall - melt

Here, snowfall is obtained based on observed inputs. We need to interpolate precipitation and temperature data in advance to obtain precipitation and temperature values at any time period, which are then used to calculate snowfall. The melt is calculated based on the snowpack combined with temperature values. The constructed ODE function is as follows:

using Zygote
using OrdinaryDiffEq
using SciMLSensitivity
using MLUtils
using ComponentArrays
using DataInterpolations

step_func(x) = (tanh(5.0 * x) + 1.0) * 0.5

tmp_func = (inputs, states, pas) -> begin
    prcp = inputs[1]
    temp = inputs[2]
    lday = inputs[3]
    snowpack = states[1]
    Tmin = pas.params.Tmin
    Df = pas.params.Df
    Tmax = pas.params.Tmax
    snowfall = step_func(Tmin - temp) * prcp
    melt = step_func(temp - Tmax) * min(snowpack, Df * (temp - Tmax))
    return [snowfall - melt]
end


Df, Tmax, Tmin = 2.67, 0.17, -2.09
ps = ComponentVector(params=(Df=Df, Tmax=Tmax, Tmin=Tmin))
ps_axes = getaxes(ps)
u0 = zeros(eltype(ps), 1)

node_input = rand(3, 20)
timeidx = collect(1:20)
itpfunc = LinearInterpolation(node_input, timeidx)

function ode_func!(du, u, p, t)
    du .= tmp_func(itpfunc(t), u, ComponentVector(p, ps_axes))
end

tspan = (1.0, 20.0)
prob = ODEProblem(ode_func!, u0, tspan, Vector(ps))
sol = solve(prob, Tsit5(), sensealg=BacksolveAdjoint(autojacvec=EnzymeVJP()))

Zygote.gradient(Vector(ps)) do p
    prob = ODEProblem(ode_func!, u0, tspan, p)
    sol = solve(prob, Tsit5(), sensealg=BacksolveAdjoint(autojacvec=EnzymeVJP())) |> Array
    return sum(sol)
end

At this point, I want to extend the dimensionality of this problem. Imagine that I have multiple snowpack modules, for example, 10. The input should then be three-dimensional: node_input=rand(3, 10, 20). The second dimension of the input data corresponds to the number of snowpack modules. At the same time, the initial value u0 in the ODE problem will also become a 1×10 dimensional vector. The specific code is as follows:

tmp_func2 = (inputs, states, pas) -> begin
    prcp = inputs[1]
    temp = inputs[2]
    lday = inputs[3]
    snowpack = states[1]
    Tmin = pas.params.Tmin
    Df = pas.params.Df
    Tmax = pas.params.Tmax
    snowfall = @. step_func(Tmin - temp) * prcp
    melt = @. step_func(temp - Tmax) * min(snowpack, Df * (temp - Tmax))
    return stack([snowfall .- melt], dims=1)
end

Df, Tmax, Tmin = 2.674548848, 0.175739196, -2.092959084
ps = ComponentVector(params=(Df=Df, Tmax=Tmax, Tmin=Tmin))
ps_axes = getaxes(ps)
u0 = rand(1, 10)

node_input = rand(3, 10, 20)
itpfuncs = LinearInterpolation.(eachslice(node_input, dims=1), Ref(1:20)) # raise wrong
tmp_input_func = (t) -> ntuple(i -> itpfuncs[i](t), length(itpfuncs))
itpfunc = LinearInterpolation(reshape(node_input, :, 20), 1:20) # reshape before seems ok

function multi_ode_func!(du, u, p, t)
    tmp_input = reshape(itpfunc(t), 3, 10)
    du .= tmp_func2(eachslice(tmp_input, dims=1), eachslice(u, dims=1), ComponentVector(p, ps_axes))
end

tspan = (1.0, 20.0)
prob = ODEProblem(multi_ode_func!, u0, tspan, Vector(ps))
sol = solve(prob, Tsit5(), sensealg=GaussAdjoint(autojacvec=EnzymeVJP())) |> Array

Zygote.gradient(Vector(ps)) do p
    prob = ODEProblem(multi_ode_func!, u0, tspan, p)
    sol = solve(prob, Tsit5(), sensealg=GaussAdjoint(autojacvec=EnzymeVJP())) |> Array |> sum
end

This error occurred when calculating the gradient:

ERROR: Enzyme compilation failed due to an internal error.
 Please open an issue with the code to reproduce and full error log on github.com/EnzymeAD/Enzyme.jl
 To toggle more information for debugging (needed for bug reports), set Enzyme.Compiler.VERBOSE_ERRORS[] = true (default false)

Illegal replace ficticious phi for:   %_replacementA677 = phi {} addrspace(10)* , !dbg !765 of   %493 = call noalias nonnull "enzyme_type"="{[-1]:Pointer, [-1,0]:Integer, [-1,1]:Integer, [-1,2]:Integer, [-1,3]:Integer, [-1,4]:Integer, [-1,5]:Integer, [-1,6]:Integer, [-1,7]:Integer, [-1,8]:Pointer, [-1,8,-1]:Float@double}" {} addrspace(10)* @jl_genericmemory_copy_slice({} addrspace(10)* %787, i64 %492, i64 %403) #57, !dbg !612

Stacktrace:
  [1] copy
    @ .\array.jl:350
  [2] unaliascopy
    @ .\abstractarray.jl:1516
  [3] unalias
    @ .\abstractarray.jl:1500
  [4] broadcast_unalias
    @ .\broadcast.jl:941
  [5] preprocess
    @ .\broadcast.jl:948
  [6] preprocess_args
    @ .\broadcast.jl:950
  [7] preprocess
    @ .\broadcast.jl:947
  [8] override_bc_copyto!
    @ D:\Julia\packages\packages\Enzyme\g1jMR\src\compiler\interpreter.jl:798
  [9] copyto!
    @ .\broadcast.jl:920
 [10] override_bc_materialize
    @ D:\Julia\packages\packages\Enzyme\g1jMR\src\compiler\interpreter.jl:854
 [11] #83
    @ e:\JlCode\HydroModels\dev\base\tmp_test.jl:41
 [12] #83
    @ e:\JlCode\HydroModels\dev\base\tmp_test.jl:0

Stacktrace:
  [1] julia_error(msg::String, val::Ptr{…}, errtype::Enzyme.API.ErrorType, data::Ptr{…}, data2::Ptr{…}, B::Ptr{…})
    @ Enzyme.Compiler D:\Julia\packages\packages\Enzyme\g1jMR\src\errors.jl:384
  [2] julia_error(cstr::Cstring, val::Ptr{…}, errtype::Enzyme.API.ErrorType, data::Ptr{…}, data2::Ptr{…}, B::Ptr{…})
    @ Enzyme.Compiler D:\Julia\packages\packages\Enzyme\g1jMR\src\errors.jl:210
  [3] EnzymeCreatePrimalAndGradient(logic::Enzyme.Logic, todiff::LLVM.Function, retType::Enzyme.API.CDIFFE_TYPE, constant_args::Vector{…}, TA::Enzyme.TypeAnalysis, returnValue::Bool, dretUsed::Bool, mode::Enzyme.API.CDerivativeMode, runtimeActivity::Bool, width::Int64, additionalArg::Ptr{…}, forceAnonymousTape::Bool, typeInfo::Enzyme.FnTypeInfo, uncacheable_args::Vector{…}, augmented::Ptr{…}, atomicAdd::Bool)
    @ Enzyme.API D:\Julia\packages\packages\Enzyme\g1jMR\src\api.jl:269
  [4] enzyme!(job::GPUCompiler.CompilerJob{…}, mod::LLVM.Module, primalf::LLVM.Function, TT::Type, mode::Enzyme.API.CDerivativeMode, width::Int64, parallel::Bool, actualRetType::Type, wrap::Bool, modifiedBetween::NTuple{…} where N, returnPrimal::Bool, expectedTapeType::Type, loweredArgs::Set{…}, boxedArgs::Set{…}) 
    @ Enzyme.Compiler D:\Julia\packages\packages\Enzyme\g1jMR\src\compiler.jl:1707
  [5] codegen(output::Symbol, job::GPUCompiler.CompilerJob{…}; libraries::Bool, deferred_codegen::Bool, optimize::Bool, toplevel::Bool, strip::Bool, validate::Bool, only_entry::Bool, parent_job::Nothing)
    @ Enzyme.Compiler D:\Julia\packages\packages\Enzyme\g1jMR\src\compiler.jl:4655
  [6] codegen
    @ D:\Julia\packages\packages\Enzyme\g1jMR\src\compiler.jl:3441 [inlined]
  [7] _thunk(job::GPUCompiler.CompilerJob{Enzyme.Compiler.EnzymeTarget, Enzyme.Compiler.EnzymeCompilerParams}, postopt::Bool)
    @ Enzyme.Compiler D:\Julia\packages\packages\Enzyme\g1jMR\src\compiler.jl:5515
  [8] _thunk
    @ D:\Julia\packages\packages\Enzyme\g1jMR\src\compiler.jl:5515 [inlined]
  [9] cached_compilation
    @ D:\Julia\packages\packages\Enzyme\g1jMR\src\compiler.jl:5567 [inlined]
 [10] thunkbase(mi::Core.MethodInstance, World::UInt64, FA::Type{…}, A::Type{…}, TT::Type, Mode::Enzyme.API.CDerivativeMode, width::Int64, ModifiedBetween::NTuple{…} where N, ReturnPrimal::Bool, ShadowInit::Bool, ABI::Type, ErrIfFuncWritten::Bool, RuntimeActivity::Bool, edges::Vector{…})
    @ Enzyme.Compiler D:\Julia\packages\packages\Enzyme\g1jMR\src\compiler.jl:5678
 [11] thunk_generator(world::UInt64, source::LineNumberNode, FA::Type, A::Type, TT::Type, Mode::Enzyme.API.CDerivativeMode, Width::Int64, ModifiedBetween::NTuple{…} where N, ReturnPrimal::Bool, ShadowInit::Bool, ABI::Type, ErrIfFuncWritten::Bool, RuntimeActivity::Bool, self::Any, fakeworld::Any, fa::Type, a::Type, tt::Type, mode::Type, width::Type, modifiedbetween::Type, returnprimal::Type, shadowinit::Type, abi::Type, erriffuncwritten::Type, runtimeactivity::Type)     
    @ Enzyme.Compiler D:\Julia\packages\packages\Enzyme\g1jMR\src\compiler.jl:5863
 [12] runtime_generic_augfwd(activity::Type{…}, runtimeActivity::Val{…}, width::Val{…}, ModifiedBetween::Val{…}, RT::Val{…}, f::var"#83#84", df::Nothing, primal_1::RowSlices{…}, shadow_1_1::RowSlices{…}, primal_2::RowSlices{…}, shadow_2_1::RowSlices{…}, primal_3::ComponentVector{…}, shadow_3_1::ComponentVector{…}) 
    @ Enzyme.Compiler D:\Julia\packages\packages\Enzyme\g1jMR\src\rules\jitrules.jl:445
 [13] multi_ode_func!
    @ e:\JlCode\HydroModels\dev\base\tmp_test.jl:56
Some type information was truncated. Use `show(err)` to see complete types.

Illegal replace ficticious phi for:   %_replacementA677 = phi {} addrspace(10)* , !dbg !765 of   %493 = call noalias nonnull "enzyme_type"="{[-1]:Pointer, [-1,0]:Integer, [-1,1]:Integer, [-1,2]:Integer, [-1,3]:Integer, [-1,4]:Integer, [-1,5]:Integer, [-1,6]:Integer, [-1,7]:Integer, [-1,8]:Pointer, [-1,8,-1]:Float@double}" {} addrspace(10)* @jl_genericmemory_copy_slice({} addrspace(10)* %787, i64 %492, i64 %403) #57, !dbg !612

But using ZygoteVJP is ok:

Zygote.gradient(Vector(ps)) do p
    prob = ODEProblem(multi_ode_func!, u0, tspan, p)
    sol = solve(prob, Tsit5(), sensealg=BacksolveAdjoint(autojacvec=ZygoteVJP())) |> Array |> sum
end

#output ([2.363584938462766, 210.78169299208085, 4.2921775834541405e-7],)

Although ZygoteVJP can meet my needs, I found during my own usage that EnzymeVJP computes gradients much faster. Therefore, I hope this requirement can be correctly implemented under EnzymeVJP. However, from the code errors, it seems there is a compilation issue. As a non-computer science Julia user, interpreting this problem is somewhat challenging for me. Additionally, I found that using the following form for input is more convenient:

itpfuncs = LinearInterpolation.(eachslice(node_input, dims=1), Ref(1:20)) # raises error
tmp_input_func = (t) -> ntuple(i -> itpfuncs[i](t), length(itpfuncs))

However, Enzyme does not seem to support this form well either.
In summary, I sincerely plead for someone to assist me in resolving my issue. My environment is as follows:

Julia Version 1.11.2
Commit 5e9a32e7af (2024-12-01 20:02 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Windows (x86_64-w64-mingw32)
  CPU: 24 × 12th Gen Intel(R) Core(TM) i9-12900HX
  WORD_SIZE: 64
  LLVM: libLLVM-16.0.6 (ORCJIT, alderlake)
Threads: 1 default, 0 interactive, 1 GC (on 24 virtual cores)
Environment:
  JULIA_DEPOT_PATH = D:\Julia\packages
  JULIA_PKG_SERVER = https://mirrors.pku.edu.cn/julia/
  JULIA_EDITOR = code
  JULIA_NUM_THREADS =
(dev) pkg> st Zygote OrdinaryDiffEq SciMLSensitivity MLUtils ComponentArrays 
  [b0b7db55] ComponentArrays v0.15.26
  [82cc6244] DataInterpolations v8.0.0
  [f1d291b0] MLUtils v0.4.8
  [1dea7af3] OrdinaryDiffEq v6.93.0
  [1ed8b502] SciMLSensitivity v7.76.0
  [e88e6eb3] Zygote v0.7.6

Open an issue on DataInterpolations.jl on Enzyme compatibility. It should be possible to make an MWE that does not have the ODE solver but is directly calling Enzyme. If you make that MWE and post an issue it will be much easier to solve this.