AD through DataInterpolations, ODE : Gradient calculation, parameter -> time series -> DataInterpolation

Greetings!

The question is regarding AD through DataInterpolations, ODE. The function is described as follows,

using DataInterpolations, OrdinaryDiffEq, SciMLSensitivity, Zygote

time = 0.:0.1:20
function temp(ps)
    x = sin.(ps .* time)
    x_interp = CubicSpline(x, time)
    x_t = t -> x_interp(t)
    function system_augment!(du, u, ps, t)
        du[1, :] .= u[2, :]'  
        du[2, :] .= (-1 .+ 1 .* ps(t))'
    end
    system!(du, u, ps, t) = system_augment!(du, u, ps, t)
    prob = ODEProblem{true, SciMLBase.FullSpecialize}(system!, rand(2, 1), (0, 2), x_t)  
    sol = solve(prob, saveat = 0.1, )
    out = sum(Array(sol))
    return out  
end
temp(0.1)

Now calculating the error,

gradient(temp, 0.1) # Fails, the error is given by

ERROR: `p` is not a SciMLStructure. This is required for adjoint sensitivity analysis. For more information,
see the documentation on SciMLStructures.jl for the definition of the SciMLStructures interface.
In particular, adjoint sensitivities only applies to `Tunable`.


Stacktrace:
  [1] automatic_sensealg_choice(prob::ODEProblem{…}, u0::Matrix{…}, p::@NamedTuple{…}, verbose::Bool, repack::Functors.var"#3#6"{…})
    @ SciMLSensitivity ~/.julia/packages/SciMLSensitivity/7UEpc/src/concrete_solve.jl:87
  [2] _concrete_solve_adjoint(::ODEProblem{…}, ::Nothing, ::Nothing, ::Matrix{…}, ::Function, ::SciMLBase.ChainRulesOriginator; verbose::Bool, kwargs::@Kwargs{…})
    @ SciMLSensitivity ~/.julia/packages/SciMLSensitivity/7UEpc/src/concrete_solve.jl:274
  [3] _concrete_solve_adjoint
    @ ~/.julia/packages/SciMLSensitivity/7UEpc/src/concrete_solve.jl:245 [inlined]
  [4] #_solve_adjoint#66
    @ ~/.julia/packages/DiffEqBase/PbBEl/src/solve.jl:1584 [inlined]
  [5] _solve_adjoint
    @ ~/.julia/packages/DiffEqBase/PbBEl/src/solve.jl:1557 [inlined]
  [6] #rrule#4
    @ ~/.julia/packages/DiffEqBase/PbBEl/ext/DiffEqBaseChainRulesCoreExt.jl:26 [inlined]
  [7] rrule
    @ ~/.julia/packages/DiffEqBase/PbBEl/ext/DiffEqBaseChainRulesCoreExt.jl:22 [inlined]
  [8] rrule
    @ ~/.julia/packages/ChainRulesCore/U6wNx/src/rules.jl:144 [inlined]
  [9] chain_rrule_kw
    @ ~/.julia/packages/Zygote/TWpme/src/compiler/chainrules.jl:236 [inlined]
 [10] macro expansion
    @ ~/.julia/packages/Zygote/TWpme/src/compiler/interface2.jl:0 [inlined]
 [11] _pullback
    @ ~/.julia/packages/Zygote/TWpme/src/compiler/interface2.jl:91 [inlined]
 [12] _apply
    @ ./boot.jl:946 [inlined]
 [13] adjoint
    @ ~/.julia/packages/Zygote/TWpme/src/lib/lib.jl:202 [inlined]
 [14] _pullback
    @ ~/.julia/packages/ZygoteRules/CkVIK/src/adjoint.jl:67 [inlined]
 [15] #solve#42
    @ ~/.julia/packages/DiffEqBase/PbBEl/src/solve.jl:1043 [inlined]
 [16] _pullback(::Zygote.Context{…}, ::DiffEqBase.var"##solve#42", ::Nothing, ::Nothing, ::Nothing, ::Val{…}, ::@Kwargs{…}, ::typeof(solve), ::ODEProblem{…})
    @ Zygote ~/.julia/packages/Zygote/TWpme/src/compiler/interface2.jl:0
 [17] _apply
    @ ./boot.jl:946 [inlined]
 [18] adjoint
    @ ~/.julia/packages/Zygote/TWpme/src/lib/lib.jl:202 [inlined]
 [19] _pullback
    @ ~/.julia/packages/ZygoteRules/CkVIK/src/adjoint.jl:67 [inlined]
 [20] solve
    @ ~/.julia/packages/DiffEqBase/PbBEl/src/solve.jl:1033 [inlined]
 [21] _pullback(::Zygote.Context{…}, ::typeof(Core.kwcall), ::@NamedTuple{…}, ::typeof(solve), ::ODEProblem{…})
    @ Zygote ~/.julia/packages/Zygote/TWpme/src/compiler/interface2.jl:0
 [22] temp
    @ ~/Documents/Julia/Custom_Lux_DONN/Neural_Feedback_VDP_Network.jl:447 [inlined]
 [23] _pullback(ctx::Zygote.Context{false}, f::typeof(temp), args::Float64)
    @ Zygote ~/.julia/packages/Zygote/TWpme/src/compiler/interface2.jl:0
 [24] pullback(f::Function, cx::Zygote.Context{false}, args::Float64)
    @ Zygote ~/.julia/packages/Zygote/TWpme/src/compiler/interface.jl:90
 [25] pullback
    @ ~/.julia/packages/Zygote/TWpme/src/compiler/interface.jl:88 [inlined]
 [26] gradient(f::Function, args::Float64)
    @ Zygote ~/.julia/packages/Zygote/TWpme/src/compiler/interface.jl:147
 [27] top-level scope
    @ ~/Documents/Julia/Custom_Lux_DONN/Neural_Feedback_VDP_Network.jl:453

But this works,

time = 0.:0.1:20
function temp(ps)
    x = sin.(ps .* time)
    function system_augment!(du, u, ps, t)
        du[1, :] = u[2, :]'  
        du[2, :] = (-1 .+ 1 .* ps[[trunc(Int64, 1 + t/0.1)]])'
    end
    system!(du, u, ps, t) = system_augment!(du, u, ps, t)
    prob = ODEProblem{true, SciMLBase.FullSpecialize}(system!, rand(2, 1), (0,1), x)  
    sol = solve(prob, saveat = 0.1, )
    out = sum(Array(sol))
    return out  
end
temp(0.1)

gradient(temp, 0.1) # (5188.061373048052,)

I thought it was interesting. The DataInterpolations usage seemed natural, the latter just seems weird, is there any reason for the DataInterpolations version to work?

Thanks very much

In your first example the line x_t = t -> x_interp(t) seems redundant; x_t is a trivial wrapper of x_interp.

The problem could be with Zygote support of DataInterpolations.jl, that has never been brought to a solid level iirc.

Ohhh, thanks very much for clarification. I think I read somewhere the focus is on developing Enzyme for most purposes. I will see what Enzyme does.

In your first example the line x_t = t -> x_interp(t) seems redundant; x_t is a trivial wrapper of x_interp .

Ohh, yes, thank you for this.