In the supervised graph problem, you train a neural network that maps entire graph instances to target values. The GeometricFlux.jl package has great functionality for training graph neural networks; however, all of the examples (https://github.com/FluxML/GeometricFlux.jl/tree/master/examples) focus on the problem of making node-level predictions (rather than graph-level predictions).
Is there a preferred syntax for using GeometricFlux.jl models to make predictions over minibatches of graphs? A MWE below gives how I’ve been handling this, but I’m not sure if I’m making some dumb or inefficient design decision. My key design detail is probably the definition of a (f::Chain)(x::AbstractVector{FeaturedGraph})
method.
I compare this to the dgl.batch
function in Python’s DGL (Training a GNN for Graph Classification — DGL 0.9 documentation), which explicitly constructs a graph minibatch.
Thanks!
using Flux, GeometricFlux, LightGraphs
# generate a graph with node features
function get_graph()
mg = SimpleGraph()
add_vertices!(mg, 3)
add_edge!(mg, 1, 2)
add_edge!(mg, 2, 3)
adj_mat = Float32.(adjacency_matrix(mg))
FeaturedGraph(adj_mat; nf=ones(Float32, 2, nv(mg)))
end
# a method for making predictions over a minibatch of `FeaturedGraph` datapoints
(f::Chain)(x::AbstractVector{FeaturedGraph}) = Flux.stack(f.(x), 1) |> vec
# define a simple graph neural network
function get_model(input_dim=2, hidden_dim=3, output_dim=1)
Chain(GraphConv(input_dim => hidden_dim, tanh), # x_i = tanh( Θ_1 x_i + ∑_{j in N(i)} Θ_2 x_j )
node_feature,
x -> sum(x, dims=2), # sum aggregation
Dense(hidden_dim, output_dim))
end
model = get_model()
# standard loss function for graph regression
loss(x, y) = Flux.Losses.mse(model(x), y)
# a single data point
graph = get_graph()
target = 0.5f0
# minibatch of data points
graphs = FeaturedGraph[graph, graph, graph]
targets = [target, target, target]
# evaluate model on minibatch
loss(graphs, targets)