I’m trying to model Graph NODEs integrating GraphNeuralNetworks.jl
and OrdinaryDiffEq.jl
. I am trying to learn both the neural network parameters as well as the weights of the edges, so I have to manually modify the Flux
parameters during prediction. When I run the following MWE:
using Graphs, GraphNeuralNetworks, Flux, OrdinaryDiffEq, ComponentArrays, Zygote, SciMLSensitivity
time = 1:10
x0 = rand(9)
obs = rand(9,10)
fullGraph = GNNGraph(complete_digraph(3))
layer1 = GCNConv(3 => 10,tanh,use_edge_weight=true)
layer2 = GCNConv(10 => 3,use_edge_weight=true)
chain = GNNChain(layer1,layer2)
pinit = ComponentArray{Float32}(weights = rand(ne(fullGraph)),
layer1 = f64(layer1.weight),layer2 = f64(layer2.weight))
function predict(p)
fullGraph = GNNGraph(complete_digraph(3))
fullGraph = set_edge_weight(fullGraph,p.weights)
chain.layers[1].weight .= p.layer1
chain.layers[2].weight .= p.layer2
function nn!(du,u,p,t)
uGraph = reshape(u,(3,3))
dGraph = reshape(chain(fullGraph,uGraph),(3*3))
du .= dGraph
end
prob = ODEProblem(nn!,x0,(time[1],time[end]),saveat=time)
sol = solve(prob)
return Array(sol)
end
function loss_function(p)
pred = predict(p)
sum(abs2,pred .- obs)
end
Zygote.gradient(loss_function,pinit)
I get the following error:
ERROR: BoundsError: attempt to access 10-element UnitRange{Int64} at index [0]
Stacktrace:
[1] throw_boundserror(A::UnitRange{Int64}, I::Int64)
@ Base .\abstractarray.jl:737
[2] getindex
@ .\range.jl:930 [inlined]
[3] (::SciMLSensitivity.ReverseLossCallback{…})(integrator::OrdinaryDiffEq.ODEIntegrator{…})
@ SciMLSensitivity C:\Users\JArroyo-Esquivel\.julia\packages\SciMLSensitivity\waEMv\src\adjoint_common.jl:530
[4] #111
@ C:\Users\JArroyo-Esquivel\.julia\packages\DiffEqCallbacks\9fKPq\src\preset_time.jl:58 [inlined]
[5] apply_discrete_callback!
@ C:\Users\JArroyo-Esquivel\.julia\packages\DiffEqBase\uRikl\src\callbacks.jl:613 [inlined]
[6] apply_discrete_callback!
@ C:\Users\JArroyo-Esquivel\.julia\packages\DiffEqBase\uRikl\src\callbacks.jl:628 [inlined]
[7] handle_callbacks!(integrator::OrdinaryDiffEq.ODEIntegrator{…})
@ OrdinaryDiffEq C:\Users\JArroyo-Esquivel\.julia\packages\OrdinaryDiffEq\XAFCL\src\integrators\integrator_utils.jl:349
[8] _loopfooter!(integrator::OrdinaryDiffEq.ODEIntegrator{…})
@ OrdinaryDiffEq C:\Users\JArroyo-Esquivel\.julia\packages\OrdinaryDiffEq\XAFCL\src\integrators\integrator_utils.jl:254
[9] loopfooter!
@ C:\Users\JArroyo-Esquivel\.julia\packages\OrdinaryDiffEq\XAFCL\src\integrators\integrator_utils.jl:207 [inlined]
[10] solve!(integrator::OrdinaryDiffEq.ODEIntegrator{…})
@ OrdinaryDiffEq C:\Users\JArroyo-Esquivel\.julia\packages\OrdinaryDiffEq\XAFCL\src\solve.jl:558
[11] #__solve#670
@ C:\Users\JArroyo-Esquivel\.julia\packages\OrdinaryDiffEq\XAFCL\src\solve.jl:7 [inlined]
[12] __solve
@ C:\Users\JArroyo-Esquivel\.julia\packages\OrdinaryDiffEq\XAFCL\src\solve.jl:1 [inlined]
[13] solve_call(_prob::ODEProblem{…}, args::CompositeAlgorithm{…}; merge_callbacks::Bool, kwargshandle::Nothing, kwargs::@Kwargs{…})
@ DiffEqBase C:\Users\JArroyo-Esquivel\.julia\packages\DiffEqBase\uRikl\src\solve.jl:612
[14] solve_call
@ C:\Users\JArroyo-Esquivel\.julia\packages\DiffEqBase\uRikl\src\solve.jl:569 [inlined]
[15] solve_up(prob::ODEProblem{…}, sensealg::Nothing, u0::Vector{…}, p::SciMLBase.NullParameters, args::CompositeAlgorithm{…}; kwargs::@Kwargs{…})
@ DiffEqBase C:\Users\JArroyo-Esquivel\.julia\packages\DiffEqBase\uRikl\src\solve.jl:1080
[16] solve_up
@ C:\Users\JArroyo-Esquivel\.julia\packages\DiffEqBase\uRikl\src\solve.jl:1066 [inlined]
[17] #solve#51
@ C:\Users\JArroyo-Esquivel\.julia\packages\DiffEqBase\uRikl\src\solve.jl:1003 [inlined]
[18] solve
@ C:\Users\JArroyo-Esquivel\.julia\packages\DiffEqBase\uRikl\src\solve.jl:993 [inlined]
[19] #__solve#675
@ C:\Users\JArroyo-Esquivel\.julia\packages\OrdinaryDiffEq\XAFCL\src\solve.jl:547 [inlined]
[20] __solve
@ C:\Users\JArroyo-Esquivel\.julia\packages\OrdinaryDiffEq\XAFCL\src\solve.jl:546 [inlined]
[21] solve_call(_prob::ODEProblem{…}, args::Nothing; merge_callbacks::Bool, kwargshandle::Nothing, kwargs::@Kwargs{…})
@ DiffEqBase C:\Users\JArroyo-Esquivel\.julia\packages\DiffEqBase\uRikl\src\solve.jl:612
[22] solve_up(prob::ODEProblem{…}, sensealg::Nothing, u0::Vector{…}, p::SciMLBase.NullParameters, args::Nothing; kwargs::@Kwargs{…})
@ DiffEqBase C:\Users\JArroyo-Esquivel\.julia\packages\DiffEqBase\uRikl\src\solve.jl:1072
[23] solve_up
@ C:\Users\JArroyo-Esquivel\.julia\packages\DiffEqBase\uRikl\src\solve.jl:1066 [inlined]
[24] solve(prob::ODEProblem{…}, args::Nothing; sensealg::Nothing, u0::Nothing, p::Nothing, wrap::Val{…}, kwargs::@Kwargs{…})
@ DiffEqBase C:\Users\JArroyo-Esquivel\.julia\packages\DiffEqBase\uRikl\src\solve.jl:1003
[25] _adjoint_sensitivities(sol::ODESolution{…}, sensealg::QuadratureAdjoint{…}, alg::Nothing; t::UnitRange{…}, dgdu_discrete::Function, dgdp_discrete::Nothing, dgdu_continuous::Nothing, dgdp_continuous::Nothing, g::Nothing, abstol::Float64, reltol::Float64, callback::Nothing, kwargs::@Kwargs{…})
@ SciMLSensitivity C:\Users\JArroyo-Esquivel\.julia\packages\SciMLSensitivity\waEMv\src\quadrature_adjoint.jl:340
[26] _adjoint_sensitivities
@ C:\Users\JArroyo-Esquivel\.julia\packages\SciMLSensitivity\waEMv\src\quadrature_adjoint.jl:328 [inlined]
[27] #adjoint_sensitivities#63
@ C:\Users\JArroyo-Esquivel\.julia\packages\SciMLSensitivity\waEMv\src\sensitivity_interface.jl:386 [inlined]
[28] (::SciMLSensitivity.var"#adjoint_sensitivity_backpass#314"{…})(Δ::ODESolution{…})
@ SciMLSensitivity C:\Users\JArroyo-Esquivel\.julia\packages\SciMLSensitivity\waEMv\src\concrete_solve.jl:582
[29] ZBack
@ C:\Users\JArroyo-Esquivel\.julia\packages\Zygote\nsBv0\src\compiler\chainrules.jl:211 [inlined]
[30] (::Zygote.var"#291#292"{…})(Δ::ODESolution{…})
@ Zygote C:\Users\JArroyo-Esquivel\.julia\packages\Zygote\nsBv0\src\lib\lib.jl:206
[31] (::Zygote.var"#2169#back#293"{…})(Δ::ODESolution{…})
@ Zygote C:\Users\JArroyo-Esquivel\.julia\packages\ZygoteRules\M4xmc\src\adjoint.jl:72
[32] #solve#51
@ C:\Users\JArroyo-Esquivel\.julia\packages\DiffEqBase\uRikl\src\solve.jl:1003 [inlined]
[33] (::Zygote.Pullback{…})(Δ::ODESolution{…})
@ Zygote C:\Users\JArroyo-Esquivel\.julia\packages\Zygote\nsBv0\src\compiler\interface2.jl:0
[34] #291
@ C:\Users\JArroyo-Esquivel\.julia\packages\Zygote\nsBv0\src\lib\lib.jl:206 [inlined]
[35] #2169#back
@ C:\Users\JArroyo-Esquivel\.julia\packages\ZygoteRules\M4xmc\src\adjoint.jl:72 [inlined]
[36] solve
@ C:\Users\JArroyo-Esquivel\.julia\packages\DiffEqBase\uRikl\src\solve.jl:993 [inlined]
[37] (::Zygote.Pullback{…})(Δ::ODESolution{…})
@ Zygote C:\Users\JArroyo-Esquivel\.julia\packages\Zygote\nsBv0\src\compiler\interface2.jl:0
[38] predict
@ .\REPL[11]:13 [inlined]
[39] (::Zygote.Pullback{Tuple{typeof(predict), ComponentVector{Float32, Vector{Float32}, Tuple{Axis{…}}}}, Any})(Δ::Matrix{Float64})
@ Zygote C:\Users\JArroyo-Esquivel\.julia\packages\Zygote\nsBv0\src\compiler\interface2.jl:0
[40] loss_function
@ .\REPL[12]:2 [inlined]
[41] (::Zygote.var"#75#76"{Zygote.Pullback{Tuple{…}, Tuple{…}}})(Δ::Float64)
@ Zygote C:\Users\JArroyo-Esquivel\.julia\packages\Zygote\nsBv0\src\compiler\interface.jl:91
[42] gradient(f::Function, args::ComponentVector{Float32, Vector{Float32}, Tuple{Axis{…}}})
@ Zygote C:\Users\JArroyo-Esquivel\.julia\packages\Zygote\nsBv0\src\compiler\interface.jl:148
[43] top-level scope
@ REPL[17]:1
Running loss_function(pinit)
and Zygote.pullback(loss_function,pinit)
outputs without an issue, so there’s something going on when trying to calculate the gradient.