Minibatches of graphs in GeometricFlux.jl

In the supervised graph problem, you train a neural network that maps entire graph instances to target values. The GeometricFlux.jl package has great functionality for training graph neural networks; however, all of the examples (GeometricFlux.jl/examples at master · FluxML/GeometricFlux.jl · GitHub) focus on the problem of making node-level predictions (rather than graph-level predictions).

Is there a preferred syntax for using GeometricFlux.jl models to make predictions over minibatches of graphs? A MWE below gives how I’ve been handling this, but I’m not sure if I’m making some dumb or inefficient design decision. My key design detail is probably the definition of a (f::Chain)(x::AbstractVector{FeaturedGraph}) method.

I compare this to the dgl.batch function in Python’s DGL (Training a GNN for Graph Classification — DGL 0.8 documentation), which explicitly constructs a graph minibatch.


using Flux, GeometricFlux, LightGraphs

# generate a graph with node features
function get_graph()
    mg = SimpleGraph()
    add_vertices!(mg, 3)
    add_edge!(mg, 1, 2)
    add_edge!(mg, 2, 3)
    adj_mat = Float32.(adjacency_matrix(mg))
    FeaturedGraph(adj_mat; nf=ones(Float32, 2, nv(mg)))

# a method for making predictions over a minibatch of `FeaturedGraph` datapoints
(f::Chain)(x::AbstractVector{FeaturedGraph}) = Flux.stack(f.(x), 1) |> vec

# define a simple graph neural network
function get_model(input_dim=2, hidden_dim=3, output_dim=1)
    Chain(GraphConv(input_dim => hidden_dim, tanh),  # x_i = tanh( Θ_1 x_i + ∑_{j in N(i)} Θ_2 x_j )
          x -> sum(x, dims=2),  # sum aggregation
          Dense(hidden_dim, output_dim))

model = get_model()
# standard loss function for graph regression
loss(x, y) = Flux.Losses.mse(model(x), y)

# a single data point
graph = get_graph()
target = 0.5f0

# minibatch of data points
graphs = FeaturedGraph[graph, graph, graph]
targets = [target, target, target]

# evaluate model on minibatch
loss(graphs, targets)

GeometricFlux doesn’t support graph batching yet (implement graph concatenation · Issue #218 · FluxML/GeometricFlux.jl · GitHub) and graph batch predictions.
You can find these features in GraphNeuralNetworks.jl, e.g. see
this example of graph classification
GraphNeuralNetworks.jl/graph_classification_tudataset.jl at master · CarloLucibello/GraphNeuralNetworks.jl · GitHub

1 Like

Thank you! I had not seen GraphNeuralNetworks.jl before, so I’ll take a look at it.