Hello everyone,
I have recently been working with Graph Neural Ordinary Differential Equations for structure predictions of proteins. This has been working quite well as long as the graph structure is fixed and given to the GCNConv layer using WithGraph, as shown in this simple example:
mat = [1. 1. 0. 0.
0. 1. 0. 1.
1. 0. 1. 0.]
adj_mat = [0 1 1 0
1 0 0 0
1 0 0 1
0 0 1 0]
fg = FeaturedGraph(adj_mat)
layers = DiffEqFlux.Chain(
WithGraph(fg, GCNConv(3 => 9, relu)),
WithGraph(fg, GCNConv(9 => 3))
)
prob = NeuralODE(layers, (0.0, 10.0), Tsit5())
sol = prob(mat)
I would however like to create a GDE where the graph structure is variable and included in the input for the network instead. GCNConv layers support this: simply hand them a FeaturedGraph as input that contains the adjancency matrix and the feature matrix. But I can’t seem to figure out how to combine this with a NeuralODE.
Directly passing the NeuralODE a FeaturedGraph does not work.
fg = FeaturedGraph(adj_mat, nf=mat)
layers = DiffEqFlux.Chain(
GCNConv(3 => 9, relu),
GCNConv(9 => 3)
)
prob = NeuralODE(layers, (0.0, 10.0), Tsit5())
sol = prob(fg)
> ERROR: LoadError: MethodError: no method matching oneunit(::Type{Any})
Closest candidates are:
oneunit(::Type{Union{Missing, T}}) where T at missing.jl:105
oneunit(::Type{T}) where T at number.jl:319
oneunit(::T) where T at number.jl:318
I have also attempted to sidestep this issue by applying the graph convolution prior to the NeuralODE like so
n_ode = NeuralODE(Dense(9, 9, relu), (0.0, 10.0), Tsit5())
prob = DiffEqFlux.Chain(
GCNConv(3 => 9, relu),
x -> x.nf,
n_ode,
Array,
Dense(9, 3)
)
sol = prob(fg)
This works for prediction. However, I cannot train this network using sciml_train.
# Dummy loss function
loss_n_ode = function(p::Vector{Float32})
return sum(abs2, sol)
end
callback = function(p::Vector{Float32}, l::Float64)
display(l)
return false
end
DiffEqFlux.sciml_train(loss_n_ode, initial_params(prob), cb = callback, maxiters = 300)
> LoadError: MethodError: Cannot `convert` an object of type Nothing to an object of type Float32
Closest candidates are:
convert(::Type{T}, !Matched::Union{InitialValues.SpecificInitialValue{typeof(*)}, InitialValues.SpecificInitialValue{typeof(Base.mul_prod)}}) where T<:Union{AbstractString, Number} at C:\Users\Marc\.julia\packages\InitialValues\OWP8V\src\InitialValues.jl:258
convert(::Type{T}, !Matched::VectorizationBase.AbstractSIMD) where T<:Union{Bool, Float16, Float32, Float64, Int16, Int32, Int64, Int8, UInt16, UInt32, UInt64, UInt8, SIMDTypes.Bit} at C:\Users\Marc\.julia\packages\VectorizationBase\89gBr\src\base_defs.jl:153
convert(::Type{T}, !Matched::DualNumbers.Dual) where T<:Union{Real, Complex} at C:\Users\Marc\.julia\packages\DualNumbers\5knFX\src\dual.jl:24
...
in expression starting at C:\Users\Marc\Proteinet\src\GDE.jl:48
fill!(dest::Vector{Float32}, x::Nothing) at array.jl:333
copyto! at broadcast.jl:944 [inlined]
materialize! at broadcast.jl:894 [inlined]
materialize!(dest::Vector{Float32}, bc::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{0}, Nothing, typeof(identity), Tuple{Base.RefValue{Nothing}}}) at broadcast.jl:891
(::GalacticOptim.var"#268#278"{GalacticOptim.var"#267#277"{OptimizationFunction{true, GalacticOptim.AutoZygote, DiffEqFlux.var"#84#89"{var"#3457#3458"}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, Nothing}})(::Vector{Float32}, ::Vector{Float32}) at zygote.jl:8
macro expansion at flux.jl:27 [inlined]
macro expansion at utils.jl:35 [inlined]
__solve(prob::OptimizationProblem{false, OptimizationFunction{false, GalacticOptim.AutoZygote, OptimizationFunction{true, GalacticOptim.AutoZygote, DiffEqFlux.var"#84#89"{var"#3457#3458"}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, GalacticOptim.var"#268#278"{GalacticOptim.var"#267#277"{OptimizationFunction{true, GalacticOptim.AutoZygote, DiffEqFlux.var"#84#89"{var"#3457#3458"}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, Nothing}}, GalacticOptim.var"#271#281"{GalacticOptim.var"#267#277"{OptimizationFunction{true, GalacticOptim.AutoZygote, DiffEqFlux.var"#84#89"{var"#3457#3458"}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, Nothing}}, GalacticOptim.var"#276#286", Nothing, Nothing, Nothing}, Vector{Float32}, SciMLBase.NullParameters, Nothing, Nothing, Nothing, Nothing, Base.Iterators.Pairs{Symbol, var"#3459#3460", Tuple{Symbol}, NamedTuple{(:cb,), Tuple{var"#3459#3460"}}}}, opt::ADAM, data::Base.Iterators.Cycle{Tuple{GalacticOptim.NullData}}; maxiters::Int64, cb::Function, progress::Bool, save_best::Bool, kwargs::Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}) at flux.jl:25
__solve at flux.jl:5 [inlined]
__solve at flux.jl:5 [inlined]
#solve#476 at solve.jl:3 [inlined]
(::CommonSolve.var"#solve##kw")(::NamedTuple{(:maxiters, :cb), Tuple{Int64, var"#3459#3460"}}, ::typeof(solve), ::OptimizationProblem{false, OptimizationFunction{false, GalacticOptim.AutoZygote, OptimizationFunction{true, GalacticOptim.AutoZygote, DiffEqFlux.var"#84#89"{var"#3457#3458"}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, GalacticOptim.var"#268#278"{GalacticOptim.var"#267#277"{OptimizationFunction{true, GalacticOptim.AutoZygote, DiffEqFlux.var"#84#89"{var"#3457#3458"}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, Nothing}}, GalacticOptim.var"#271#281"{GalacticOptim.var"#267#277"{OptimizationFunction{true, GalacticOptim.AutoZygote, DiffEqFlux.var"#84#89"{var"#3457#3458"}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, Nothing}}, GalacticOptim.var"#276#286", Nothing, Nothing, Nothing}, Vector{Float32}, SciMLBase.NullParameters, Nothing, Nothing, Nothing, Nothing, Base.Iterators.Pairs{Symbol, var"#3459#3460", Tuple{Symbol}, NamedTuple{(:cb,), Tuple{var"#3459#3460"}}}}, ::ADAM) at solve.jl:3
sciml_train(::var"#3457#3458", ::Vector{Float32}, ::Nothing, ::Nothing; lower_bounds::Nothing, upper_bounds::Nothing, maxiters::Int64, kwargs::Base.Iterators.Pairs{Symbol, var"#3459#3460", Tuple{Symbol}, NamedTuple{(:cb,), Tuple{var"#3459#3460"}}}) at train.jl:106
(::DiffEqFlux.var"#sciml_train##kw")(::NamedTuple{(:cb, :maxiters), Tuple{var"#3459#3460", Int64}}, ::typeof(DiffEqFlux.sciml_train), ::Function, ::Vector{Float32}, ::Nothing, ::Nothing) at train.jl:57
(::DiffEqFlux.var"#sciml_train##kw")(::NamedTuple{(:cb, :maxiters), Tuple{var"#3459#3460", Int64}}, ::typeof(DiffEqFlux.sciml_train), loss::Function, θ::Vector{Float32}) at train.jl:57
top-level scope at GDE.jl:48
eval at boot.jl:360 [inlined]
include_string(mapexpr::typeof(identity), mod::Module, code::String, filename::String) at loading.jl:1116
I’ve tried looking at the GDE examples that come with DiffEqFlux, but they all seem to use a fixed graph structure, so I am unsure how to proceed.