Significant compile time latency in Flux with a GAN

Indeed on the CPU I also see this high compile time latency on the first call (more than 1030 minutes, still running) with Julia 1.6.3.

2 Likes

Question of an uninitiated: does the Julia compiler dispatch compilation of specializations to the GPU packages somehow? Ah, OK: I’d imagine GPU packages use for example macros in functions, whose compilation is obviously dispatched by the compiler depending on optimization (specialization) level?

Here is a screenshot of the callgrind profiles (only for the second call to pullback) on julia 1.6.3 which shows that still a lot of time is spent on compilation (in fact almost all):

Could it be that LLVM tries to do some complex optimizations but which fails due to the deeply nested types used here. Then on the second call, the compiler wants again optimize the code (and takes again a long time).

1 Like

I think it has more to do with a combinatorial explosion of specializations induced by deeply nested varying tuple types (because using either -O1 or using @nospecialize seems to reduce/eliminate the problem). Only question to discuss on Extremely high first call latency Julia 1.6 versus 1.5 with multiphysics PDE solver · Issue #43206 · JuliaLang/julia · GitHub is if this optimization works as intended or out of control in these situations…

Edit: I just checked that and have to correct myself: it doesn’t seem to be the number but the complexity of specialization, which is the problem. In both cases I analyzed the critical signature was a couple of 100 kB long, so yes Alexander, you are probably right!

2 Likes

I managed to reduce this to the following:

using Flux

channels = 4

function resblock(channels)
    return SkipConnection(Chain(
        Conv((3, 3), channels => channels, pad=1),
        Conv((3, 3), channels => channels, pad=1),
    ), +)
end

model = Chain(
    SkipConnection(
        Chain(
            resblock(channels),
            resblock(channels),
            resblock(channels),
            resblock(channels),
            resblock(channels),

            resblock(channels),
            resblock(channels),
            resblock(channels),
            resblock(channels),
            resblock(channels),

            resblock(channels),
            resblock(channels),
            resblock(channels),
            resblock(channels),
            resblock(channels),
        ),
    +),
    AdaptiveMeanPool((1, 1))
)

display(model)
println()
@show typeof(model) 

loss(x) = sum(model(x))

lr_images = randn(Float32, 2, 2, channels, 1)
@time loss(lr_images)
@time loss(lr_images) 
    
loss_grad(x, ps) = gradient(() -> loss(x), ps)
    
ps = Flux.params(model)
@time loss_grad(lr_images, ps)
@time loss_grad(lr_images, ps)

Not sure about the second gradient timing, but the 1st one has been going for over 10 minutes with no results. Importantly, removing just one of the resblocks results in a far more tractable compilation time (115s).

Do you mind filing an issue about this? I wonder if Very slow first-time gradient calculation · Issue #1119 · FluxML/Zygote.jl · GitHub might be related as well.

3 Likes

I took the “varying tuple types” hypothesis for a spin and tried the same model using NaiveNASflux which uses a conciously type unstable representation of the model to allow for mutation.

Using the same code as @ToucheSir above but converting it to a CompGraph using ONNXNaiveNASflux (as there is no convenient built-in for direct Chain => CompGraph translation):

mio = PipeBuffer();

# One off thing as ONNXNaiveNASflux currently does not support AdapiveMeanPool
(::AdaptiveMeanPool)(pp::ONNXNaiveNASflux.AbstractProbe) = ONNXNaiveNASflux.globalpool(pp, identity, "GlobalAveragePool");

save(mio, model);

model_graph = load(mio);

model_graph(lr_images) == model(lr_images) # true

loss(x) = sum(model_graph(x));

@time loss_grad(lr_images, Flux.params(model_graph))
24.340328 seconds (49.10 M allocations: 2.556 GiB, 5.01% gc time, 99.83% compilation time)
Grads(...)

However, the chain example finishes on my computer albeit a bit slower so maybe I’m not comparing apples to apples here:

loss(x) = sum(model(x))

@time loss_grad(lr_images, Flux.params(model))
 98.507128 seconds (9.56 M allocations: 499.581 MiB, 0.18% gc time, 99.93% compilation time)
Grads(...)

versioninfo()
Julia Version 1.7.0-rc3
Commit 3348de4ea6 (2021-11-15 08:22 UTC)
Platform Info:
  OS: Windows (x86_64-w64-mingw32)
  CPU: Intel(R) Core(TM) i7-5820K CPU @ 3.30GHz
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-12.0.1 (ORCJIT, haswell)
Environment:
  JULIA_DEPOT_PATH = E:/Programs/julia/.julia
  JULIA_EDITOR = code
  JULIA_NUM_THREADS = 6
2 Likes

Thanks a lot for reducing the code!

Good idea, I just filled it here:

2 Likes

This looks very interesting! Thank you for pointing me in this direction.

I wanted to compare the gradients. It seems that length(grad_CompGraph.grads) has 61 elements while length(grad_Chain.grads) has 60 elements.
But all gradients of the 60 parameters (returned by Flux.params) are the same, so I guess the additional elements would not matter.

I am wondering about this line:

(::AdaptiveMeanPool)(pp::ONNXNaiveNASflux.AbstractProbe) = ONNXNaiveNASflux.globalpool(pp, identity, "GlobalAveragePool");

It seems that this actually implements the AdaptiveMeanPool in ONNXNaiveNASflux (or is it rather a stub)?

(I used ONNXNaiveNASflux in the past to convert a ONNX to model Flux, and it worked really great ! :smiley: )

using ONNXNaiveNASflux
using Test

mio = PipeBuffer();

# One off thing as ONNXNaiveNASflux currently does not support AdapiveMeanPool                                                                                        
(::AdaptiveMeanPool)(pp::ONNXNaiveNASflux.AbstractProbe) = ONNXNaiveNASflux.globalpool(pp, identity, "GlobalAveragePool");

save(mio, model);

model_graph = load(mio);

model_graph(lr_images) == model(lr_images) # true                                                                                                                     

loss(x) = sum(model_graph(x));

grad_CompGraph = @time loss_grad(lr_images, Flux.params(model_graph))


loss(x) = sum(model(x))

grad_Chain = @time loss_grad(lr_images, Flux.params(model))

@show length(grad_CompGraph.grads)
@show length(grad_Chain.grads)

for i = 1:length(Flux.params(model))
    @show i
    @show @test     grad_Chain[Flux.params(model)[i]]  ≈  grad_CompGraph[Flux.params(model_graph)[i]]
end
1 Like

That’s odd. I could swear it was the same as chain when I tested. I’m not at the computer now but I will check again when I am.

It adds it as a global pool which happens to be the same in this case (I think, didn’t bother to check how arguments are used), but not in general. If you need it it should not be hard to add assuming the ONNX definition is somewhat sane.

1 Like