On the future of Flux.destructure and SciML integration

Due to some technical limitations regarding Lux and coupling layers (1, 2, 3) and the existence of InvertibleNetworks.jl, I have chosen to use Flux for my designs. Still, my long-term objective would be to use Optimization.jl and other SciML-related modules, which generally expect flat parameter vectors. Thus, I’ll have to rely on destructure and Restructure.

The SciML documentation recommends not using Flux, due to some issues related to performance and type promotion (Flux.jl-vs-Lux.jl). My question: Are these issues the result of some deliberate design choice, arising perhaps from some compromise, or can one hope for a “fix” sometime in the future?

In other words, are these difficulties intrinsic to Flux’s design? If so, is there something one could do to partially bypass them?

Thanks in advance,

2 Likes

It’s important to know that these functions had an extremely bad reputation because the original implementation was extremely fragile & had zero tests.

Unfortunately, replacing such a thing means the ire is mostly redirected at the new version (which is much stricter, and much better tested). Perhaps removing it entirely would have been better.

This is not really true.

The complaint about type promotion was this. Flux regards a model being Float32 (or Float64 or Float16) as part of the model, and will re-construct it as such. It has explicit functions like f16 for changing between these. Supplying Float16 parameters to Restructure will not change the model’s precision. Perhaps that’s surprising compared to Base Julia, but GPUs are weird, and accidental Float64 on a GPU was the number one cause of “Flux is much much slower than pytorch” threads. What “well known” means is really that one guy (with a megaphone) was once surprised that the new less-sloppy implementation fixes this. It remains very easy to run models in Float64 if this is actually desired.

Performance should be OK? When last I tried it was fairly easy to cook up benchmarks in which either destructure or ComponentArrays.jl is a bit faster – they do similar things but have slight differences in when they allocate copies. Usually neither is a large slice of the model’s run time. If you have particular performance concerns they can probably be solved – there’s no intrinsic reason for large differences.

4 Likes

While I can’t really comment on what would be the most natural way to deal with types, I absolutely agree that destructure is a great tool, and gets the job done. It also seems to interact much more nicely with tied parameters, which is a major reason I’ll use Flux this time.

However, in regimes where de/restructuring is very frequent (mostly as a hack to use Flux with explicit parameters) the performance cost does seem quite severe, unless I’m doing something wrong. On my machine, small PINNs seem to be 3-4x slower when destructured:

using Flux, Zygote, BenchmarkTools

C = Chain(Dense(1=>10,tanh),Dense(10=>10,tanh),Dense(10=>10,tanh),Dense(10=>1))
ps, re = Flux.destructure(C)
v = [1f0]

@btime C(v)
@btime re(ps)(v)
@btime Zygote.gradient(c->c(v)[1],C)
@btime Zygote.gradient(p->re(p)(v)[1],ps)

This is, again, probably stretching the tool beyond what it was made for; I’m glad it exists in any case. My only question is if there’s anything I can do to reduce the slowdown - like designing custom layers or avoiding deep nesting of parameters.

(In any case, “well known” was a poor choice of words, and I have edited the OP to reflect it)

In this regime it is expensive. While I haven’t checked this example today, my claim above is that using ComponentArrays.jl will also be expensive. I’d say that explicit/implicit is not the right axis here, it’s more like solid-vector vs. nested structure. For sufficiently small networks, converting back & forth may take longer than the matrix multiplications etc.

There is more that could be done here, e.g. Optimisers.jl could add a destructure! which re-uses the same model quite easily. It might save quite a bit here. See e.g. issue 146.

This is the regime where SimpleChains.jl is likely to be much faster. It’s a completely different design, which never uses a nested structure at all. But has various other limitations.

This is a key difference from ComponentArrays.jl, which perhaps we should highlight. (As well as being the main source of coding headaches!) I’m glad if it’s useful to someone.

julia> using ComponentArrays, Optimisers

julia> let twice = [1.0, 2.0]
        cv = ComponentArray(x=twice, y=twice, z=[1.0, 2.0])
        cv.x[1] += 999
        cv  # this has 6 independent scalar parameters
       end
ComponentVector{Float64}(x = [1000.0, 2.0], y = [1.0, 2.0], z = [1.0, 2.0])

julia> let twice = [1.0, 2.0]
        v, re = destructure((x=twice, y=twice, z=[1.0, 2.0]))
        @show v  # only 4 indep parameters
        v[1] += 999
        re(v)
       end
v = [1.0, 2.0, 1.0, 2.0]
(x = [1000.0, 2.0], y = [1000.0, 2.0], z = [1.0, 2.0])

Fantastic, thanks. re making a view instead of making a copy seems like a relatively straightforward fix; I’ll try and get it running on my end.

1 Like