Greetings!
The question is regarding AD through DataInterpolations, ODE. The function is described as follows,
using DataInterpolations, OrdinaryDiffEq, SciMLSensitivity, Zygote
time = 0.:0.1:20
function temp(ps)
x = sin.(ps .* time)
x_interp = CubicSpline(x, time)
x_t = t -> x_interp(t)
function system_augment!(du, u, ps, t)
du[1, :] .= u[2, :]'
du[2, :] .= (-1 .+ 1 .* ps(t))'
end
system!(du, u, ps, t) = system_augment!(du, u, ps, t)
prob = ODEProblem{true, SciMLBase.FullSpecialize}(system!, rand(2, 1), (0, 2), x_t)
sol = solve(prob, saveat = 0.1, )
out = sum(Array(sol))
return out
end
temp(0.1)
Now calculating the error,
gradient(temp, 0.1) # Fails, the error is given by
ERROR: `p` is not a SciMLStructure. This is required for adjoint sensitivity analysis. For more information,
see the documentation on SciMLStructures.jl for the definition of the SciMLStructures interface.
In particular, adjoint sensitivities only applies to `Tunable`.
Stacktrace:
[1] automatic_sensealg_choice(prob::ODEProblem{…}, u0::Matrix{…}, p::@NamedTuple{…}, verbose::Bool, repack::Functors.var"#3#6"{…})
@ SciMLSensitivity ~/.julia/packages/SciMLSensitivity/7UEpc/src/concrete_solve.jl:87
[2] _concrete_solve_adjoint(::ODEProblem{…}, ::Nothing, ::Nothing, ::Matrix{…}, ::Function, ::SciMLBase.ChainRulesOriginator; verbose::Bool, kwargs::@Kwargs{…})
@ SciMLSensitivity ~/.julia/packages/SciMLSensitivity/7UEpc/src/concrete_solve.jl:274
[3] _concrete_solve_adjoint
@ ~/.julia/packages/SciMLSensitivity/7UEpc/src/concrete_solve.jl:245 [inlined]
[4] #_solve_adjoint#66
@ ~/.julia/packages/DiffEqBase/PbBEl/src/solve.jl:1584 [inlined]
[5] _solve_adjoint
@ ~/.julia/packages/DiffEqBase/PbBEl/src/solve.jl:1557 [inlined]
[6] #rrule#4
@ ~/.julia/packages/DiffEqBase/PbBEl/ext/DiffEqBaseChainRulesCoreExt.jl:26 [inlined]
[7] rrule
@ ~/.julia/packages/DiffEqBase/PbBEl/ext/DiffEqBaseChainRulesCoreExt.jl:22 [inlined]
[8] rrule
@ ~/.julia/packages/ChainRulesCore/U6wNx/src/rules.jl:144 [inlined]
[9] chain_rrule_kw
@ ~/.julia/packages/Zygote/TWpme/src/compiler/chainrules.jl:236 [inlined]
[10] macro expansion
@ ~/.julia/packages/Zygote/TWpme/src/compiler/interface2.jl:0 [inlined]
[11] _pullback
@ ~/.julia/packages/Zygote/TWpme/src/compiler/interface2.jl:91 [inlined]
[12] _apply
@ ./boot.jl:946 [inlined]
[13] adjoint
@ ~/.julia/packages/Zygote/TWpme/src/lib/lib.jl:202 [inlined]
[14] _pullback
@ ~/.julia/packages/ZygoteRules/CkVIK/src/adjoint.jl:67 [inlined]
[15] #solve#42
@ ~/.julia/packages/DiffEqBase/PbBEl/src/solve.jl:1043 [inlined]
[16] _pullback(::Zygote.Context{…}, ::DiffEqBase.var"##solve#42", ::Nothing, ::Nothing, ::Nothing, ::Val{…}, ::@Kwargs{…}, ::typeof(solve), ::ODEProblem{…})
@ Zygote ~/.julia/packages/Zygote/TWpme/src/compiler/interface2.jl:0
[17] _apply
@ ./boot.jl:946 [inlined]
[18] adjoint
@ ~/.julia/packages/Zygote/TWpme/src/lib/lib.jl:202 [inlined]
[19] _pullback
@ ~/.julia/packages/ZygoteRules/CkVIK/src/adjoint.jl:67 [inlined]
[20] solve
@ ~/.julia/packages/DiffEqBase/PbBEl/src/solve.jl:1033 [inlined]
[21] _pullback(::Zygote.Context{…}, ::typeof(Core.kwcall), ::@NamedTuple{…}, ::typeof(solve), ::ODEProblem{…})
@ Zygote ~/.julia/packages/Zygote/TWpme/src/compiler/interface2.jl:0
[22] temp
@ ~/Documents/Julia/Custom_Lux_DONN/Neural_Feedback_VDP_Network.jl:447 [inlined]
[23] _pullback(ctx::Zygote.Context{false}, f::typeof(temp), args::Float64)
@ Zygote ~/.julia/packages/Zygote/TWpme/src/compiler/interface2.jl:0
[24] pullback(f::Function, cx::Zygote.Context{false}, args::Float64)
@ Zygote ~/.julia/packages/Zygote/TWpme/src/compiler/interface.jl:90
[25] pullback
@ ~/.julia/packages/Zygote/TWpme/src/compiler/interface.jl:88 [inlined]
[26] gradient(f::Function, args::Float64)
@ Zygote ~/.julia/packages/Zygote/TWpme/src/compiler/interface.jl:147
[27] top-level scope
@ ~/Documents/Julia/Custom_Lux_DONN/Neural_Feedback_VDP_Network.jl:453
But this works,
time = 0.:0.1:20
function temp(ps)
x = sin.(ps .* time)
function system_augment!(du, u, ps, t)
du[1, :] = u[2, :]'
du[2, :] = (-1 .+ 1 .* ps[[trunc(Int64, 1 + t/0.1)]])'
end
system!(du, u, ps, t) = system_augment!(du, u, ps, t)
prob = ODEProblem{true, SciMLBase.FullSpecialize}(system!, rand(2, 1), (0,1), x)
sol = solve(prob, saveat = 0.1, )
out = sum(Array(sol))
return out
end
temp(0.1)
gradient(temp, 0.1) # (5188.061373048052,)
I thought it was interesting. The DataInterpolations usage seemed natural, the latter just seems weird, is there any reason for the DataInterpolations version to work?
Thanks very much