Local edge feature aggregation in GraphNeuralNetworks.jl

All the ingredients should be in place. You can define a convolution like that as

"""
    MEGNetConv(in => out; aggr=mean)

Convolution from [Graph Networks as a Universal Machine Learning Framework for Molecules and Crystals](https://arxiv.org/pdf/1812.05055.pdf)
paper.
"""
using GraphNeuralNetworks, Flux, Statistics
using GraphNeuralNetworks: aggregate_neighbors

struct MEGNetConv <: GNNLayer
    ϕe
    ϕv 
    aggr
end

Flux.@functor MEGNetConv

function MEGNetConv(ch::Pair{Int,Int}; aggr=mean)
    nin, nout = ch 
    ϕe = Chain(Dense(3nin, nout, relu),
               Dense(nout, nout))

    ϕv = Chain(Dense(nin + nout, nout, relu),
               Dense(nout, nout))

    MEGNetConv(ϕe, ϕv, aggr)
end

function (m::MEGNetConv)(g::GNNGraph, x::AbstractMatrix, e::AbstractMatrix)
    ē = apply_edges(g, xi=x, xj=x, e=e) do xi, xj, e
        m.ϕe(vcat(xi, xj, e))
    end

    xᵉ = aggregate_neighbors(g, m.aggr, ē)

    x̄ = m.ϕv(vcat(x, xᵉ))

    return x̄, ē
end

g = rand_graph(10, 40)
x = randn(3, 10)
e = randn(3, 40)
m = MEGNetConv(3=>3)
x̄, ē = m(g, x, e)

Aggregation operations that are not +, max, min, mean are not supported yet.
For using the conv layer inside a large model you may want to refer to the
Explicit modeling section of the docs.

Let me know if you need any features to be added to GNN.jl!
I created a PR with the MEGNet layer

5 Likes