Using a variable graph structure with NeuralODE and GCNConv

Hello everyone,

I have recently been working with Graph Neural Ordinary Differential Equations for structure predictions of proteins. This has been working quite well as long as the graph structure is fixed and given to the GCNConv layer using WithGraph, as shown in this simple example:

mat = [1. 1. 0. 0.
       0. 1. 0. 1.
       1. 0. 1. 0.]

adj_mat = [0 1 1 0
           1 0 0 0
           1 0 0 1
           0 0 1 0]

fg = FeaturedGraph(adj_mat)
layers = DiffEqFlux.Chain(
            WithGraph(fg, GCNConv(3 => 9, relu)),
            WithGraph(fg, GCNConv(9 => 3))
         )
prob = NeuralODE(layers, (0.0, 10.0), Tsit5())
sol = prob(mat)

I would however like to create a GDE where the graph structure is variable and included in the input for the network instead. GCNConv layers support this: simply hand them a FeaturedGraph as input that contains the adjancency matrix and the feature matrix. But I can’t seem to figure out how to combine this with a NeuralODE.
Directly passing the NeuralODE a FeaturedGraph does not work.

fg = FeaturedGraph(adj_mat, nf=mat)
layers = DiffEqFlux.Chain(
            GCNConv(3 => 9, relu),
            GCNConv(9 => 3)
         )
prob = NeuralODE(layers, (0.0, 10.0), Tsit5())
sol = prob(fg)

> ERROR: LoadError: MethodError: no method matching oneunit(::Type{Any})
Closest candidates are:
  oneunit(::Type{Union{Missing, T}}) where T at missing.jl:105
  oneunit(::Type{T}) where T at number.jl:319
  oneunit(::T) where T at number.jl:318

I have also attempted to sidestep this issue by applying the graph convolution prior to the NeuralODE like so

n_ode = NeuralODE(Dense(9, 9, relu), (0.0, 10.0), Tsit5())
prob = DiffEqFlux.Chain(
            GCNConv(3 => 9, relu),
            x -> x.nf,
            n_ode,
            Array,
            Dense(9, 3)
         )
sol = prob(fg)

This works for prediction. However, I cannot train this network using sciml_train.

# Dummy loss function
loss_n_ode = function(p::Vector{Float32})
   return sum(abs2, sol)
end

callback = function(p::Vector{Float32}, l::Float64)
   display(l)
   return false
end

DiffEqFlux.sciml_train(loss_n_ode, initial_params(prob), cb = callback, maxiters = 300)

> LoadError: MethodError: Cannot `convert` an object of type Nothing to an object of type Float32
Closest candidates are:
  convert(::Type{T}, !Matched::Union{InitialValues.SpecificInitialValue{typeof(*)}, InitialValues.SpecificInitialValue{typeof(Base.mul_prod)}}) where T<:Union{AbstractString, Number} at C:\Users\Marc\.julia\packages\InitialValues\OWP8V\src\InitialValues.jl:258
  convert(::Type{T}, !Matched::VectorizationBase.AbstractSIMD) where T<:Union{Bool, Float16, Float32, Float64, Int16, Int32, Int64, Int8, UInt16, UInt32, UInt64, UInt8, SIMDTypes.Bit} at C:\Users\Marc\.julia\packages\VectorizationBase\89gBr\src\base_defs.jl:153
  convert(::Type{T}, !Matched::DualNumbers.Dual) where T<:Union{Real, Complex} at C:\Users\Marc\.julia\packages\DualNumbers\5knFX\src\dual.jl:24
  ...
in expression starting at C:\Users\Marc\Proteinet\src\GDE.jl:48
fill!(dest::Vector{Float32}, x::Nothing) at array.jl:333
copyto! at broadcast.jl:944 [inlined]
materialize! at broadcast.jl:894 [inlined]
materialize!(dest::Vector{Float32}, bc::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{0}, Nothing, typeof(identity), Tuple{Base.RefValue{Nothing}}}) at broadcast.jl:891
(::GalacticOptim.var"#268#278"{GalacticOptim.var"#267#277"{OptimizationFunction{true, GalacticOptim.AutoZygote, DiffEqFlux.var"#84#89"{var"#3457#3458"}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, Nothing}})(::Vector{Float32}, ::Vector{Float32}) at zygote.jl:8
macro expansion at flux.jl:27 [inlined]
macro expansion at utils.jl:35 [inlined]
__solve(prob::OptimizationProblem{false, OptimizationFunction{false, GalacticOptim.AutoZygote, OptimizationFunction{true, GalacticOptim.AutoZygote, DiffEqFlux.var"#84#89"{var"#3457#3458"}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, GalacticOptim.var"#268#278"{GalacticOptim.var"#267#277"{OptimizationFunction{true, GalacticOptim.AutoZygote, DiffEqFlux.var"#84#89"{var"#3457#3458"}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, Nothing}}, GalacticOptim.var"#271#281"{GalacticOptim.var"#267#277"{OptimizationFunction{true, GalacticOptim.AutoZygote, DiffEqFlux.var"#84#89"{var"#3457#3458"}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, Nothing}}, GalacticOptim.var"#276#286", Nothing, Nothing, Nothing}, Vector{Float32}, SciMLBase.NullParameters, Nothing, Nothing, Nothing, Nothing, Base.Iterators.Pairs{Symbol, var"#3459#3460", Tuple{Symbol}, NamedTuple{(:cb,), Tuple{var"#3459#3460"}}}}, opt::ADAM, data::Base.Iterators.Cycle{Tuple{GalacticOptim.NullData}}; maxiters::Int64, cb::Function, progress::Bool, save_best::Bool, kwargs::Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}) at flux.jl:25
__solve at flux.jl:5 [inlined]
__solve at flux.jl:5 [inlined]
#solve#476 at solve.jl:3 [inlined]
(::CommonSolve.var"#solve##kw")(::NamedTuple{(:maxiters, :cb), Tuple{Int64, var"#3459#3460"}}, ::typeof(solve), ::OptimizationProblem{false, OptimizationFunction{false, GalacticOptim.AutoZygote, OptimizationFunction{true, GalacticOptim.AutoZygote, DiffEqFlux.var"#84#89"{var"#3457#3458"}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, GalacticOptim.var"#268#278"{GalacticOptim.var"#267#277"{OptimizationFunction{true, GalacticOptim.AutoZygote, DiffEqFlux.var"#84#89"{var"#3457#3458"}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, Nothing}}, GalacticOptim.var"#271#281"{GalacticOptim.var"#267#277"{OptimizationFunction{true, GalacticOptim.AutoZygote, DiffEqFlux.var"#84#89"{var"#3457#3458"}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, Nothing}}, GalacticOptim.var"#276#286", Nothing, Nothing, Nothing}, Vector{Float32}, SciMLBase.NullParameters, Nothing, Nothing, Nothing, Nothing, Base.Iterators.Pairs{Symbol, var"#3459#3460", Tuple{Symbol}, NamedTuple{(:cb,), Tuple{var"#3459#3460"}}}}, ::ADAM) at solve.jl:3
sciml_train(::var"#3457#3458", ::Vector{Float32}, ::Nothing, ::Nothing; lower_bounds::Nothing, upper_bounds::Nothing, maxiters::Int64, kwargs::Base.Iterators.Pairs{Symbol, var"#3459#3460", Tuple{Symbol}, NamedTuple{(:cb,), Tuple{var"#3459#3460"}}}) at train.jl:106
(::DiffEqFlux.var"#sciml_train##kw")(::NamedTuple{(:cb, :maxiters), Tuple{var"#3459#3460", Int64}}, ::typeof(DiffEqFlux.sciml_train), ::Function, ::Vector{Float32}, ::Nothing, ::Nothing) at train.jl:57
(::DiffEqFlux.var"#sciml_train##kw")(::NamedTuple{(:cb, :maxiters), Tuple{var"#3459#3460", Int64}}, ::typeof(DiffEqFlux.sciml_train), loss::Function, θ::Vector{Float32}) at train.jl:57
top-level scope at GDE.jl:48
eval at boot.jl:360 [inlined]
include_string(mapexpr::typeof(identity), mod::Module, code::String, filename::String) at loading.jl:1116

I’ve tried looking at the GDE examples that come with DiffEqFlux, but they all seem to use a fixed graph structure, so I am unsure how to proceed.

1 Like

Hi! I am glad to see feedback from you.
It’s flexible to adapt variable graph strategy to several graph convolution layers.
Unfortunately, currently variable graph strategy is not supported for GNODE.
If you’re interested in this approach, it’s welcome to fire an issue to GeometricFlux.jl.

1 Like

Thank you very much for the reponse. It’s unfortunate that this is currently not yet implemented.