[Flux / Zygote] Circumvent long compilation times

I have implemented a custom physics in model Flux / Zygote to benefit from gradient-based optimization of these great packages. It contains custom layers, each shipping its own ChainRulesCore.rrule. A single model contains many of these otherwise simple layers, and Zygote calculates the correct gradients out of the box.

Unfortunately, Zygote takes ages to calculate the first gradient. Compilation can easily exceed hours and take longer than the optimization itself.

Is there any way to avoid this painfully long compilation time? Why is compilation of layers built-into Flux significantly faster?

Here is a MWE, which already takes more than 20s to compile

import Flux
import ChainRulesCore

# custom layer
mutable struct Drift
    len::Float64
end

function driftLinear(p::AbstractVecOrMat, len::Float64)
    pnew = copy(p)
      
    pnew[1,:] .+= len .* p[2,:]
    pnew[3,:] .+= len .* p[4,:]

    return pnew
end

function ChainRulesCore.rrule(::typeof(driftLinear), pold::AbstractVecOrMat, len::Float64)
    p = driftLinear(pold, len)
    function driftLinear_pullback(Δ)
        newΔ = copy(Δ)

        newΔ[2,:] .+= len .* Δ[1,:]
        newΔ[4,:] .+= len .* Δ[3,:]
        return ChainRulesCore.NoTangent(), newΔ, ChainRulesCore.NoTangent()
    end
    return p, driftLinear_pullback
end

function (e::Drift)(p::AbstractVecOrMat)
    driftLinear(p, e.len)
end

# create model
function track(model::Flux.Chain, batch::AbstractVecOrMat)::AbstractArray
    out = reduce(hcat, Flux.activations(model, batch))
    out = reshape(out, 7, :, length(model))
    return PermutedDimsArray(out, (1,3,2))
end

model = Flux.Chain([Drift(1.) for _ in 1:10]...)

# take gradient
inp = ones(7,100)

Flux.trainable(e::Drift) = (e.len,)
parameters = Flux.params(model)

@time grads = Flux.gradient(parameters) do
    out = track(model, inp)
    Flux.Losses.mse(out, ones(size(out)))
end
21.495620 seconds (63.22 M allocations: 3.324 GiB, 4.14% gc time, 99.99% compilation time)

Are there any tricks to make things faster with Zygote?

3 Likes

I was able to cut the time-to-first-gradient (TTFG) from 24s to 16s with these 2 changes:

struct Drift # more efficient if mutable (and this doesn't need to be)
    len::Float64
end
Flux.@functor Drift # use this instead of overriding Flux.trainable for 99% of cases

function track(model::Flux.Chain, batch::AbstractVecOrMat)
    acts = Flux.activations(model, batch)
    out = hcat(Flux.activations(model, batch)...) # faster TTFG than reduce(hcat, ...) for some reason
    out = reshape(out, 7, :, length(model))
    return PermutedDimsArray(out, (1,3,2))
end

This is with Flux 0.13, which also makes calling Chains more compiler friendly (but not activations, unfortunately). It’s not altogether clear to me why hcat(...) would be easier to compile than reduce(hcat, ...) here since the opposite often seems to be true, but that made the largest difference (24s → 17s).

On a macro scale, long compilation times have been a perennial problem with Zygote. My last post on the second issue has some ideas for how we might start addressing this, but as it stands it’s unclear where we’ll get the resources (i.e. experienced people hours) for it given the uncertainty of the problem difficulty.

A couple of lesser notes:

  1. Zygote doesn’t consider scalar numbers parameters, so the gradients you’re getting back are empty. You can confirm this yourself by looking at the contents of parameters.
  2. There is a fixed overhead for Zygote.gradient, which you can see by running a trivial function first:
julia> @time Flux.gradient(() -> sum(inp), parameters)
 11.998717 seconds (30.36 M allocations: 1.621 GiB, 5.38% gc time, 100.00% compilation time)
Grads(...)

julia> @time grads = Flux.gradient(parameters) do
           out = track(model, inp)
           Flux.Losses.mse(out, ones(size(out)))
       end
 12.592630 seconds (33.18 M allocations: 1.721 GiB, 3.39% gc time, 99.97% compilation time)
Grads(...)
1 Like