I am trying to optimize an ODE problem specified with ModelingToolkit. The objective function indexes into the ODESolution by a symbol. There is a Zygote adjoint for getindex(::VectorOfArray, i)
, and VectorOfArray
is an ancestor of ODESolution. It only handles integer indexing and I get errors when this pushback from ReverseArrayTools is called in the parameter optimization with a Num
. (Thats when I started learning about AD, so I am not familiar with all the background)
I read about the ChainRules package and think that I should define a more specialized reverse rule for ODESolution
or AbstractTimeseriesSolution
that can handle this case.
Now, in my trials (excerpt below), this rule is never called (Julia 1.6 on Ubuntu), although
I patiently restarted the REPL in each try.
How are the rules selected? What do I need to read and do?
... # setup of the problem, gradient is never called here
@variables t,x(t)
soltrue = solve(prob, Tsit5(), saveat = tsteps);
using ChainRulesCore
function ChainRulesCore.rrule(::typeof(getindex), VA::ODESolution, sym)
function ODESolution_getindex_pullback(Δ)
@show Δ
@show length(VA)
@show VA
@show VA.u
# convert symbol to index
i = issymbollike(sym) ? sym_to_index(sym, VA) : sym
@show i
# similar to VectorOfArray: return zero for non-matching indices
Δ′ = [ (i == j ? Δ : zero(x)) for (x,j) in zip(VA.u, 1:length(VA))]
(NO_FIELDS, Δ′)
end
VA[sym], ODESolution_getindex_pullback(Δ)
end
popt0 = [1.1]
f1(p) = soltrue[x][1] * p[1] # simple case, reality use p to generate a modified solution
f1(popt0)
#using Zygote
gr = Zygote.gradient(f1, popt0)
I still get the error of the pushback function of the rule defined in RecursiveArrayTools:
ZygoteRules.@adjoint function getindex(VA::AbstractVectorOfArray, i)
function AbstractVectorOfArray_getindex_adjoint(Δ)
@info "AbstractVectorOfArray_getindex_adjoint" # debug info inserted by me
@show typeof(VA)
Δ′ = [ (i == j ? Δ : zero(x)) for (x,j) in zip(VA.u, 1:length(VA))]
(Δ′,nothing)
end
VA[i],AbstractVectorOfArray_getindex_adjoint
end
[ Info: AbstractVectorOfArray_getindex_adjoint
typeof(VA) = ODESolution{Float64, 2, ...
...
ERROR: TypeError: non-boolean (Num) used in boolean context
Stacktrace:
[1] (::RecursiveArrayTools.var"#80#82"{Vector{Float64}, Num})(::Tuple{Vector{Float64}, Int64})
@ RecursiveArrayTools ./none:0
...
[6] Pullback
@ ./REPL[21]:1 [inlined]
[7] (::typeof(∂(f1)))(Δ::Float64)
@ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface2.jl:0
[8] (::Zygote.var"#41#42"{typeof(∂(f1))})(Δ::Float64)
@ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface.jl:41
[9] gradient(f::Function, args::Vector{Float64})
@ Zygote ~/.julia/packages/Zygote/RxTZu/src/compiler/interface.jl:59
[10] top-level scope
@ REPL[24]:1