# Minibatches of graphs in GeometricFlux.jl

In the supervised graph problem, you train a neural network that maps entire graph instances to target values. The GeometricFlux.jl package has great functionality for training graph neural networks; however, all of the examples (GeometricFlux.jl/examples at master · FluxML/GeometricFlux.jl · GitHub) focus on the problem of making node-level predictions (rather than graph-level predictions).

Is there a preferred syntax for using GeometricFlux.jl models to make predictions over minibatches of graphs? A MWE below gives how I’ve been handling this, but I’m not sure if I’m making some dumb or inefficient design decision. My key design detail is probably the definition of a `(f::Chain)(x::AbstractVector{FeaturedGraph}) ` method.

I compare this to the `dgl.batch` function in Python’s DGL (Training a GNN for Graph Classification — DGL 0.8 documentation), which explicitly constructs a graph minibatch.

Thanks!

``````using Flux, GeometricFlux, LightGraphs

# generate a graph with node features
function get_graph()
mg = SimpleGraph()
FeaturedGraph(adj_mat; nf=ones(Float32, 2, nv(mg)))
end

# a method for making predictions over a minibatch of `FeaturedGraph` datapoints
(f::Chain)(x::AbstractVector{FeaturedGraph}) = Flux.stack(f.(x), 1) |> vec

# define a simple graph neural network
function get_model(input_dim=2, hidden_dim=3, output_dim=1)
Chain(GraphConv(input_dim => hidden_dim, tanh),  # x_i = tanh( Θ_1 x_i + ∑_{j in N(i)} Θ_2 x_j )
node_feature,
x -> sum(x, dims=2),  # sum aggregation
Dense(hidden_dim, output_dim))
end

model = get_model()
# standard loss function for graph regression
loss(x, y) = Flux.Losses.mse(model(x), y)

# a single data point
graph = get_graph()
target = 0.5f0

# minibatch of data points
graphs = FeaturedGraph[graph, graph, graph]
targets = [target, target, target]

# evaluate model on minibatch
loss(graphs, targets)
``````

GeometricFlux doesn’t support graph batching yet (implement graph concatenation · Issue #218 · FluxML/GeometricFlux.jl · GitHub) and graph batch predictions.
You can find these features in GraphNeuralNetworks.jl, e.g. see
this example of graph classification
GraphNeuralNetworks.jl/graph_classification_tudataset.jl at master · CarloLucibello/GraphNeuralNetworks.jl · GitHub

1 Like

Thank you! I had not seen GraphNeuralNetworks.jl before, so I’ll take a look at it.