Hello, I have an ODE problem which uses DataInterpolations.jl to obtain input values for the system at specific timepoints in order to compute dx. I want to use Enzyme.jl to compute the gradients with respect to parameters. There is a similar post here, however I get a different error, therefore I’m not sure if my issue is related. Below is a MWE:
using OrdinaryDiffEq
using Enzyme
using SciMLSensitivity
using DataInterpolations: ConstantInterpolation
using OrdinaryDiffEqLowOrderRK: RK4
function fun(dx, x, p, t, u_func)
dx .= -p[1] * u_func(t)
nothing
end
function test_fun(p, prob)
prob = remake(prob, p=p)
sol = solve(prob, RK4(), save_everystep=false)
res = sol.u[2]
return res[1]
end
function test_AD()
p = [-1.0]
dp = [0.0]
x0 = [2.0]
dt = 0.5
n_samples = 2000
t_end = n_samples*dt
t_data = 0:dt:t_end-dt
u = repeat([zeros(40); ones(40)],25,1)
u_func = ConstantInterpolation(u', t_data)
tspan = (0.0, t_end-dt)
prob = ODEProblem{true}((dx,x,p,t) -> fun(dx,x,p,t,u_func), x0, tspan, p)
dprob = Enzyme.make_zero(prob)
Enzyme.autodiff(Enzyme.Reverse, test_fun, Active, Duplicated(p, dp), DuplicatedNoNeed(prob,dprob))
@info dp
end
test_AD()
And here the Stacktrace:
ERROR: LoadError: ArgumentError: cannot construct a value of type Union{} for return result
Stacktrace:
[1] (::Core.TypeofBottom)(a::Int64)
@ Core .\boot.jl:275
[2] zero(::Type{Union{}})
@ Base .\number.jl:310
[3] zero(x::Vector{Union{}})
@ Base .\abstractarray.jl:1205
[4] make_zero(::Type{Vector{Union{}}}, seen::IdDict{Any, Any}, prev::Vector{Union{}}, ::Val{false})
@ Enzyme.Compiler C:\Users\user\.julia\packages\Enzyme\13cYK\src\typeutils\make_zero.jl:42
[5] make_zero
@ C:\Users\user\.julia\packages\Enzyme\13cYK\src\typeutils\make_zero.jl:257 [inlined]
[6] make_zero(::Type{var"#11#12"{…}}, seen::IdDict{Any, Any}, prev::var"#11#12"{ConstantInterpolation{…}}, ::Val{false})
@ Enzyme.Compiler C:\Users\user\.julia\packages\Enzyme\13cYK\src\typeutils\make_zero.jl:257
[7] make_zero(::Type{…}, seen::IdDict{…}, prev::ODEFunction{…}, ::Val{…})
@ Enzyme.Compiler C:\Users\user\.julia\packages\Enzyme\13cYK\src\typeutils\make_zero.jl:257
[8] make_zero
@ C:\Users\user\.julia\packages\Enzyme\13cYK\src\typeutils\make_zero.jl:240 [inlined]
[9] make_zero (repeats 2 times)
@ C:\Users\user\.julia\packages\EnzymeCore\uEFFs\src\EnzymeCore.jl:587 [inlined]
[10] test_AD()
@ Main C:\Users\user\tmp\Enzyme_mwe.jl:34
[11] top-level scope
@ C:\Users\user\tmp\Enzyme_mwe.jl:40
in expression starting at C:\Users\user\tmp\Enzyme_mwe.jl:40
Some type information was truncated. Use `show(err)` to see complete types.
The issue seems to be related to make_zero and u_func. Any help appreciated:)
The package versions:
(tmp) pkg> st OrdinaryDiffEq Enzyme SciMLSensitivity DataInterpolations
Status `C:\Users\user\tmp\Project.toml`
[82cc6244] DataInterpolations v8.10.0
⌃ [7da242da] Enzyme v0.13.152
⌃ [1dea7af3] OrdinaryDiffEq v7.0.0
⌃ [1ed8b502] SciMLSensitivity v7.111.0
[1344f307] OrdinaryDiffEqLowOrderRK v2.1.0
I’m using the LTS release (1.10.11)