Local edge feature aggregation in GraphNeuralNetworks.jl

I’m looking at a broad class of graph neural network update operations defined as follows:

for each edge `k` connecting vertices `s` and `r`
    ē_k = ϕ_e(concat(v_s, v_r, e_k))

for each node `i`
    v_{i, e} = Pool({ē_k : k ∈ E(i)})
    v̄_i = ϕ_v(concat(v_{i, e}, v_i))

where e_k is the embedding of edge k, and ē_k is its update, v_i is the embedding of node i and v̄_i is its update, E(i) is the set of edges that lead into node i, and ϕ_e and ϕ_v are MLPs. This update is similar to one used for materials modeling here: https://doi.org/10.1021/acs.chemmater.9b01294

Is the functionality needed for an update like this currently available in GraphNeuralNetworks.jl? If not, are there any pointers on what functions would be needed to implement it?


All the ingredients should be in place. You can define a convolution like that as

    MEGNetConv(in => out; aggr=mean)

Convolution from [Graph Networks as a Universal Machine Learning Framework for Molecules and Crystals](https://arxiv.org/pdf/1812.05055.pdf)
using GraphNeuralNetworks, Flux, Statistics
using GraphNeuralNetworks: aggregate_neighbors

struct MEGNetConv <: GNNLayer

Flux.@functor MEGNetConv

function MEGNetConv(ch::Pair{Int,Int}; aggr=mean)
    nin, nout = ch 
    ϕe = Chain(Dense(3nin, nout, relu),
               Dense(nout, nout))

    ϕv = Chain(Dense(nin + nout, nout, relu),
               Dense(nout, nout))

    MEGNetConv(ϕe, ϕv, aggr)

function (m::MEGNetConv)(g::GNNGraph, x::AbstractMatrix, e::AbstractMatrix)
    ē = apply_edges(g, xi=x, xj=x, e=e) do xi, xj, e
        m.ϕe(vcat(xi, xj, e))

    xᵉ = aggregate_neighbors(g, m.aggr, ē)

    x̄ = m.ϕv(vcat(x, xᵉ))

    return x̄, ē

g = rand_graph(10, 40)
x = randn(3, 10)
e = randn(3, 40)
m = MEGNetConv(3=>3)
x̄, ē = m(g, x, e)

Aggregation operations that are not +, max, min, mean are not supported yet.
For using the conv layer inside a large model you may want to refer to the
Explicit modeling section of the docs.

Let me know if you need any features to be added to GNN.jl!
I created a PR with the MEGNet layer


@CarloLucibello Thank you! I had seen the apply_edges function in the documentation but hadn’t fully internalized what it did.