(Wouldnβt have necro-posted but this post shows up quite frequently when searching for neural operators in Julia)
With our in-progress rewrite of NeuralOperators using Lux in GitHub - LuxDL/NeuralOperators.jl (part of GSoC via SciML), the performance significantly improves: (there are still some bottlenecks we are looking into, but currently DeepONets outperform the Pytorch version using deepxde and FNOs are much faster than the older Flux NeuralOperators version)
- CPU Version: Forward Pass is down by 3x; Backward Pass improves by ~2x
- CUDA Version: Time per gradient call goes down by 2x
CPU Code
using NeuralOperators, Random
# using NeuralOperators: NeuralOperators, Flux # <-- Flux Version
using BenchmarkTools
N_samples = 128
xdata = rand(Float32, 1, 16, 16, N_samples);
lux_fno = FourierNeuralOperator(; Ο=gelu, chs=(1, 32, 32, 32, 32, 32, 64, 1),
modes=(8, 8))
ps, st = Lux.setup(Xoshiro(), lux_fno)
# flux_fno = NeuralOperators.FourierNeuralOperator(;
# ch=(1, 32, 32, 32, 32, 32, 64, 1), modes=(8, 8), Ο=gelu)
# @benchmark $flux_fno($xdata)
# BenchmarkTools.Trial: 14 samples with 1 evaluation.
# Range (min β¦ max): 309.223 ms β¦ 577.339 ms β GC (min β¦ max): 2.81% β¦ 29.22%
# Time (median): 341.670 ms β GC (median): 2.04%
# Time (mean Β± Ο): 376.467 ms Β± 85.218 ms β GC (mean Β± Ο): 4.56% Β± 7.43%
# ββ ββ β β β β β β
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ β
# 309 ms Histogram: frequency by time 577 ms <
# Memory estimate: 216.32 MiB, allocs estimate: 972.
@benchmark Lux.apply($lux_fno, $xdata, $ps, $st)
# BenchmarkTools.Trial: 32 samples with 1 evaluation.
# Range (min β¦ max): 102.107 ms β¦ 316.077 ms β GC (min β¦ max): 1.12% β¦ 52.78%
# Time (median): 129.961 ms β GC (median): 3.90%
# Time (mean Β± Ο): 156.749 ms Β± 57.966 ms β GC (mean Β± Ο): 20.78% Β± # 19.24%
#
# β ββ β
# # β
β
βββββββ
ββββββββββ
ββββββ
ββββββββ
ββββ
βββββ
ββ
βββ
ββ
ββββββββββ# ββ
β
# 102 ms Histogram: frequency by time 316 ms <
#
# Memory estimate: 188.18 MiB, allocs estimate: 644.
loss(m, x) = sum(abs2, m(x))
loss(m, x, ps, st) = sum(abs2, first(m(x, ps, st)))
# @benchmark Zygote.gradient($loss, $flux_fno, $xdata)
# BenchmarkTools.Trial: 5 samples with 1 evaluation.
# Range (min β¦ max): 866.505 ms β¦ 1.541 s β GC (min β¦ max): 0.00% β¦ 12.45%
# Time (median): 1.154 s β GC (median): 16.62%
# Time (mean Β± Ο): 1.207 s Β± 253.181 ms β GC (mean Β± Ο): 13.15% Β± 13.80%
# β β β β β
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ β
# 867 ms Histogram: frequency by time 1.54 s <
# Memory estimate: 467.05 MiB, allocs estimate: 3982.
@benchmark Zygote.gradient($loss, $lux_fno, $xdata, $ps, $st)
# BenchmarkTools.Trial: 11 samples with 1 evaluation.
# Range (min β¦ max): 358.702 ms β¦ 543.770 ms β GC (min β¦ max): 1.55% β¦ 29.35%
# Time (median): 468.138 ms β GC (median): 31.51%
# Time (mean Β± Ο): 457.426 ms Β± 66.325 ms β GC (mean Β± Ο): 25.74% Β± 13.01%
#
# β β β β ββ β β β
# # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ# ββ β
# 359 ms Histogram: frequency by time 544 ms <
#
# Memory estimate: 478.67 MiB, allocs estimate: 2637.
GPU (CUDA 1650Ti) Code
using NeuralOperators, Random
# using NeuralOperators: NeuralOperators, Flux
using LuxCUDA, BenchmarkTools
CUDA.allowscalar(false)
gdev = gpu_device()
N_samples = 128
xdata = rand(Float32, 1, 16, 16, N_samples)|> gdev;
# flux_fno = NeuralOperators.FourierNeuralOperator(;
# ch=(1, 32, 32, 32, 32, 32, 64, 1), modes=(8, 8), Ο=gelu) |> Flux.gpu
lux_fno = FourierNeuralOperator(; Ο=gelu, chs=(1, 32, 32, 32, 32, 32, 64, 1),
modes=(8, 8))
ps, st = Lux.setup(Xoshiro(), lux_fno) |> gdev
# @benchmark CUDA.@sync $flux_fno($xdata)
# BenchmarkTools.Trial: 348 samples with 1 evaluation.
# Range (min β¦ max): 10.444 ms β¦ 191.160 ms β GC (min β¦ max): 26.16% β¦ 88.66%
# Time (median): 12.158 ms β GC (median): 0.00%
# Time (mean Β± Ο): 14.303 ms Β± 12.171 ms β GC (mean Β± Ο): 6.34% Β± 9.47%
# βββ
# βββββββββ
β
ββ
ββββββββββββββββββββββββββββββββββββββββββββββββ β
# 10.4 ms Histogram: frequency by time 33.3 ms <
# Memory estimate: 99.48 KiB, allocs estimate: 3402.
@benchmark CUDA.@sync Lux.apply($lux_fno, $xdata, $ps, $st)
# BenchmarkTools.Trial: 604 samples with 1 evaluation.
# Range (min β¦ max): 6.271 ms β¦ 24.575 ms β GC (min β¦ max): 0.00% β¦ 14.43%
# Time (median): 6.486 ms β GC (median): 0.00%
# Time (mean Β± Ο): 8.255 ms Β± 3.474 ms β GC (mean Β± Ο): 2.85% Β± 4.21%
#
# βββ βββ β
# ββββββββββββββββββ
ββ
β
β
ββ
β
βββββββββββββββββββββββββββ
ββββββ
β
# 6.27 ms Histogram: log(frequency) by time 21.7 ms <
#
# Memory estimate: 60.91 KiB, allocs estimate: 2539.
loss(m, x) = sum(abs2, m(x))
loss(m, x, ps, st) = sum(abs2, first(m(x, ps, st)))
# @benchmark CUDA.@sync Zygote.gradient($loss, $flux_fno, $xdata)
# BenchmarkTools.Trial: 109 samples with 1 evaluation.
# Range (min β¦ max): 37.156 ms β¦ 174.007 ms β GC (min β¦ max): 0.00% β¦ 82.40%
# Time (median): 44.307 ms β GC (median): 0.00%
# Time (mean Β± Ο): 45.864 ms Β± 12.516 ms β GC (mean Β± Ο): 5.23% Β± 8.41%
# ββββ
β
ββ βββ
# ββββββββββββββββββββββββββββ
ββ
ββββββββββββ
β
ββββββββββββββ
βββ β
# 37.2 ms Histogram: frequency by time 49.2 ms <
# Memory estimate: 640.98 KiB, allocs estimate: 11110.
@benchmark CUDA.@sync Zygote.gradient($loss, $lux_fno, $xdata, $ps, $st)
# BenchmarkTools.Trial: 303 samples with 1 evaluation.
# Range (min β¦ max): 15.836 ms β¦ 24.553 ms β GC (min β¦ max): 0.00% β¦ 11.17%
# Time (median): 16.253 ms β GC (median): 0.00%
# Time (mean Β± Ο): 16.511 ms Β± 792.804 ΞΌs β GC (mean Β± Ο): 4.25% Β± 7.07%
#
β
β
βββ
# β
ββββββββββββββ
ββββ
ββ
βββββββββββββββββββββββββββββββββββββββ β
# 15.8 ms Histogram: frequency by time 19.7 ms <
# Memory estimate: 302.09 KiB, allocs estimate: 8704.