Local edge feature aggregation in GraphNeuralNetworks.jl

I’m looking at a broad class of graph neural network update operations defined as follows:

for each edge `k` connecting vertices `s` and `r`
    ē_k = ϕ_e(concat(v_s, v_r, e_k))

for each node `i`
    v_{i, e} = Pool({ē_k : k ∈ E(i)})
    v̄_i = ϕ_v(concat(v_{i, e}, v_i))

where e_k is the embedding of edge k, and ē_k is its update, v_i is the embedding of node i and v̄_i is its update, E(i) is the set of edges that lead into node i, and ϕ_e and ϕ_v are MLPs. This update is similar to one used for materials modeling here: https://doi.org/10.1021/acs.chemmater.9b01294

Is the functionality needed for an update like this currently available in GraphNeuralNetworks.jl? If not, are there any pointers on what functions would be needed to implement it?

Thanks!

1 Like

All the ingredients should be in place. You can define a convolution like that as

"""
    MEGNetConv(in => out; aggr=mean)

Convolution from [Graph Networks as a Universal Machine Learning Framework for Molecules and Crystals](https://arxiv.org/pdf/1812.05055.pdf)
paper.
"""
using GraphNeuralNetworks, Flux, Statistics
using GraphNeuralNetworks: aggregate_neighbors

struct MEGNetConv <: GNNLayer
    ϕe
    ϕv 
    aggr
end

Flux.@functor MEGNetConv

function MEGNetConv(ch::Pair{Int,Int}; aggr=mean)
    nin, nout = ch 
    ϕe = Chain(Dense(3nin, nout, relu),
               Dense(nout, nout))

    ϕv = Chain(Dense(nin + nout, nout, relu),
               Dense(nout, nout))

    MEGNetConv(ϕe, ϕv, aggr)
end

function (m::MEGNetConv)(g::GNNGraph, x::AbstractMatrix, e::AbstractMatrix)
    ē = apply_edges(g, xi=x, xj=x, e=e) do xi, xj, e
        m.ϕe(vcat(xi, xj, e))
    end

    xᵉ = aggregate_neighbors(g, m.aggr, ē)

    x̄ = m.ϕv(vcat(x, xᵉ))

    return x̄, ē
end

g = rand_graph(10, 40)
x = randn(3, 10)
e = randn(3, 40)
m = MEGNetConv(3=>3)
x̄, ē = m(g, x, e)

Aggregation operations that are not +, max, min, mean are not supported yet.
For using the conv layer inside a large model you may want to refer to the
Explicit modeling section of the docs.

Let me know if you need any features to be added to GNN.jl!
I created a PR with the MEGNet layer

4 Likes

@CarloLucibello Thank you! I had seen the apply_edges function in the documentation but hadn’t fully internalized what it did.