NeuralOperators.jl Performance (compared with Python)

Hi, I am training a Neural Operator to learn a map between a two dimensional domain. For Julia, this is a MWE of the problem I am considering:

# Imports.
using NeuralOperators
using Flux
using FluxTraining
using MLUtils

# Generate data.
N_samples = 5_000
xdata = Float32.(rand(1, 16, 16, N_samples))
ydata = Float32.(rand(1, 16, 16, N_samples))

# Model parameters.
optimiser = Flux.Optimiser(WeightDecay(1.0f-4), Flux.Adam(5.0f-4))

# DataLoaders.
data_train, data_test = splitobs((xdata, ydata), at = 0.9)
loader_train, loader_test = DataLoader(data_train, batchsize = 100), DataLoader(data_test, batchsize = 100)
data = collect.((loader_train, loader_test))

chs =  (1, 32, 32, 32, 32, 32, 64, 1)
model = FourierNeuralOperator(ch = chs, modes = (8, 8), σ = gelu)
learner = Learner(model, data, optimiser, l₂loss)

for k in range(1, 100)
  epoch!(learner, TrainingPhase(), learner.data.training)
end

It takes ~30-40s to train an epoch (even when using julia --threads=auto). In Python, the equivalent problem takes almost an order of magnitude less time while maintaining the same error by epoch.

How can I get closer to Python’s performance in Julia?

Any help will be appreciated. Thanks!

Try putting the code into a function.

in this case I think most of the time is spent inside epoch! call so not sure it will help much

Already tried. And as @adienes said, there is not really any change.

I‘d say this is a perfect use case for a profiler. Just slam @ profview (or similar, without space, I’m on phone) in front of the loop and see where most time is spent. Have you checked how much of this time is compilation time?

Yes, already did thanks to a suggestion given in Slack a few hours ago (by Chris). It turned out to be Tullio, a package used to perform tensor operations with Einstein notation. I changed it to a standard matrix multiplication, and now a significant amount of time is spent doing that.

Here is the new profiling.

PS: w.r.t. compilation time, I ran the function before to avoid that.

We could try increasing the number of Open Blas threads or using MKL. You have not shown us which Python code you are comparing to, so this is difficult to compare apples to apples.

Longer discussion is on Slack. Basically, it comes down to the fact that Tullio.jl reverse passes are very slow and allocate a lot, so it’s much faster when replaced with matrix multiplications. However, this algorithm really shouldn’t be using matmuls for this operation. Quoting from the Slack:

In theory it could be changed to conv, the problem is that the X I mentioned before is the truncated Fourier transform of the series. One would need some additional process, like applying FFT, truncating, applying inverse FFT and the conv.

So using conv isn’t great either.

The best thing here would be to have a good einsum operation handle this. Since Tullio.jl is well-optimized in the forward pass and @mcabbott works on automatic differentiation, I presume this unoptimized behavior is likely just something that was overlooked and fixable. Making Tullio.jl better instead of dumping it is probably the best option IMO.

1 Like