Significant compile time latency in Flux with a GAN

I am trying to implement a SRGAN (super-resolution generative adversarial networks) in Flux (I am happy to share once it works… :-))
The generator network is composed of 5 residual blocks (function resblock below).
Unfortunately the time of the first call of Zygote.pullback increases very quickly
with the number of residual blocks, so that I am not able to use the 5
residual blocks as intended. I already commented out any elements-wise
broadcasting (activation functions), but the issue is still present.

number res. block pullback time (1st call) pullback time (2nd call)
0 58.2 s 10.5 s
1 73.33 s 21.2 s
2 94.13 s 39.08 s
3 3264 s (54.4 min) 3240 s (54.0 min)

I am using Julia 1.7.0-rc1 and Flux 0.12.6 and Zygote 0.6.21 on Linux (Intel i9-10900X, NVIDIA GeForce RTX 3080).
The issue is also present in Julia 1.6.1 while Julia 1.7.0-rc1 is a bit faster.
If I start julia with the option -O1, then the compile times become manageable (e.g. 20 seconds for 3 residual blocks) but the runtime performance in general will probably suffer.

Any help would be greatly appreciated.

using Flux
using Dates

channels = 64
in_channels = 3
num_channels = [3,64,64,128,128,256,256,512,512]

# discriminator

function block((in_channels,out_channels); stride = 1, use_batch_norm = true)
    layers = []
    push!(layers, Conv((3,3),in_channels => out_channels,pad = 1,stride = stride))
    if use_batch_norm
        push!(layers, BatchNorm(out_channels))
    #push!(layers, x -> leakyrelu.(x,0.2f0))
    return layers

discriminator = Chain(
    reduce(vcat,[block(num_channels[i] => num_channels[i+1];
            stride = 1 + (i+1) % 2,
            use_batch_norm = i!=1) for i = 1:length(num_channels)-1 ])...,
    Conv((1,1), num_channels[end] => 1024),
    #x -> leakyrelu.(x,0.2f0),
    Conv((1,1), 1024 => 1),
) |> gpu

function resblock(channels)
    return SkipConnection(Chain(
            Conv((3,3),channels => channels, pad=1),
            Conv((3,3),channels => channels, pad=1),
        , +)

function upsample(in_channels, up_scale)
    return [
        PixelShuffle(up_scale),
    ]
end

generator = Chain(
    Conv((9,9),3 => channels, pad = 4),

        # test with different number of residual blocks
        Conv((3,3),channels => channels, pad=1),
    upsample(channels, 2)...,
    upsample(channels, 2)...,
    Conv((9,9),channels => 3,σ, pad=4),
) |> gpu;

hr_images = randn(Float32,88,88,3,32)
lr_images = randn(Float32,22,22,3,32)

hr_images = gpu(hr_images)
lr_images = gpu(lr_images)

# check foreward
@show sum(discriminator(generator(lr_images)))

params_g = Flux.params(generator)

@info "generator $("

# Taking gradient of generator
sum(discriminator(generator(lr_images)))
end
Those are some downright nasty times! Can you also post the output of @time so we can see what the reported compilation time is? I’d also be curious to know about the numbers for running just the forward pass as well.

For two residuals blocks, the very first call to pullback give this output of @time:

94.131946 seconds (62.50 M allocations: 3.273 GiB, 1.09% gc time, 11.29% compilation time)

The following call is :

39.082354 seconds (55.50 k allocations: 4.078 MiB)

In my table before I reported the runtime on second calls (except of 3 residual blocks where it did not work). I updated the table to distinguish between the 1st and 2nd call.

These are the output of the forward call @time discriminator(generator(lr_images)) for 2 residuals blocks:

# first call
 21.147667 seconds (54.87 M allocations: 2.882 GiB, 4.68% gc time, 64.02% compilation time)
# second call
  0.011111 seconds (34.36 k allocations: 1.796 MiB)

Interestingly, the compile time of the forward pass does not change with the number of residual layers.

The latency of the gradients of discriminator and generator individually are both reasonable. It is the compile time of the combined function discriminator(generator(lr_images)) that grows so quickly.

I just let it run overnight and here are compile time for 3 residual blocks (approximately 1 hour):

# forward pass of discriminator(generator(lr_images)) (1st call)
21.046987 seconds (54.87 M allocations: 2.882 GiB, 4.79% gc time, 63.64% compilation time)
# forward pass of discriminator(generator(lr_images)) (2nd call)
  0.011780 seconds (34.65 k allocations: 1.810 MiB)
# backward pass of discriminator(generator(lr_images)) (1st call)
3264.246108 seconds (62.58 M allocations: 3.277 GiB, 0.04% gc time, 0.33% compilation time)
# backward pass of discriminator(generator(lr_images)) (2nd call)
3240.438707 seconds (59.96 k allocations: 4.199 MiB)

Interestingly the runtime of the backward pass for the 1st and 2nd call is almost identical.

I am wondering if in this code, Zygote is also taking the gradients of all parameters in discriminator (even they are not needed when updating the generator):

params_g = Flux.params(generator)
sum(discriminator(generator(lr_images)))
end

The machine has 32 GB of RAM and the RAM utilization of julia was about 5%. I did not observe any memory disk swapping.

On a different machine I tried with julia 1.5.2 and got very good runtime values with 3 residual blocks:

julia> include("/home/ulg/gher/abarth/Julia/share/test_zygote_perf.jl")
# forward pass (1st and 2nd call)
19.787238 seconds (51.32 M allocations: 2.547 GiB, 4.63% gc time)
  0.492617 seconds (439.07 k allocations: 21.413 MiB)
# backward pass (1st and 2nd call)
 28.111849 seconds (51.57 M allocations: 2.611 GiB, 3.74% gc time)
  0.069456 seconds (50.17 k allocations: 2.577 MiB)

That would be a factor of ~50 000 (between julia 1.7.0-rc1/1.6.1 and julia 1.5.2) for the 2nd call of the backward pass!

I’m not seeing these compile times. With 5 blocks:

generator = Chain(
    Conv((9,9),3 => channels, pad = 4),

        # test with different number of residual blocks
        Conv((3,3),channels => channels, pad=1),
    upsample(channels, 2)...,
    upsample(channels, 2)...,
    Conv((9,9),channels => 3,σ, pad=4),
) |> gpu;

# ...

@time sum(discriminator(generator(lr_images)))
@time sum(discriminator(generator(lr_images)))

# ...

@time Flux.pullback(params_g) do
@time Flux.pullback(params_g) do
35.056557 seconds (63.13 M allocations: 3.313 GiB, 2.61% gc time, 31.06% compilation time)
  0.019049 seconds (48.12 k allocations: 1.993 MiB)

 83.980706 seconds (66.88 M allocations: 3.488 GiB, 1.04% gc time)
 23.423331 seconds (83.27 k allocations: 4.966 MiB)

This is on Julia 1.7-rc1 with Flux.jl from master.

Seeing similar numbers as maleadt on both 1.7.0-rc1 and 1.6.2 using Flux 0.12.6

Thank you all for looking into this issue!
Indeed with the master version of Flux, I get similar run times on Julia 1.7.0-rc1 than you (with 5 residual blocks):

# forward pass
 21.221844 seconds (54.54 M allocations: 2.865 GiB, 5.28% gc time, 62.09% compilation time)                                                                           
  0.011881 seconds (35.20 k allocations: 1.833 MiB)                                
# backward pass
116.064788 seconds (62.76 M allocations: 3.287 GiB, 0.87% gc time, 9.09% compilation time)                                                                                   
 53.833062 seconds (59.98 k allocations: 4.688 MiB)   

With Julia 1.5.3 installed on the same hardware and the same script, I get:

# forward pass
 18.232473 seconds (54.21 M allocations: 2.702 GiB, 4.11% gc time)
  0.357820 seconds (492.87 k allocations: 24.360 MiB, 5.85% gc time)
# backward pass
 22.321700 seconds (51.42 M allocations: 2.603 GiB, 3.46% gc time)
  0.048057 seconds (52.73 k allocations: 2.716 MiB)

Calling pullback (2nd time) in Julia 1.7.0-rc1 takes 53.8 s for me, while it takes only 0.048 s on julia 1.5.3.
However, many packages are installed at a different version in julia 1.5.3, e.g. CUDA v2.4.3, Flux v0.12.1, GPUCompiler v0.8.3.

IIRC Flux 0.12.6 and Flux master should have the same compat. Can you run a quick ] up while using the former and see if it makes a difference? Another thing to test would be comparing CPU vs GPU speeds, because compilation times may differ (if only slightly) between the two.

Try using the profiler to figure out what regressed.

It now seems that I do not get these high compile times every time. Even with Flux 0.12.6 I have compile times of the order of minutes, but not always. I am still in front of a julia 1.6.2 session where I got these times (2 forward calls and 2 calls to pullback with 5 residual blocks):

$ /opt/julia-1.6.2/bin/julia /home/abarth/projects/Julia/share/test_zygote_perf5.jl                                                                                   
 18.040173 seconds (44.21 M allocations: 2.507 GiB, 5.14% gc time)                                                                                                    
  0.009322 seconds (30.69 k allocations: 1.739 MiB, 73.32% compilation time)                                                                                          
sum(discriminator(generator(lr_images))) = 0.16050659f0                                                                                                               
[ Info: generator 2021-09-23T21:50:35.220                                                                                                                             
2589.196584 seconds (61.49 M allocations: 3.500 GiB, 0.05% gc time, 0.00% compilation time)                                                                           
2890.574050 seconds (67.31 k allocations: 5.385 MiB, 100.00% compilation time)                                                                                        

I thought to share a screenshot, but that would be silly :slight_smile:. I guess that the issue is not in Flux, but maybe in some of its dependencies.
With Flux.jl#master in julia 1.6.2 I have compile times of at least 17 minutes (still running) with the only change in the manifest file is the version of Flux:

$ diff ~/.julia/environments/v1.6/Manifest.toml ~/.julia/environments/v1.6/Manifest.toml.long-compile-times-julia1.6.2
< git-tree-sha1 = "ecc78fd4d97c11af01c1b73a895529a4f2292045"
< repo-rev = "master"
< repo-url = ""
> git-tree-sha1 = "1286e5dd0b4c306108747356a7a5d39a11dc4080"

I ran the profiler for the pullback call:

# compile
sum(discriminator(generator(lr_images)))
end

# run pre-compiled
sum(discriminator(generator(lr_images)))
end

# run and profile pre-compiled
sum(discriminator(generator(lr_images)))
end

open("profile-julia-$(VERSION)-resblock5.log","w") do f

Most of the time is spent in pullback / _pullback and the line test_zygote_perf5.jl:131 corresponding to line sum(discriminator(generator(lr_images))) within the pullback do-block.

Any advice how to get more detailed information would be very helpful.

julia 1.5.3

julia 1.7.0-rc1

I tried now also julia 1.7.0-rc2 but I also see very long run times (second call of pullback). It takes 116 minutes (6995 seconds) for the default optimization option, but only 0.631694 seconds if I lower the optimization to -O1.

Each of these 3 timings are a call to pullback (for a SRGAN with 5 residual blocks).

$ CUDA_VISIBLE_DEVICES=0 ~/opt/julia-1.7.0-rc2/bin/julia /home/abarth/projects/Julia/share/test_zygote_perf5.jl                                                       
7486.239744 seconds (65.13 M allocations: 3.403 GiB, 0.02% gc time, 100.00% compilation time)                                                                         
6995.621327 seconds (47.47 k allocations: 3.992 MiB, 100.00% compilation time)     
7307.956745 seconds (47.47 k allocations: 3.992 MiB, 100.00% compilation time)     

$ CUDA_VISIBLE_DEVICES=0 ~/opt/julia-1.7.0-rc2/bin/julia -O1 /home/abarth/projects/Julia/share/test_zygote_perf5.jl                                                   
 20.914206 seconds (65.13 M allocations: 3.403 GiB, 6.77% gc time, 99.65% compilation time)                                                                           
  0.631694 seconds (52.85 k allocations: 4.082 MiB, 98.47% compilation time)
  0.632313 seconds (53.18 k allocations: 4.087 MiB, 98.45% compilation time)       

I’m just running your MWE on Julia 1.6.4 with optimization level -O1/-O2 and --trace-compile. I’d like to confirm the high compilation time for -O2 and report that

precompile(Tuple{Type{Zygote.Pullback{Tuple{Main.var"#10#11"}, T} where T}, 
[200 kB type signature cut]

seems to be the culprit. I’m wondering about the type parameter T right now, do not know enough about Flux to check for type instability however.

It just occurred to me that the Julia 1.5 tests are running with CUDA 2, while the latest round are running with CUDA 3 (because it has a lower bound on Julia 1.6).

To have a better chance of isolating the culprit, I’d recommend also testing your model on CPU on the latest 1.6/1.7 and seeing if the pathological performance persists.
Edit: have you looked at a flamegraph (either from VSCode’s profiler or for hotspots?

Hi, I tested the code from the original posting. Is this OK or do I have to adapt something?

I wouldn’t expect to see precompile show up (did you ensure the package was already precompiled?), but I don’t know enough about compilation to comment. Ideally any tracing would only start capturing from the point when pullback is called (i.e. after all packages are imported and the forward pass is warmed up), but again I’m not sure if that’s possible.

The reason I asked about CPU-only perf is because the cross-post in Extremely high first call latency Julia 1.6 versus 1.5 with multiphysics PDE solver - #25 by Alexander-Barth might be indicative of (lots of time spent in LLVM middle end during GPU compilation).

Fair point. But:

  • I didn’t see any reference to GPU/Cuda or else.
  • I have quite a simple system with built-in Intel graphics
Julia Version 1.6.4
Commit 35f0c911f4 (2021-11-19 03:54 UTC)
Platform Info:
  OS: Windows (x86_64-w64-mingw32)
  CPU: Intel(R) Core(TM) i7-10710U CPU @ 1.10GHz
  LIBM: libopenlibm
  LLVM: libLLVM-11.0.1 (ORCJIT, skylake)
Thank you for looking into the issue! The code in the original post is still fine. Some variations of it is to uncomment (progressively) the lines with #resblock(channels) which makes the neural network deeper (and the type signature of the structure representing the model more complex) and the compile time increases rapidly.

I will make some test on the CPUs.

Thank, you: I activated all of the commented lines to test with depth 5.

