NeuralOperators.jl Performance (compared with Python)

Hi, I am training a Neural Operator to learn a map between a two dimensional domain. For Julia, this is a MWE of the problem I am considering:

# Imports.
using NeuralOperators
using Flux
using FluxTraining
using MLUtils

# Generate data.
N_samples = 5_000
xdata = Float32.(rand(1, 16, 16, N_samples))
ydata = Float32.(rand(1, 16, 16, N_samples))

# Model parameters.
optimiser = Flux.Optimiser(WeightDecay(1.0f-4), Flux.Adam(5.0f-4))

# DataLoaders.
data_train, data_test = splitobs((xdata, ydata), at = 0.9)
loader_train, loader_test = DataLoader(data_train, batchsize = 100), DataLoader(data_test, batchsize = 100)
data = collect.((loader_train, loader_test))

chs =  (1, 32, 32, 32, 32, 32, 64, 1)
model = FourierNeuralOperator(ch = chs, modes = (8, 8), Οƒ = gelu)
learner = Learner(model, data, optimiser, lβ‚‚loss)

for k in range(1, 100)
  epoch!(learner, TrainingPhase(), learner.data.training)
end

It takes ~30-40s to train an epoch (even when using julia --threads=auto). In Python, the equivalent problem takes almost an order of magnitude less time while maintaining the same error by epoch.

How can I get closer to Python’s performance in Julia?

Any help will be appreciated. Thanks!

Try putting the code into a function.

in this case I think most of the time is spent inside epoch! call so not sure it will help much

Already tried. And as @adienes said, there is not really any change.

Iβ€˜d say this is a perfect use case for a profiler. Just slam @ profview (or similar, without space, I’m on phone) in front of the loop and see where most time is spent. Have you checked how much of this time is compilation time?

Yes, already did thanks to a suggestion given in Slack a few hours ago (by Chris). It turned out to be Tullio, a package used to perform tensor operations with Einstein notation. I changed it to a standard matrix multiplication, and now a significant amount of time is spent doing that.

Here is the new profiling.

PS: w.r.t. compilation time, I ran the function before to avoid that.

We could try increasing the number of Open Blas threads or using MKL. You have not shown us which Python code you are comparing to, so this is difficult to compare apples to apples.

1 Like

Longer discussion is on Slack. Basically, it comes down to the fact that Tullio.jl reverse passes are very slow and allocate a lot, so it’s much faster when replaced with matrix multiplications. However, this algorithm really shouldn’t be using matmuls for this operation. Quoting from the Slack:

In theory it could be changed to conv, the problem is that the X I mentioned before is the truncated Fourier transform of the series. One would need some additional process, like applying FFT, truncating, applying inverse FFT and the conv.

So using conv isn’t great either.

The best thing here would be to have a good einsum operation handle this. Since Tullio.jl is well-optimized in the forward pass and @mcabbott works on automatic differentiation, I presume this unoptimized behavior is likely just something that was overlooked and fixable. Making Tullio.jl better instead of dumping it is probably the best option IMO.

1 Like

(Wouldn’t have necro-posted but this post shows up quite frequently when searching for neural operators in Julia)

With our in-progress rewrite of NeuralOperators using Lux in GitHub - LuxDL/NeuralOperators.jl (part of GSoC via SciML), the performance significantly improves: (there are still some bottlenecks we are looking into, but currently DeepONets outperform the Pytorch version using deepxde and FNOs are much faster than the older Flux NeuralOperators version)

  1. CPU Version: Forward Pass is down by 3x; Backward Pass improves by ~2x
  2. CUDA Version: Time per gradient call goes down by 2x

CPU Code

using NeuralOperators, Random
# using NeuralOperators: NeuralOperators, Flux # <-- Flux Version
using BenchmarkTools

N_samples = 128
xdata = rand(Float32, 1, 16, 16, N_samples);

lux_fno = FourierNeuralOperator(; Οƒ=gelu, chs=(1, 32, 32, 32, 32, 32, 64, 1), 
    modes=(8, 8))
ps, st = Lux.setup(Xoshiro(), lux_fno)

# flux_fno = NeuralOperators.FourierNeuralOperator(;
#     ch=(1, 32, 32, 32, 32, 32, 64, 1), modes=(8, 8), Οƒ=gelu)

# @benchmark $flux_fno($xdata)
# BenchmarkTools.Trial: 14 samples with 1 evaluation.
#  Range (min … max):  309.223 ms … 577.339 ms  β”Š GC (min … max): 2.81% … 29.22%
#  Time  (median):     341.670 ms               β”Š GC (median):    2.04%
#  Time  (mean Β± Οƒ):   376.467 ms Β±  85.218 ms  β”Š GC (mean Β± Οƒ):  4.56% Β±  7.43%

#   β–ˆβ– β–ˆβ– ▁ ▁ β–ˆ              β–ˆ                          ▁       ▁  
#   β–ˆβ–ˆβ–β–ˆβ–ˆβ–β–ˆβ–β–ˆβ–β–ˆβ–β–β–β–β–β–β–β–β–β–β–β–β–β–β–ˆβ–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–ˆβ–β–β–β–β–β–β–β–ˆ ▁
#   309 ms           Histogram: frequency by time          577 ms <

#  Memory estimate: 216.32 MiB, allocs estimate: 972.

@benchmark Lux.apply($lux_fno, $xdata, $ps, $st)
# BenchmarkTools.Trial: 32 samples with 1 evaluation.
#  Range (min … max):  102.107 ms … 316.077 ms  β”Š GC (min … max):  1.12% … 52.78%
#  Time  (median):     129.961 ms               β”Š GC (median):     3.90%
#  Time  (mean Β± Οƒ):   156.749 ms Β±  57.966 ms  β”Š GC (mean Β± Οƒ):  20.78% Β± # 19.24%
# 
#       β–‚ β–ˆβ–‚  β–‚                                                    
#   # β–…β–…β–ˆβ–β–ˆβ–ˆβ–ˆβ–ˆβ–…β–ˆβ–ˆβ–β–β–β–β–β–β–β–…β–β–β–β–β–β–…β–β–β–β–β–β–β–β–…β–β–β–β–…β–β–β–β–β–…β–β–…β–β–β–…β–β–…β–β–β–β–β–β–β–β–β–β–# ▁▅ ▁
#   102 ms           Histogram: frequency by time          316 ms <
# 
#  Memory estimate: 188.18 MiB, allocs estimate: 644.

loss(m, x) = sum(abs2, m(x))
loss(m, x, ps, st) = sum(abs2, first(m(x, ps, st)))

# @benchmark Zygote.gradient($loss, $flux_fno, $xdata)
# BenchmarkTools.Trial: 5 samples with 1 evaluation.
#  Range (min … max):  866.505 ms …    1.541 s  β”Š GC (min … max):  0.00% … 12.45%
#  Time  (median):        1.154 s               β”Š GC (median):    16.62%
#  Time  (mean Β± Οƒ):      1.207 s Β± 253.181 ms  β”Š GC (mean Β± Οƒ):  13.15% Β± 13.80%

#   β–ˆ                      β–ˆ β–ˆ                 β–ˆ                β–ˆ  
#   β–ˆβ–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–ˆβ–β–ˆβ–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–ˆβ–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–ˆ ▁
#   867 ms           Histogram: frequency by time          1.54 s <

#  Memory estimate: 467.05 MiB, allocs estimate: 3982.

@benchmark Zygote.gradient($loss, $lux_fno, $xdata, $ps, $st)
# BenchmarkTools.Trial: 11 samples with 1 evaluation.
#  Range (min … max):  358.702 ms … 543.770 ms  β”Š GC (min … max):  1.55% … 29.35%
#  Time  (median):     468.138 ms               β”Š GC (median):    31.51%
#  Time  (mean Β± Οƒ):   457.426 ms Β±  66.325 ms  β”Š GC (mean Β± Οƒ):  25.74% Β± 13.01%
# 
#   β–ˆ               ▁    ▁ ▁            ▁▁             β–ˆ      ▁ ▁  
#   # β–ˆβ–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–ˆβ–β–β–β–β–ˆβ–β–ˆβ–β–β–β–β–β–β–β–β–β–β–β–β–ˆβ–ˆβ–β–β–β–β–β–β–β–β–β–β–β–β–β–ˆβ–β–β–β–β–β–β–ˆ# β–β–ˆ ▁
#   359 ms           Histogram: frequency by time          544 ms <
# 
#  Memory estimate: 478.67 MiB, allocs estimate: 2637.

GPU (CUDA 1650Ti) Code

using NeuralOperators, Random
# using NeuralOperators: NeuralOperators, Flux
using LuxCUDA, BenchmarkTools

CUDA.allowscalar(false)

gdev = gpu_device()

N_samples = 128
xdata = rand(Float32, 1, 16, 16, N_samples)|> gdev;

# flux_fno = NeuralOperators.FourierNeuralOperator(;
#     ch=(1, 32, 32, 32, 32, 32, 64, 1), modes=(8, 8), Οƒ=gelu) |> Flux.gpu

lux_fno = FourierNeuralOperator(; Οƒ=gelu, chs=(1, 32, 32, 32, 32, 32, 64, 1), 
    modes=(8, 8))
ps, st = Lux.setup(Xoshiro(), lux_fno) |> gdev

# @benchmark CUDA.@sync $flux_fno($xdata)
# BenchmarkTools.Trial: 348 samples with 1 evaluation.
#  Range (min … max):  10.444 ms … 191.160 ms  β”Š GC (min … max): 26.16% … 88.66%
#  Time  (median):     12.158 ms               β”Š GC (median):     0.00%
#  Time  (mean Β± Οƒ):   14.303 ms Β±  12.171 ms  β”Š GC (mean Β± Οƒ):   6.34% Β±  9.47%

#     β–ˆβ–β–‚                                                         
#   β–„β–‡β–ˆβ–ˆβ–ˆβ–„β–„β–„β–…β–…β–„β–…β–„β–„β–„β–ƒβ–„β–ƒβ–β–β–β–‚β–β–‚β–‚β–‚β–‚β–‚β–‚β–β–‚β–β–‚β–β–‚β–‚β–β–β–β–β–β–β–β–β–‚β–β–β–β–β–β–β–‚β–β–β–β–β–β–β–β–‚ β–ƒ
#   10.4 ms         Histogram: frequency by time         33.3 ms <

#  Memory estimate: 99.48 KiB, allocs estimate: 3402.

@benchmark CUDA.@sync Lux.apply($lux_fno, $xdata, $ps, $st)
# BenchmarkTools.Trial: 604 samples with 1 evaluation.
#  Range (min … max):  6.271 ms … 24.575 ms  β”Š GC (min … max): 0.00% … 14.43%
#  Time  (median):     6.486 ms              β”Š GC (median):    0.00%
#  Time  (mean Β± Οƒ):   8.255 ms Β±  3.474 ms  β”Š GC (mean Β± Οƒ):  2.85% Β±  4.21%
# 
#   β–ˆβ–ƒβ–ƒ     ▁▃▂ ▁                                               
#   β–ˆβ–ˆβ–ˆβ–‡β–β–β–β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–†β–‡β–†β–†β–…β–†β–…β–…β–…β–„β–…β–…β–β–β–β–β–β–β–β–„β–β–β–β–β–β–β–†β–ˆβ–ˆβ–‡β–„β–„β–„β–„β–„β–„β–β–„β–…β–„β–„β–„β–β–β–… β–‡
#   6.27 ms      Histogram: log(frequency) by time     21.7 ms <
# 
#  Memory estimate: 60.91 KiB, allocs estimate: 2539.

loss(m, x) = sum(abs2, m(x))
loss(m, x, ps, st) = sum(abs2, first(m(x, ps, st)))

# @benchmark CUDA.@sync Zygote.gradient($loss, $flux_fno, $xdata)
# BenchmarkTools.Trial: 109 samples with 1 evaluation.
#  Range (min … max):  37.156 ms … 174.007 ms  β”Š GC (min … max): 0.00% … 82.40%
#  Time  (median):     44.307 ms               β”Š GC (median):    0.00%
#  Time  (mean Β± Οƒ):   45.864 ms Β±  12.516 ms  β”Š GC (mean Β± Οƒ):  5.23% Β±  8.41%

#                                 β–β–ˆβ–β–…β–…β–ˆβ–ˆ ▂▁▄                     
#   β–ƒβ–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–ˆβ–…β–ƒβ–…β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–†β–ˆβ–ˆβ–ˆβ–…β–…β–β–β–β–β–†β–β–ƒβ–ƒβ–β–β–ƒβ–ƒβ–ˆβ–…β–β–ƒβ–† β–ƒ
#   37.2 ms         Histogram: frequency by time         49.2 ms <

#  Memory estimate: 640.98 KiB, allocs estimate: 11110.

@benchmark CUDA.@sync Zygote.gradient($loss, $lux_fno, $xdata, $ps, $st)
# BenchmarkTools.Trial: 303 samples with 1 evaluation.
#  Range (min … max):  15.836 ms …  24.553 ms  β”Š GC (min … max): 0.00% … 11.17%
#  Time  (median):     16.253 ms               β”Š GC (median):    0.00%
#  Time  (mean Β± Οƒ):   16.511 ms Β± 792.804 ΞΌs  β”Š GC (mean Β± Οƒ):  4.25% Β±  7.07%
# 
     β–…β–…β–‡β–ˆβ–‚                                                      
#   β–…β–†β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–„β–ƒβ–‚β–ƒβ–ƒβ–…β–„β–†β–„β–…β–„β–…β–ƒβ–ƒβ–ƒβ–ƒβ–ƒβ–‚β–‚β–ƒβ–‚β–‚β–β–‚β–β–β–β–‚β–β–‚β–β–ƒβ–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–‚β–β–β–‚ β–ƒ
#   15.8 ms         Histogram: frequency by time         19.7 ms <

#  Memory estimate: 302.09 KiB, allocs estimate: 8704.
6 Likes