AD using Enzyme in ODE problem with DataInterpolations

Hello, I have an ODE problem which uses DataInterpolations.jl to obtain input values for the system at specific timepoints in order to compute dx. I want to use Enzyme.jl to compute the gradients with respect to parameters. There is a similar post here, however I get a different error, therefore I’m not sure if my issue is related. Below is a MWE:

using OrdinaryDiffEq
using Enzyme
using SciMLSensitivity
using DataInterpolations: ConstantInterpolation
using OrdinaryDiffEqLowOrderRK: RK4

function fun(dx, x, p, t, u_func)
    dx .= -p[1] * u_func(t)
    nothing
end

function test_fun(p, prob)
    prob = remake(prob, p=p)
    sol = solve(prob, RK4(), save_everystep=false)
    res = sol.u[2]
    return res[1]
end

function test_AD()
    p = [-1.0]
    dp = [0.0]
    x0 = [2.0]
    
    dt = 0.5
    n_samples = 2000
    t_end = n_samples*dt
    t_data = 0:dt:t_end-dt

    u = repeat([zeros(40); ones(40)],25,1)

    u_func = ConstantInterpolation(u', t_data)
    tspan = (0.0, t_end-dt)

    prob = ODEProblem{true}((dx,x,p,t) -> fun(dx,x,p,t,u_func), x0, tspan, p)
    dprob = Enzyme.make_zero(prob)

    Enzyme.autodiff(Enzyme.Reverse, test_fun, Active, Duplicated(p, dp), DuplicatedNoNeed(prob,dprob))
    @info dp
end

test_AD()

And here the Stacktrace:

ERROR: LoadError: ArgumentError: cannot construct a value of type Union{} for return result
Stacktrace:
  [1] (::Core.TypeofBottom)(a::Int64)
    @ Core .\boot.jl:275
  [2] zero(::Type{Union{}})
    @ Base .\number.jl:310
  [3] zero(x::Vector{Union{}})
    @ Base .\abstractarray.jl:1205
  [4] make_zero(::Type{Vector{Union{}}}, seen::IdDict{Any, Any}, prev::Vector{Union{}}, ::Val{false})
    @ Enzyme.Compiler C:\Users\user\.julia\packages\Enzyme\13cYK\src\typeutils\make_zero.jl:42
  [5] make_zero
    @ C:\Users\user\.julia\packages\Enzyme\13cYK\src\typeutils\make_zero.jl:257 [inlined]
  [6] make_zero(::Type{var"#11#12"{…}}, seen::IdDict{Any, Any}, prev::var"#11#12"{ConstantInterpolation{…}}, ::Val{false})
    @ Enzyme.Compiler C:\Users\user\.julia\packages\Enzyme\13cYK\src\typeutils\make_zero.jl:257
  [7] make_zero(::Type{…}, seen::IdDict{…}, prev::ODEFunction{…}, ::Val{…})
    @ Enzyme.Compiler C:\Users\user\.julia\packages\Enzyme\13cYK\src\typeutils\make_zero.jl:257
  [8] make_zero
    @ C:\Users\user\.julia\packages\Enzyme\13cYK\src\typeutils\make_zero.jl:240 [inlined]
  [9] make_zero (repeats 2 times)
    @ C:\Users\user\.julia\packages\EnzymeCore\uEFFs\src\EnzymeCore.jl:587 [inlined]
 [10] test_AD()
    @ Main C:\Users\user\tmp\Enzyme_mwe.jl:34
 [11] top-level scope
    @ C:\Users\user\tmp\Enzyme_mwe.jl:40
in expression starting at C:\Users\user\tmp\Enzyme_mwe.jl:40
Some type information was truncated. Use `show(err)` to see complete types.

The issue seems to be related to make_zero and u_func. Any help appreciated:)

The package versions:

(tmp) pkg> st OrdinaryDiffEq Enzyme SciMLSensitivity DataInterpolations
Status `C:\Users\user\tmp\Project.toml`
  [82cc6244] DataInterpolations v8.10.0
⌃ [7da242da] Enzyme v0.13.152
⌃ [1dea7af3] OrdinaryDiffEq v7.0.0
⌃ [1ed8b502] SciMLSensitivity v7.111.0
  [1344f307] OrdinaryDiffEqLowOrderRK v2.1.0

I’m using the LTS release (1.10.11)