Hello, the problem I am currently facing is solving an ordinary differential equation (ODE) with interpolated data inputs. The problem can be visualized as a snow layer, where the snow layer represents the intermediate state value, snowfall is the input supply, and snowmelt is the consumption value. The state equation can be expressed as:
D(snowpack) ~ snowfall - melt
Here, snowfall is obtained based on observed inputs. We need to interpolate precipitation and temperature data in advance to obtain precipitation and temperature values at any time period, which are then used to calculate snowfall. The melt is calculated based on the snowpack combined with temperature values. The constructed ODE function is as follows:
using Zygote
using OrdinaryDiffEq
using SciMLSensitivity
using MLUtils
using ComponentArrays
using DataInterpolations
step_func(x) = (tanh(5.0 * x) + 1.0) * 0.5
tmp_func = (inputs, states, pas) -> begin
prcp = inputs[1]
temp = inputs[2]
lday = inputs[3]
snowpack = states[1]
Tmin = pas.params.Tmin
Df = pas.params.Df
Tmax = pas.params.Tmax
snowfall = step_func(Tmin - temp) * prcp
melt = step_func(temp - Tmax) * min(snowpack, Df * (temp - Tmax))
return [snowfall - melt]
end
Df, Tmax, Tmin = 2.67, 0.17, -2.09
ps = ComponentVector(params=(Df=Df, Tmax=Tmax, Tmin=Tmin))
ps_axes = getaxes(ps)
u0 = zeros(eltype(ps), 1)
node_input = rand(3, 20)
timeidx = collect(1:20)
itpfunc = LinearInterpolation(node_input, timeidx)
function ode_func!(du, u, p, t)
du .= tmp_func(itpfunc(t), u, ComponentVector(p, ps_axes))
end
tspan = (1.0, 20.0)
prob = ODEProblem(ode_func!, u0, tspan, Vector(ps))
sol = solve(prob, Tsit5(), sensealg=BacksolveAdjoint(autojacvec=EnzymeVJP()))
Zygote.gradient(Vector(ps)) do p
prob = ODEProblem(ode_func!, u0, tspan, p)
sol = solve(prob, Tsit5(), sensealg=BacksolveAdjoint(autojacvec=EnzymeVJP())) |> Array
return sum(sol)
end
At this point, I want to extend the dimensionality of this problem. Imagine that I have multiple snowpack modules, for example, 10. The input should then be three-dimensional: node_input=rand(3, 10, 20)
. The second dimension of the input data corresponds to the number of snowpack modules. At the same time, the initial value u0
in the ODE problem will also become a 1×10 dimensional vector. The specific code is as follows:
tmp_func2 = (inputs, states, pas) -> begin
prcp = inputs[1]
temp = inputs[2]
lday = inputs[3]
snowpack = states[1]
Tmin = pas.params.Tmin
Df = pas.params.Df
Tmax = pas.params.Tmax
snowfall = @. step_func(Tmin - temp) * prcp
melt = @. step_func(temp - Tmax) * min(snowpack, Df * (temp - Tmax))
return stack([snowfall .- melt], dims=1)
end
Df, Tmax, Tmin = 2.674548848, 0.175739196, -2.092959084
ps = ComponentVector(params=(Df=Df, Tmax=Tmax, Tmin=Tmin))
ps_axes = getaxes(ps)
u0 = rand(1, 10)
node_input = rand(3, 10, 20)
itpfuncs = LinearInterpolation.(eachslice(node_input, dims=1), Ref(1:20)) # raise wrong
tmp_input_func = (t) -> ntuple(i -> itpfuncs[i](t), length(itpfuncs))
itpfunc = LinearInterpolation(reshape(node_input, :, 20), 1:20) # reshape before seems ok
function multi_ode_func!(du, u, p, t)
tmp_input = reshape(itpfunc(t), 3, 10)
du .= tmp_func2(eachslice(tmp_input, dims=1), eachslice(u, dims=1), ComponentVector(p, ps_axes))
end
tspan = (1.0, 20.0)
prob = ODEProblem(multi_ode_func!, u0, tspan, Vector(ps))
sol = solve(prob, Tsit5(), sensealg=GaussAdjoint(autojacvec=EnzymeVJP())) |> Array
Zygote.gradient(Vector(ps)) do p
prob = ODEProblem(multi_ode_func!, u0, tspan, p)
sol = solve(prob, Tsit5(), sensealg=GaussAdjoint(autojacvec=EnzymeVJP())) |> Array |> sum
end
This error occurred when calculating the gradient:
ERROR: Enzyme compilation failed due to an internal error.
Please open an issue with the code to reproduce and full error log on github.com/EnzymeAD/Enzyme.jl
To toggle more information for debugging (needed for bug reports), set Enzyme.Compiler.VERBOSE_ERRORS[] = true (default false)
Illegal replace ficticious phi for: %_replacementA677 = phi {} addrspace(10)* , !dbg !765 of %493 = call noalias nonnull "enzyme_type"="{[-1]:Pointer, [-1,0]:Integer, [-1,1]:Integer, [-1,2]:Integer, [-1,3]:Integer, [-1,4]:Integer, [-1,5]:Integer, [-1,6]:Integer, [-1,7]:Integer, [-1,8]:Pointer, [-1,8,-1]:Float@double}" {} addrspace(10)* @jl_genericmemory_copy_slice({} addrspace(10)* %787, i64 %492, i64 %403) #57, !dbg !612
Stacktrace:
[1] copy
@ .\array.jl:350
[2] unaliascopy
@ .\abstractarray.jl:1516
[3] unalias
@ .\abstractarray.jl:1500
[4] broadcast_unalias
@ .\broadcast.jl:941
[5] preprocess
@ .\broadcast.jl:948
[6] preprocess_args
@ .\broadcast.jl:950
[7] preprocess
@ .\broadcast.jl:947
[8] override_bc_copyto!
@ D:\Julia\packages\packages\Enzyme\g1jMR\src\compiler\interpreter.jl:798
[9] copyto!
@ .\broadcast.jl:920
[10] override_bc_materialize
@ D:\Julia\packages\packages\Enzyme\g1jMR\src\compiler\interpreter.jl:854
[11] #83
@ e:\JlCode\HydroModels\dev\base\tmp_test.jl:41
[12] #83
@ e:\JlCode\HydroModels\dev\base\tmp_test.jl:0
Stacktrace:
[1] julia_error(msg::String, val::Ptr{…}, errtype::Enzyme.API.ErrorType, data::Ptr{…}, data2::Ptr{…}, B::Ptr{…})
@ Enzyme.Compiler D:\Julia\packages\packages\Enzyme\g1jMR\src\errors.jl:384
[2] julia_error(cstr::Cstring, val::Ptr{…}, errtype::Enzyme.API.ErrorType, data::Ptr{…}, data2::Ptr{…}, B::Ptr{…})
@ Enzyme.Compiler D:\Julia\packages\packages\Enzyme\g1jMR\src\errors.jl:210
[3] EnzymeCreatePrimalAndGradient(logic::Enzyme.Logic, todiff::LLVM.Function, retType::Enzyme.API.CDIFFE_TYPE, constant_args::Vector{…}, TA::Enzyme.TypeAnalysis, returnValue::Bool, dretUsed::Bool, mode::Enzyme.API.CDerivativeMode, runtimeActivity::Bool, width::Int64, additionalArg::Ptr{…}, forceAnonymousTape::Bool, typeInfo::Enzyme.FnTypeInfo, uncacheable_args::Vector{…}, augmented::Ptr{…}, atomicAdd::Bool)
@ Enzyme.API D:\Julia\packages\packages\Enzyme\g1jMR\src\api.jl:269
[4] enzyme!(job::GPUCompiler.CompilerJob{…}, mod::LLVM.Module, primalf::LLVM.Function, TT::Type, mode::Enzyme.API.CDerivativeMode, width::Int64, parallel::Bool, actualRetType::Type, wrap::Bool, modifiedBetween::NTuple{…} where N, returnPrimal::Bool, expectedTapeType::Type, loweredArgs::Set{…}, boxedArgs::Set{…})
@ Enzyme.Compiler D:\Julia\packages\packages\Enzyme\g1jMR\src\compiler.jl:1707
[5] codegen(output::Symbol, job::GPUCompiler.CompilerJob{…}; libraries::Bool, deferred_codegen::Bool, optimize::Bool, toplevel::Bool, strip::Bool, validate::Bool, only_entry::Bool, parent_job::Nothing)
@ Enzyme.Compiler D:\Julia\packages\packages\Enzyme\g1jMR\src\compiler.jl:4655
[6] codegen
@ D:\Julia\packages\packages\Enzyme\g1jMR\src\compiler.jl:3441 [inlined]
[7] _thunk(job::GPUCompiler.CompilerJob{Enzyme.Compiler.EnzymeTarget, Enzyme.Compiler.EnzymeCompilerParams}, postopt::Bool)
@ Enzyme.Compiler D:\Julia\packages\packages\Enzyme\g1jMR\src\compiler.jl:5515
[8] _thunk
@ D:\Julia\packages\packages\Enzyme\g1jMR\src\compiler.jl:5515 [inlined]
[9] cached_compilation
@ D:\Julia\packages\packages\Enzyme\g1jMR\src\compiler.jl:5567 [inlined]
[10] thunkbase(mi::Core.MethodInstance, World::UInt64, FA::Type{…}, A::Type{…}, TT::Type, Mode::Enzyme.API.CDerivativeMode, width::Int64, ModifiedBetween::NTuple{…} where N, ReturnPrimal::Bool, ShadowInit::Bool, ABI::Type, ErrIfFuncWritten::Bool, RuntimeActivity::Bool, edges::Vector{…})
@ Enzyme.Compiler D:\Julia\packages\packages\Enzyme\g1jMR\src\compiler.jl:5678
[11] thunk_generator(world::UInt64, source::LineNumberNode, FA::Type, A::Type, TT::Type, Mode::Enzyme.API.CDerivativeMode, Width::Int64, ModifiedBetween::NTuple{…} where N, ReturnPrimal::Bool, ShadowInit::Bool, ABI::Type, ErrIfFuncWritten::Bool, RuntimeActivity::Bool, self::Any, fakeworld::Any, fa::Type, a::Type, tt::Type, mode::Type, width::Type, modifiedbetween::Type, returnprimal::Type, shadowinit::Type, abi::Type, erriffuncwritten::Type, runtimeactivity::Type)
@ Enzyme.Compiler D:\Julia\packages\packages\Enzyme\g1jMR\src\compiler.jl:5863
[12] runtime_generic_augfwd(activity::Type{…}, runtimeActivity::Val{…}, width::Val{…}, ModifiedBetween::Val{…}, RT::Val{…}, f::var"#83#84", df::Nothing, primal_1::RowSlices{…}, shadow_1_1::RowSlices{…}, primal_2::RowSlices{…}, shadow_2_1::RowSlices{…}, primal_3::ComponentVector{…}, shadow_3_1::ComponentVector{…})
@ Enzyme.Compiler D:\Julia\packages\packages\Enzyme\g1jMR\src\rules\jitrules.jl:445
[13] multi_ode_func!
@ e:\JlCode\HydroModels\dev\base\tmp_test.jl:56
Some type information was truncated. Use `show(err)` to see complete types.
Illegal replace ficticious phi for: %_replacementA677 = phi {} addrspace(10)* , !dbg !765 of %493 = call noalias nonnull "enzyme_type"="{[-1]:Pointer, [-1,0]:Integer, [-1,1]:Integer, [-1,2]:Integer, [-1,3]:Integer, [-1,4]:Integer, [-1,5]:Integer, [-1,6]:Integer, [-1,7]:Integer, [-1,8]:Pointer, [-1,8,-1]:Float@double}" {} addrspace(10)* @jl_genericmemory_copy_slice({} addrspace(10)* %787, i64 %492, i64 %403) #57, !dbg !612
But using ZygoteVJP is ok:
Zygote.gradient(Vector(ps)) do p
prob = ODEProblem(multi_ode_func!, u0, tspan, p)
sol = solve(prob, Tsit5(), sensealg=BacksolveAdjoint(autojacvec=ZygoteVJP())) |> Array |> sum
end
#output ([2.363584938462766, 210.78169299208085, 4.2921775834541405e-7],)
Although ZygoteVJP can meet my needs, I found during my own usage that EnzymeVJP computes gradients much faster. Therefore, I hope this requirement can be correctly implemented under EnzymeVJP. However, from the code errors, it seems there is a compilation issue. As a non-computer science Julia user, interpreting this problem is somewhat challenging for me. Additionally, I found that using the following form for input is more convenient:
itpfuncs = LinearInterpolation.(eachslice(node_input, dims=1), Ref(1:20)) # raises error
tmp_input_func = (t) -> ntuple(i -> itpfuncs[i](t), length(itpfuncs))
However, Enzyme does not seem to support this form well either.
In summary, I sincerely plead for someone to assist me in resolving my issue. My environment is as follows:
Julia Version 1.11.2
Commit 5e9a32e7af (2024-12-01 20:02 UTC)
Build Info:
Official https://julialang.org/ release
Platform Info:
OS: Windows (x86_64-w64-mingw32)
CPU: 24 × 12th Gen Intel(R) Core(TM) i9-12900HX
WORD_SIZE: 64
LLVM: libLLVM-16.0.6 (ORCJIT, alderlake)
Threads: 1 default, 0 interactive, 1 GC (on 24 virtual cores)
Environment:
JULIA_DEPOT_PATH = D:\Julia\packages
JULIA_PKG_SERVER = https://mirrors.pku.edu.cn/julia/
JULIA_EDITOR = code
JULIA_NUM_THREADS =
(dev) pkg> st Zygote OrdinaryDiffEq SciMLSensitivity MLUtils ComponentArrays
[b0b7db55] ComponentArrays v0.15.26
[82cc6244] DataInterpolations v8.0.0
[f1d291b0] MLUtils v0.4.8
[1dea7af3] OrdinaryDiffEq v6.93.0
[1ed8b502] SciMLSensitivity v7.76.0
[e88e6eb3] Zygote v0.7.6