I have implemented a custom physics in model Flux / Zygote to benefit from gradient-based optimization of these great packages. It contains custom layers, each shipping its own ChainRulesCore.rrule
. A single model contains many of these otherwise simple layers, and Zygote calculates the correct gradients out of the box.
Unfortunately, Zygote takes ages to calculate the first gradient. Compilation can easily exceed hours and take longer than the optimization itself.
Is there any way to avoid this painfully long compilation time? Why is compilation of layers built-into Flux significantly faster?
Here is a MWE, which already takes more than 20s to compile
import Flux
import ChainRulesCore
# custom layer
mutable struct Drift
len::Float64
end
function driftLinear(p::AbstractVecOrMat, len::Float64)
pnew = copy(p)
pnew[1,:] .+= len .* p[2,:]
pnew[3,:] .+= len .* p[4,:]
return pnew
end
function ChainRulesCore.rrule(::typeof(driftLinear), pold::AbstractVecOrMat, len::Float64)
p = driftLinear(pold, len)
function driftLinear_pullback(Δ)
newΔ = copy(Δ)
newΔ[2,:] .+= len .* Δ[1,:]
newΔ[4,:] .+= len .* Δ[3,:]
return ChainRulesCore.NoTangent(), newΔ, ChainRulesCore.NoTangent()
end
return p, driftLinear_pullback
end
function (e::Drift)(p::AbstractVecOrMat)
driftLinear(p, e.len)
end
# create model
function track(model::Flux.Chain, batch::AbstractVecOrMat)::AbstractArray
out = reduce(hcat, Flux.activations(model, batch))
out = reshape(out, 7, :, length(model))
return PermutedDimsArray(out, (1,3,2))
end
model = Flux.Chain([Drift(1.) for _ in 1:10]...)
# take gradient
inp = ones(7,100)
Flux.trainable(e::Drift) = (e.len,)
parameters = Flux.params(model)
@time grads = Flux.gradient(parameters) do
out = track(model, inp)
Flux.Losses.mse(out, ones(size(out)))
end
21.495620 seconds (63.22 M allocations: 3.324 GiB, 4.14% gc time, 99.99% compilation time)
Are there any tricks to make things faster with Zygote?