Significant compile time latency in Flux with a GAN

I am trying to implement a SRGAN (super-resolution generative adversarial networks) in Flux (I am happy to share once it works… :-))
The generator network is composed of 5 residual blocks (function resblock below).
Unfortunately the time of the first call of Zygote.pullback increases very quickly
with the number of residual blocks, so that I am not able to use the 5
residual blocks as intended. I already commented out any elements-wise
broadcasting (activation functions), but the issue is still present.

number res. block pullback time (1st call) pullback time (2nd call)
0 58.2 s 10.5 s
1 73.33 s 21.2 s
2 94.13 s 39.08 s
3 3264 s (54.4 min) 3240 s (54.0 min)

I am using Julia 1.7.0-rc1 and Flux 0.12.6 and Zygote 0.6.21 on Linux (Intel i9-10900X, NVIDIA GeForce RTX 3080).
The issue is also present in Julia 1.6.1 while Julia 1.7.0-rc1 is a bit faster.
If I start julia with the option -O1, then the compile times become manageable (e.g. 20 seconds for 3 residual blocks) but the runtime performance in general will probably suffer.

Any help would be greatly appreciated.

using Flux
using Dates

channels = 64
in_channels = 3
num_channels = [3,64,64,128,128,256,256,512,512]

# discriminator

function block((in_channels,out_channels); stride = 1, use_batch_norm = true)
    layers = []
    push!(layers, Conv((3,3),in_channels => out_channels,pad = 1,stride = stride))
    if use_batch_norm
        push!(layers, BatchNorm(out_channels))
    end
    #push!(layers, x -> leakyrelu.(x,0.2f0))
    return layers
end

discriminator = Chain(
    reduce(vcat,[block(num_channels[i] => num_channels[i+1];
            stride = 1 + (i+1) % 2,
            use_batch_norm = i!=1) for i = 1:length(num_channels)-1 ])...,
    AdaptiveMeanPool((1,1)),
    Conv((1,1), num_channels[end] => 1024),
    #x -> leakyrelu.(x,0.2f0),
    Conv((1,1), 1024 => 1),
) |> gpu


function resblock(channels)
    return SkipConnection(Chain(
            Conv((3,3),channels => channels, pad=1),
            BatchNorm(channels),
            #Prelu(),
            Conv((3,3),channels => channels, pad=1),
            BatchNorm(channels),
        )
        , +)
end


function upsample(in_channels, up_scale)
    return [
        Conv((3,3),in_channels => in_channels*up_scale^2,pad=1),
        PixelShuffle(up_scale),
        #Prelu(),
    ]
end

generator = Chain(
    Conv((9,9),3 => channels, pad = 4),
    #Prelu(),

    SkipConnection(Chain(
        # test with different number of residual blocks
        resblock(channels),
        #resblock(channels),
        #resblock(channels),
        #resblock(channels),
        #resblock(channels),
        Conv((3,3),channels => channels, pad=1),
        BatchNorm(channels)),+),
    upsample(channels, 2)...,
    upsample(channels, 2)...,
    Conv((9,9),channels => 3,σ, pad=4),
) |> gpu;

hr_images = randn(Float32,88,88,3,32)
lr_images = randn(Float32,22,22,3,32)

hr_images = gpu(hr_images)
lr_images = gpu(lr_images)

# check foreward
@show sum(discriminator(generator(lr_images)))

params_g = Flux.params(generator)

@info "generator $(Dates.now())"

# Taking gradient of generator
loss_g, back = @time Flux.pullback(params_g) do
    sum(discriminator(generator(lr_images)))
end

Those are some downright nasty times! Can you also post the output of @time so we can see what the reported compilation time is? I’d also be curious to know about the numbers for running just the forward pass as well.

For two residuals blocks, the very first call to pullback give this output of @time:

94.131946 seconds (62.50 M allocations: 3.273 GiB, 1.09% gc time, 11.29% compilation time)

The following call is :

39.082354 seconds (55.50 k allocations: 4.078 MiB)

In my table before I reported the runtime on second calls (except of 3 residual blocks where it did not work). I updated the table to distinguish between the 1st and 2nd call.

These are the output of the forward call @time discriminator(generator(lr_images)) for 2 residuals blocks:

# first call
 21.147667 seconds (54.87 M allocations: 2.882 GiB, 4.68% gc time, 64.02% compilation time)
# second call
  0.011111 seconds (34.36 k allocations: 1.796 MiB)

Interestingly, the compile time of the forward pass does not change with the number of residual layers.

The latency of the gradients of discriminator and generator individually are both reasonable. It is the compile time of the combined function discriminator(generator(lr_images)) that grows so quickly.

I just let it run overnight and here are compile time for 3 residual blocks (approximately 1 hour):

# forward pass of discriminator(generator(lr_images)) (1st call)
21.046987 seconds (54.87 M allocations: 2.882 GiB, 4.79% gc time, 63.64% compilation time)
# forward pass of discriminator(generator(lr_images)) (2nd call)
  0.011780 seconds (34.65 k allocations: 1.810 MiB)
# backward pass of discriminator(generator(lr_images)) (1st call)
3264.246108 seconds (62.58 M allocations: 3.277 GiB, 0.04% gc time, 0.33% compilation time)
# backward pass of discriminator(generator(lr_images)) (2nd call)
3240.438707 seconds (59.96 k allocations: 4.199 MiB)

Interestingly the runtime of the backward pass for the 1st and 2nd call is almost identical.

I am wondering if in this code, Zygote is also taking the gradients of all parameters in discriminator (even they are not needed when updating the generator):

params_g = Flux.params(generator)
loss_g, back = @time Flux.pullback(params_g) do
    sum(discriminator(generator(lr_images)))
end

The machine has 32 GB of RAM and the RAM utilization of julia was about 5%. I did not observe any memory disk swapping.

On a different machine I tried with julia 1.5.2 and got very good runtime values with 3 residual blocks:

julia> include("/home/ulg/gher/abarth/Julia/share/test_zygote_perf.jl")
# forward pass (1st and 2nd call)
19.787238 seconds (51.32 M allocations: 2.547 GiB, 4.63% gc time)
  0.492617 seconds (439.07 k allocations: 21.413 MiB)
# backward pass (1st and 2nd call)
 28.111849 seconds (51.57 M allocations: 2.611 GiB, 3.74% gc time)
  0.069456 seconds (50.17 k allocations: 2.577 MiB)

That would be a factor of ~50 000 (between julia 1.7.0-rc1/1.6.1 and julia 1.5.2) for the 2nd call of the backward pass!

(Flux-0.12) pkg> st --manifest
Status `/home/users/a/b/abarth/.julia/environments/Flux-0.12/Manifest.toml`
  [621f4979] AbstractFFTs v1.0.1
  [1520ce14] AbstractTrees v0.3.4
  [79e6a3ab] Adapt v3.3.1
  [56f22d72] Artifacts v1.3.0
  [ab4f0b2a] BFloat16s v0.1.0
  [fa961155] CEnum v0.4.1
  [052768ef] CUDA v2.4.3
  [082447d4] ChainRules v0.7.70
  [d360d2e6] ChainRulesCore v0.9.45
  [944b1d66] CodecZlib v0.7.0
  [3da002f7] ColorTypes v0.11.0
  [5ae59095] Colors v0.12.8
  [bbf7d656] CommonSubexpressions v0.3.0
  [34da2185] Compat v3.37.0
  [e66e0078] CompilerSupportLibraries_jll v0.3.4+0
  [9a962f9c] DataAPI v1.9.0
  [864edb3b] DataStructures v0.18.10
  [163ba53b] DiffResults v1.0.3
  [b552c78f] DiffRules v1.3.1
  [ffbed154] DocStringExtensions v0.8.5
  [e2ba6199] ExprTools v0.1.6
  [1a297f60] FillArrays v0.11.9
  [53c48c17] FixedPointNumbers v0.8.4
  [587475ba] Flux v0.12.1
  [f6369f11] ForwardDiff v0.10.19
  [d9f16b24] Functors v0.2.5
  [0c68f7d7] GPUArrays v6.4.1
  [61eb1bfa] GPUCompiler v0.8.3
  [7869d1d1] IRTools v0.4.3
  [92d709cd] IrrationalConstants v0.1.0
  [692b3bcd] JLLWrappers v1.3.0
  [e5e0dc1b] Juno v0.8.4
  [929cbde3] LLVM v3.9.0
  [2ab3a3ac] LogExpFunctions v0.3.0
  [1914dd2f] MacroTools v0.5.8
  [e89f7d12] Media v0.5.0
  [e1d29d7a] Missings v1.0.2
  [872c559c] NNlib v0.7.19
  [77ba4419] NaNMath v0.3.5
  [05823500] OpenLibm_jll v0.7.1+0
  [efe28fd5] OpenSpecFun_jll v0.5.3+4
  [bac558e1] OrderedCollections v1.4.1
  [21216c6a] Preferences v1.2.2
  [189a3867] Reexport v1.2.2
  [ae029012] Requires v1.1.3
  [6c6a2e73] Scratch v1.1.0
  [a2af1166] SortingAlgorithms v1.0.1
  [276daf66] SpecialFunctions v1.6.2
  [90137ffa] StaticArrays v1.2.12
  [82ae8749] StatsAPI v1.0.0
  [2913bbd2] StatsBase v0.33.10
  [fa267f1f] TOML v1.0.3
  [a759f4b9] TimerOutputs v0.5.12
  [3bb67fe8] TranscodingStreams v0.9.6
  [a5390f91] ZipFile v0.9.4
  [83775a58] Zlib_jll v1.2.11+18
  [e88e6eb3] Zygote v0.6.12
  [700de1a5] ZygoteRules v0.2.1
  [2a0f44e3] Base64
  [ade2ca70] Dates
  [8bb1440f] DelimitedFiles
  [8ba89e20] Distributed
  [b77e0a4c] InteractiveUtils
  [76f85450] LibGit2
  [8f399da3] Libdl
  [37e2e46d] LinearAlgebra
  [56ddb016] Logging
  [d6f4376e] Markdown
  [a63ad114] Mmap
  [44cfe95a] Pkg
  [de0858da] Printf
  [9abbd945] Profile
  [3fa0cd96] REPL
  [9a3f8284] Random
  [ea8e919c] SHA
  [9e88b42a] Serialization
  [1a1011a3] SharedArrays
  [6462fe0b] Sockets
  [2f01184e] SparseArrays
  [10745b16] Statistics
  [8dfed614] Test
  [cf7118a7] UUIDs
  [4ec0a83e] Unicode
julia> versioninfo()
Julia Version 1.5.2
Commit 539f3ce943 (2020-09-23 23:17 UTC)
Platform Info:
  OS: Linux (x86_64-pc-linux-gnu)
  CPU: Intel(R) Xeon(R) Gold 6126 CPU @ 2.60GHz
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-9.0.1 (ORCJIT, skylake-avx512)
Environment:
  JULIA_REVISE_POLL = 1
1 Like

I’m not seeing these compile times. With 5 blocks:

generator = Chain(
    Conv((9,9),3 => channels, pad = 4),
    #Prelu(),

    SkipConnection(Chain(
        # test with different number of residual blocks
        resblock(channels),
        resblock(channels),
        resblock(channels),
        resblock(channels),
        resblock(channels),
        Conv((3,3),channels => channels, pad=1),
        BatchNorm(channels)),+),
    upsample(channels, 2)...,
    upsample(channels, 2)...,
    Conv((9,9),channels => 3,σ, pad=4),
) |> gpu;

# ...

@time sum(discriminator(generator(lr_images)))
@time sum(discriminator(generator(lr_images)))

# ...

@time Flux.pullback(params_g) do
    sum(discriminator(generator(lr_images)))
end
@time Flux.pullback(params_g) do
    sum(discriminator(generator(lr_images)))
end
35.056557 seconds (63.13 M allocations: 3.313 GiB, 2.61% gc time, 31.06% compilation time)
  0.019049 seconds (48.12 k allocations: 1.993 MiB)

 83.980706 seconds (66.88 M allocations: 3.488 GiB, 1.04% gc time)
 23.423331 seconds (83.27 k allocations: 4.966 MiB)

This is on Julia 1.7-rc1 with Flux.jl from master.

Seeing similar numbers as maleadt on both 1.7.0-rc1 and 1.6.2 using Flux 0.12.6

Thank you all for looking into this issue!
Indeed with the master version of Flux, I get similar run times on Julia 1.7.0-rc1 than you (with 5 residual blocks):

# forward pass
 21.221844 seconds (54.54 M allocations: 2.865 GiB, 5.28% gc time, 62.09% compilation time)                                                                           
  0.011881 seconds (35.20 k allocations: 1.833 MiB)                                
# backward pass
116.064788 seconds (62.76 M allocations: 3.287 GiB, 0.87% gc time, 9.09% compilation time)                                                                                   
 53.833062 seconds (59.98 k allocations: 4.688 MiB)   

With Julia 1.5.3 installed on the same hardware and the same script, I get:

# forward pass
 18.232473 seconds (54.21 M allocations: 2.702 GiB, 4.11% gc time)
  0.357820 seconds (492.87 k allocations: 24.360 MiB, 5.85% gc time)
# backward pass
 22.321700 seconds (51.42 M allocations: 2.603 GiB, 3.46% gc time)
  0.048057 seconds (52.73 k allocations: 2.716 MiB)

Calling pullback (2nd time) in Julia 1.7.0-rc1 takes 53.8 s for me, while it takes only 0.048 s on julia 1.5.3.
However, many packages are installed at a different version in julia 1.5.3, e.g. CUDA v2.4.3, Flux v0.12.1, GPUCompiler v0.8.3.

IIRC Flux 0.12.6 and Flux master should have the same compat. Can you run a quick ] up while using the former and see if it makes a difference? Another thing to test would be comparing CPU vs GPU speeds, because compilation times may differ (if only slightly) between the two.

Try using the profiler to figure out what regressed.

It now seems that I do not get these high compile times every time. Even with Flux 0.12.6 I have compile times of the order of minutes, but not always. I am still in front of a julia 1.6.2 session where I got these times (2 forward calls and 2 calls to pullback with 5 residual blocks):

$ /opt/julia-1.6.2/bin/julia /home/abarth/projects/Julia/share/test_zygote_perf5.jl                                                                                   
 18.040173 seconds (44.21 M allocations: 2.507 GiB, 5.14% gc time)                                                                                                    
  0.009322 seconds (30.69 k allocations: 1.739 MiB, 73.32% compilation time)                                                                                          
sum(discriminator(generator(lr_images))) = 0.16050659f0                                                                                                               
[ Info: generator 2021-09-23T21:50:35.220                                                                                                                             
2589.196584 seconds (61.49 M allocations: 3.500 GiB, 0.05% gc time, 0.00% compilation time)                                                                           
2890.574050 seconds (67.31 k allocations: 5.385 MiB, 100.00% compilation time)                                                                                        

I thought to share a screenshot, but that would be silly :slight_smile:. I guess that the issue is not in Flux, but maybe in some of its dependencies.
With Flux.jl#master in julia 1.6.2 I have compile times of at least 17 minutes (still running) with the only change in the manifest file is the version of Flux:

$ diff ~/.julia/environments/v1.6/Manifest.toml ~/.julia/environments/v1.6/Manifest.toml.long-compile-times-julia1.6.2
544,546c544
< git-tree-sha1 = "ecc78fd4d97c11af01c1b73a895529a4f2292045"
< repo-rev = "master"
< repo-url = "https://github.com/FluxML/Flux.jl.git"
---
> git-tree-sha1 = "1286e5dd0b4c306108747356a7a5d39a11dc4080"

I ran the profiler for the pullback call:

# compile
loss_g, back = @time Flux.pullback(params_g) do
    sum(discriminator(generator(lr_images)))
end

# run pre-compiled
loss_g, back = @time Flux.pullback(params_g) do
    sum(discriminator(generator(lr_images)))
end

# run and profile pre-compiled
loss_g, back = @profile Flux.pullback(params_g) do
    sum(discriminator(generator(lr_images)))
end

open("profile-julia-$(VERSION)-resblock5.log","w") do f
    Profile.print(f);
end

Most of the time is spent in pullback / _pullback and the line test_zygote_perf5.jl:131 corresponding to line sum(discriminator(generator(lr_images))) within the pullback do-block.

Any advice how to get more detailed information would be very helpful.

julia 1.5.3

julia 1.7.0-rc1