NODE Training - Performance w/ Lux vs torchdiffeq

So I’ve been training a neural ODE in Python using PyTorch and torchdiffeq. I was going to start with Julia and Lux but I ended up needing to train some upstream stuff (pretrained CNN backbones and custom PyTorch layers) that wouldn’t be trivial to port over first. The good news is that I actually managed to fit a model that solved my problem and deployed it with Lux and OrdinaryDiffEQ and it’s working great. The bad news is it took nearly a week to fit which makes it extremely difficult to iterate on improvements.

One of the biggest problems I am running into there is an apparent CPU bottleneck during integration. GPU usage is constantly around 20% on an nVidia A6000, and I’m training on four of them and it still really slow. It’s so bad I looked at running multiple replicas on the same GPU, but it wasn’t supported with the framework I built everything in.

Before I start porting the Python parts to Julia I’m curious to know if anyone has experience training models with both torchdiffeq and Lux w/ Reactant or Lux in general and could comment on whether I could expect significant speedups. The ANODE I’m fitting is pretty standard, approximately Dense(512, 256) -> Tanh() -> Dense(256, 512). I’m not using the adjoint method, just standard differentiation through the solver. The underlying family of ODE I’m trying to learn to learn isn’t stiff at all. If I go ahead with porting all of my code over there would also be a fairly lightweight CNN backbone, but I’d expect the time from that isn’t significant compared to the time spent integrating the ODE.

1 Like

Not sure about the perf, however you will be able to choose the adjoint method much better though.
As for the DL part you can expect similar perf to jax which itself is known faster than pytorch in a lot of cases, the hardest thing I think to get into is the ScimlStructure interface if you need it (to avoid global non const variables and be able to tell ScimlSensitivity what you want to diff in the parameter struct) I really hope the interface changes one day.
You will like the control you have over well everything compared to pytorch it does mean a bigger learning curve but its worth it.

1 Like

cc @avikpal

Assuming the CPU code is not the core computation (but say setup, julia runtime calls, JIT, inference, etc) Reactant should be able to signficiantly reduce this (while also providing additional improvements from the core computation side as well). For example see Lux.jl/perf at main · LuxDL/Lux.jl · GitHub (which shows ~an order of magnitude boost using Reactant vs just vanilla CUDA.jl).

Other fun things of note (trying to make your life easier), we can import existing jax code and call it from julia, fast, via Reactant (and vice versa). This should be extensible to pytorch as well, but we haven’t done that work yet. Let me know if you’re interested and want to help!

1 Like

Being able to use vision backbones from timm and all of the custom layers I have developed over the years would significantly reduce the barrier to entry. That’s pretty much the main thing stopping me from moving over to working in Julia. The other thing would be a way to deploy the model using something like nVidia Triton Server, but if I’m reading correctly we should be able to go from Reactant → StableHLO → Jax → Tensorflow SavedModel?

Back to being able to use my PyTorch code, I assume the end goal is producing something like a torchscript trace but it’s StableHLO? I’ve worked with torchscript, onnx, and TensorRT so I understand the general principal and how to do things like replacing unsupported ops with equivalent from the opset I’m targeting. This sounds like it could be something I could help with.

Reactant can already go directly to a tensorflow saved model!: Serialization | Reactant.jl

And yeah essentially all we need is a stablehlo for whatever foreign code is desired to be called. So if we can get that from the pytorch code, then it can just be hooked in via Reactant’s hlo_call FFI. See Reactant.jl/ext/ReactantPythonCallExt/pycall.jl at 13771b008e590945717f986b9e55ab2f466959a6 · EnzymeAD/Reactant.jl · GitHub for how its used for jax interop.

1 Like

There’s also Serialization | Lux.jl Docs which is a thin wrapper over the TF serialization in reactant but specialized for lux models

2 Likes

If anyone is wondering how to use a PyTorch model with Reactant.jl, I managed to load ResNet50 from timm as shown in the example here: PyTorch StableHLO Support · Issue #2065 · EnzymeAD/Reactant.jl · GitHub

The main thing remaining now is to do real world testing + work out some interface issues.

I’m going ahead and porting over my training loop to Julia now. Hopefully I will be able to get a large speedup for ODE integration on the basis of not spending tons of time in the Python interpreter every single solver step.

3 Likes