SimpleChains.jl extension to Graph Neural Networks

So yesterday I attended the Copenhagen julia meetup and saw a new Neural network toolbox called SimpleChains.jl which aims to be fast for small neural networks basically. Often Graph Neural Networks are also quite small in number of parameters and I was wondering if there are lessons from SimpleChains we can use to make our Graph Neural Network implementations faster? I’m thinking primarily of GeometricFlux.jl and GraphNeuralNetworks.jl. To me it seems as though they are based directly on Flux and as such will suffer from the same “small network” performance issues as SimpleChains.jl aims to solve. Not sure if I’m right about this and happy to be corrected.

Thoughts?

1 Like

Open an issue. I think small graphs are a great use case for this, since indeed exploiting their sparsity makes them memory bound and good on CPUs. Let’s discuss with @Elrod

5 Likes

Sure, PRs are welcome.
I can also add more documentation describing an interface, but it may be helpful to flesh it out, especially with guidance from folks who have more experience doing ML work than I do.