I have implemented a custom physics in model Flux / Zygote to benefit from gradient-based optimization of these great packages. It contains custom layers, each shipping its own ChainRulesCore.rrule
. A single model contains many of these otherwise simple layers, and Zygote calculates the correct gradients out of the box.
Unfortunately, Zygote takes ages to calculate the first gradient. Compilation can easily exceed hours and take longer than the optimization itself.
Is there any way to avoid this painfully long compilation time? Why is compilation of layers built-into Flux significantly faster?
Here is a MWE, which already takes more than 20s to compile
import Flux
import ChainRulesCore
# custom layer
mutable struct Drift
function driftLinear(p::AbstractVecOrMat, len::Float64)
pnew = copy(p)
pnew[1,:] .+= len .* p[2,:]
pnew[3,:] .+= len .* p[4,:]
return pnew
function ChainRulesCore.rrule(::typeof(driftLinear), pold::AbstractVecOrMat, len::Float64)
p = driftLinear(pold, len)
function driftLinear_pullback(Δ)
newΔ = copy(Δ)
newΔ[2,:] .+= len .* Δ[1,:]
newΔ[4,:] .+= len .* Δ[3,:]
return ChainRulesCore.NoTangent(), newΔ, ChainRulesCore.NoTangent()
return p, driftLinear_pullback
function (e::Drift)(p::AbstractVecOrMat)
driftLinear(p, e.len)
# create model
function track(model::Flux.Chain, batch::AbstractVecOrMat)::AbstractArray
out = reduce(hcat, Flux.activations(model, batch))
out = reshape(out, 7, :, length(model))
return PermutedDimsArray(out, (1,3,2))
model = Flux.Chain([Drift(1.) for _ in 1:10]...)
# take gradient
inp = ones(7,100)
Flux.trainable(e::Drift) = (e.len,)
parameters = Flux.params(model)
@time grads = Flux.gradient(parameters) do
out = track(model, inp)
Flux.Losses.mse(out, ones(size(out)))
21.495620 seconds (63.22 M allocations: 3.324 GiB, 4.14% gc time, 99.99% compilation time)
Are there any tricks to make things faster with Zygote?