Using trained neural networks inside GPU computations

Hi all,
We run simulations which rely on iterative solvers on GPU using tools like ParallelStencil or Chmy.
Part of computations involve the evaluation of heavy non-linear functions that can be shortcut by (1) training a neural network and (2) using it to predict the non-linearity during the computations.
What we need is the fastest/lightest neural network model that can be called after on either CPU or GPU.
Which is the most suitable Julia package for this kind of purpose?
Currently, we have played around with Flux (for the training only).
Thanks in advance for any recommandations!

2 Likes