All the ingredients should be in place. You can define a convolution like that as
"""
MEGNetConv(in => out; aggr=mean)
Convolution from [Graph Networks as a Universal Machine Learning Framework for Molecules and Crystals](https://arxiv.org/pdf/1812.05055.pdf)
paper.
"""
using GraphNeuralNetworks, Flux, Statistics
using GraphNeuralNetworks: aggregate_neighbors
struct MEGNetConv <: GNNLayer
ϕe
ϕv
aggr
end
Flux.@functor MEGNetConv
function MEGNetConv(ch::Pair{Int,Int}; aggr=mean)
nin, nout = ch
ϕe = Chain(Dense(3nin, nout, relu),
Dense(nout, nout))
ϕv = Chain(Dense(nin + nout, nout, relu),
Dense(nout, nout))
MEGNetConv(ϕe, ϕv, aggr)
end
function (m::MEGNetConv)(g::GNNGraph, x::AbstractMatrix, e::AbstractMatrix)
ē = apply_edges(g, xi=x, xj=x, e=e) do xi, xj, e
m.ϕe(vcat(xi, xj, e))
end
xᵉ = aggregate_neighbors(g, m.aggr, ē)
x̄ = m.ϕv(vcat(x, xᵉ))
return x̄, ē
end
g = rand_graph(10, 40)
x = randn(3, 10)
e = randn(3, 40)
m = MEGNetConv(3=>3)
x̄, ē = m(g, x, e)
Aggregation operations that are not +, max, min, mean
are not supported yet.
For using the conv layer inside a large model you may want to refer to the
Explicit modeling section of the docs.
Let me know if you need any features to be added to GNN.jl!
I created a PR with the MEGNet layer