Going beyond MNIST with Flux/Zygote and CUDA and Graph Neural Networks

Hi,
I normally try to be very specific in titles, but as good practice in some ML communities I am catchy this time.

I want to learn more about how ML works in Julia and I worked through some of the examples in Flux’s ModelZoo. Now, I want to implement a Graph NN, which takes in a (small) graph, with parameters on each vertex and predicts another graph with the same, lets call it topology, but different parameters.

BACKGROUND: I want to teach a graph NN to predict the movement of physical, classical entities with forces between them, e.g. planets with gravitational force, classical charged particles with coulomb forces, etc. The idea comes from this paper [2006.11287] Discovering Symbolic Models from Deep Learning with Inductive Biases. The NN’s input is a state of the system, e.g. position and velocity, and mass/charge of particles and the desired output is the next state of the system, according to a small discretization of time.

The graph NN consists of two MLP (multi-layer perceptron) NNs, one NN which predicts messages/edges between vetices and the other predicts the new state/parameters of a vertex given this vertix previous state and incoming messages/connecting edges. Please refer to the mentioned paper for detailed information. I guess, in the ML community graph NN networks are (broadly?) known.

I am aware of the Julia package GeometricFlux.jl, and I got some inspiration there. The reason why I do not use it (yet), is that I want to learn how ML in Julia works.

A minimal working example, with random data, looks like this


N = 5 # number of vertices
g = complete_graph(N) # this comes from SimpleGraph.jl
nvp = 5 # number of vertex parameters (e.g. 2d position + velocity + mass)
nep = 100 # number of edge parameters

# MLP predicting vertices    
v_nn = Chain(
        Dense(nvp + nep, 300, Flux.relu),
        Dense(300, 300, Flux.relu),
        Dense(300, nvp)
    ) |> gpu

# MLP predicting edges
e_nn = Chain(
        Dense(2 * nvp, 300, Flux.relu),
        Dense(300, 300, Flux.relu),
        Dense(300, nep)
    ) |> gpu
    

# input
# s_vp  source vertex parameters
# d_vp  destination vertex parameters
function message(s_vp, d_vp)
        nn_input = vcat(s_vp, d_vp)
        e_nn(nn_input)
end
 
# input
# vps  vertex parameters, outer array are vertices, inner array are the corresponding parameters
function predict_edges(vps::AbstractVector{<:AbstractVector})
        
        new_edge_parameters = Dict()
       
        for e in edges(g)
            e = Tuple(e)
            s = e[1]  # source
            d = e[2]  # destination 
            push!(new_edge_parameters, e => message(vps[s], vps[d]))
        end
        
        new_edge_parameters
end
    
    
aggregate_edge_parameters(eps) = sum(eps)
            
# input
# v predict the parameters of vertex v
# vps vertex parameters
# eps messages/edge parameters    
function predict_vertex(v, vpsAbstractVector{<:AbstractVector}, eps::Dict)
        neighbor_edge_parameters = [eps[vn < v ? (vn, v) : (v,vn)] for vn in neighbors(g, v)]
            
        aggr_ep = aggregate_edge_parameters(neighbor_edge_parameters)
        nn_input = vcat(vps[v], aggr_ep) 
        v_nn(nn_input)
end

function loss(vps::AbstractVector{<:AbstractVector}, vps_next::AbstractVector{<:AbstractVector}
        )

        eps = predict_edges(vps)
        
        l_v = zero(Float32)
        for v in vertices(g)
            vp_nn = predict_vertex(v, vps, eps)
            l_v += norm(vps_next[v] .- vp_nn, 1)
        end
        l_v /= nv(g)
        l_v
        
        #l_e = norm(values(eps), 1)
        
        l_v  #+ l_e
end

data = [([rand(Float32, nvp) for _ in 1:nv(g)], [rand(Float32, nvp) for _ in 1:nv(g)])] |> gpu

Flux.train!(loss, params(v_nn, e_nn), data, Flux.ADAM(0.05))
    

I got this working and working on a GPU, which I had difficulties with. The main difficulty was creating the buffer structure for the edges. Zygote.jl does not like arrays (except Zygote.Buffer) and CUDA.jl wants CuArrays.

Training with ~10 000 data points, i.e. length(data) = 10 000 in the MWE, takes a long time (I did not run in global scope ofc). The memory fills up quickly, I guess, because I create a lot of CuArrays, everytime I calculate the messages, and the GPU load is at ~10-15% measured with GPU-Z. I do not know how accurate that is. Further problems with this code

  • not fast
  • the dictionary is created in main memory, not on the GPU
  • only works with scalarindex(false), because of calculating the norm in the loss function
  • errors when outcommenting the l_e… lines. I already filed an issue at Zygote.jl about that, see https://github.com/FluxML/Zygote.jl/issues/760.

I would like to address the above issues, especially how to correctly create buffer structures, which work with CUDA and Zygote, and I’d like to get some feedback/input on how to efficiently write more complicated loss functions than mse(model(x), y).

I know that this is a very broad question. If there is any good resource, where I can read about this, please point me to it. If something is not clear, I will further explain.

This is just a pretty expensive way of doing this because it doesn’t use the physical structure of the equations. They systems you’re describing live on a symplectic manifold, so describing a second order ODE and solving it with a symplectic integrator will have a lot better extrapolation properties. If you want to use a graph to represent the dynamics, you could just mix https://diffeqflux.sciml.ai/dev/examples/second_order_neural/ with https://diffeqflux.sciml.ai/dev/examples/neural_gde/ and you’d have something that is constrained to be physically-realizable unlike that network.

You can then mix it with https://datadriven.sciml.ai/dev/sparse_identification/sindy/ as is done in https://arxiv.org/abs/2001.04385 to transform the result back into a symbolic description of the Hamiltonian. See this workshop for a full walk-through of that kind of methodology.

In this way you’d get higher order and enforce conservation of energy which would constrain the solution from blowing up over long time intervals.

1 Like

Hi @ChrisRackauckas!
Thanks for the quick response! I am aware that this approach is not state of the art. This is a toy case. Not only is the NN not aware of energy conservation but worse, constants as mass/charge/ etc. are parameters like position, velocity/momentum, so they won’t be conserved for little-trained networks either. Although, some constraints are worked into the design, e.g. that forces are adding up to an effective force, particle number is conserved, etc.

I am just curious, how far one can get. Having a message/edge-regularization term in the loss function enforces sparsity in the aggregated messages/edges and, in principle, should make the edge-NN (e_nn), “learn” the forces. One can indeed show, that’s in the paper, that the predicted edge parameters coincide with the forces, up to rotations.

Moving away from the toy case, one could see the edge regularization as a dimensionality reduction of interactions of vertices (in the mentioned paper they used it on dark matter simulations and predict formulas using symbolic regression).

My main goal, was to get feedback on implementing this in Julia.

I watched your talk at JuliaCon live and am reading your paper as well. Looking forward to doing more in this direction.

Have you tried replicating the python implementation at https://github.com/MilesCranmer/symbolic_deep_learning first? On first impression, it appears far more AD (and ML framework) friendly. No manual loops, limited mutation, etc. Zygote might be more flexible than PyTorch’s AD for your model and loss formulation above, but any practical reverse-mode system is unfortunately going to struggle with this kind of approach.

I did not implement their code in python. They are mainly using geometric_torch, which is hiding all the AD, loops, allocations, etc. The equivalent in Julia would be to use GeometricFlux.jl, I guess. But I want to understand how this works in a bit more detail.

Why is my approach not AD friendly and how could I make it more AD friendly?