[ANN] Lux.jl: Explicitly Parameterized Neural Networks in Julia

Lux is a new Julia deep learning framework that decouples models and parameterization using deeply nested named tuples.

  • Functional Layer API – Pure Functions and Deterministic Function Calls.
  • No more implicit parameterization
  • Compiler and AD-friendly Neural Networks
using Lux, Random, Optimisers, Zygote

# Seeding
rng = Random.default_rng()
Random.seed!(rng, 0)

# Construct the layer
model = Chain(
    BatchNorm(128),
    Dense(128, 256, tanh),
    BatchNorm(256),
    Chain(
        Dense(256, 1, tanh),
        Dense(1, 10)
    )
)

# Parameter and State Variables
ps, st = Lux.setup(rng, model) .|> gpu

# Dummy Input
x = rand(rng, Float32, 128, 2) |> gpu

# Run the model
y, st = Lux.apply(model, x, ps, st)

# Gradients
gs = gradient(p -> sum(Lux.apply(model, x, p, st)[1]), ps)[1]

# Optimization
st_opt = Optimisers.setup(Optimisers.ADAM(0.0001), ps)
st_opt, ps = Optimisers.update(st_opt, ps, gs)

To learn more about the design principle, see https://lux.csail.mit.edu/dev/introduction/overview/. For tutorials see the documentation and examples

We also provide a bunch of pretrained model weights for computer vision applications via another package Boltz scheduled to be released in a few days.

Why Lux when we already had Flux?

Lux is essentially Flux without all forms of internal mutations and a decoupled model-parameter state (and more type-stable). Quoting from SciMLStyle

Mutating codes and non-mutating codes fall into different worlds. When a code is fully immutable, the compiler can better reason about dependencies, optimize the code, and check for correctness. However, many times a code making the fullest use of mutation can outperform even what the best compilers of today can generate. That said, the worst of all worlds is when code mixes mutation with non-mutating code. Not only is this a mishmash of coding styles, it has the potential non-locality and compiler-proof issues of mutating code while not fully benefiting from the mutation.

Finally, one cherry-picked point from State of machine learning in Julia

43 Likes

This was a tremendous effort. I like the functional point of view.

May I ask if it has any effect for example on speed? I went through docs and you write that functional view is β€œnicer” to julia Optimizer.

Another question is, would Lux.jl stimulate the adoption of Diffractor, as according to my understanding, one problem preventing the use of Diffractor in Flux was implicit parameters.

And another question. Is there a chance / possibility to write modules that would be compatible with Lux and Flux?

Thanks for answers in advance.

1 Like

There are some improvements but that is mostly due to fixing some type-stability issues and such, which are not Lux specific and once I get some time I will try to upstream those to NNlib.

The "nice"ss I mentioned will come from things like Immutable Arrays Β· Issue #8 Β· avik-pal/Lux.jl Β· GitHub once they are ready.

I have not tested Diffractor so cant say much on that topic. However, I have tested out Enzyme (with its new BLAS support and stuff) and we will mostly be moving the enzyme route once the rules system is ready

Nothing in Flux prevents using a functional Layer so something like

struct WrappedDense
    d::Lux.Dense
    ps
    st
end

@functor WrappedDense

trainable(wd::Wra...) = wd.ps

(wd::WrappedDense)(x) = (wd.d)(x, wd.ps, wd.st)

So it is possible to go from Lux β†’ Flux side but the other direction though possible is strongly not recommended

3 Likes

Adding on to what Avik said:

As of Flux 0.13, there is official support for training models end-to-end without ever touching implicit params. Indeed Flux and Lux share most of the core code that enables this (namely Optimisers.jl and Zygote’s β€œexplicit parameters” mode). The only reason implicit params haven’t been dropped completely is that it would be one of the biggest compat breaks in recent history.

In addition to the WrappedDense example, you can implement the Lux AbstractExplicitLayer API for layers you control. Here’s a (not robust or comprehensive) example for Flux.Dense:

initialparameters(rng::AbstractRNG, d::Dense) = (weight=d.weight, bias=d.bias)
# or 
initialparameters(rng::AbstractRNG, d::Dense) = (weight=Flux.glorot_uniform(size(d.weight...), bias=d.bias)

# Alternatively, overload Lux.apply
function (d::Dense)(x::AbstractVecOrMat, ps::Union{ComponentArray, NamedTuple}, ::NamedTuple)
  Οƒ = NNlib.fast_act(d.Οƒ, x)  # replaces tanh => tanh_fast, etc
  return Οƒ.(ps.weight * x .+ ps.bias)
end

Since neither library enforces an abstract layer supertype, no additional dependencies should be required.

Longer term, the hope is to pull out as much common functionality as possible from both libraries. This includes functional routines for NNlib, initializations, loss functions and possibly even part of the layer interface. For example, we’ve talked about having an apply function in Flux for some time now (in large part to avoid mutable layer structs and any mutation in the forward pass), but it was never seriously attempted because of the backwards compat headache. Now that Lux has blown that door open, we can better understand the implications of such an API and how to β€œbackport” it.

9 Likes

All sounds great.

I have compared the speed of Flux and Lux on the example from Lux’s readme as

using Flux, Random, Optimisers, Zygote, BenchmarkTools
rng = Random.default_rng()
Random.seed!(rng, 0)
model = Chain(
    BatchNorm(128),
    Dense(128, 256, tanh),
    BatchNorm(256),
    Chain(
        Dense(256, 1, tanh),
        Dense(1, 10)
    )
)

ps = Flux.params(model)
x = rand(rng, Float32, 128, 2)

@btime gradient(() -> sum(model(x)), ps)
  387.708 ΞΌs (1141 allocations: 300.03 KiB)

and Lux

using Lux, Random, Optimisers, Zygote, BenchmarkTools
rng = Random.default_rng()
Random.seed!(rng, 0)
model = Chain(
    BatchNorm(128),
    Dense(128, 256, tanh),
    BatchNorm(256),
    Chain(
        Dense(256, 1, tanh),
        Dense(1, 10)
    )
)
ps, st = Lux.setup(rng, model)
x = rand(rng, Float32, 128, 2)
julia> @btime gradient(p -> sum(Lux.apply(model, x, p, st)[1]), ps)
  1.013 ms (2645 allocations: 366.52 KiB)

And flux is a bit faster.

I really like how the models can be create 1:1, no need for changes. I also wanted to try it with diffractor, Finally, I wanted to try Enzyme v0.9.6 as

autodiff(p -> sum(Lux.apply(model, x, p, st)[1]), Active, Active(ps))

but that has failed as well with an error

Error message
ERROR: MethodError: no method matching parent(::LLVM.Argument)
Closest candidates are:
  parent(::LLVM.GlobalValue) at ~/.julia/packages/LLVM/gE6U9/src/core/value/constant.jl:529
  parent(::LLVM.Instruction) at ~/.julia/packages/LLVM/gE6U9/src/core/instructions.jl:45
  parent(::LLVM.BasicBlock) at ~/.julia/packages/LLVM/gE6U9/src/core/basicblock.jl:27
Stacktrace:
  [1] parent_scope(val::LLVM.Argument, depth::Int64) (repeats 2 times)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/Ctome/src/compiler.jl:2570
  [2] (::Enzyme.Compiler.var"#49#50"{LLVM.Argument})(io::IOBuffer)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/Ctome/src/compiler.jl:2585
  [3] sprint(::Function; context::Nothing, sizehint::Int64)
    @ Base ./strings/io.jl:114
  [4] sprint(::Function)
    @ Base ./strings/io.jl:108
  [5] julia_error(cstr::Cstring, val::Ptr{LLVM.API.LLVMOpaqueValue}, errtype::Enzyme.API.ErrorType, data::Ptr{Nothing})
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/Ctome/src/compiler.jl:2585
  [6] EnzymeCreateAugmentedPrimal(logic::Enzyme.Logic, todiff::LLVM.Function, retType::Enzyme.API.CDIFFE_TYPE, constant_args::Vector{Enzyme.API.CDIFFE_TYPE}, TA::Enzyme.TypeAnalysis, returnUsed::Bool, shadowReturnUsed::Bool, typeInfo::Enzyme.FnTypeInfo, uncacheable_args::Vector{Bool}, forceAnonymousTape::Bool, width::Int64, atomicAdd::Bool)
    @ Enzyme.API ~/.julia/packages/Enzyme/Ctome/src/api.jl:147
  [7] enzyme!(job::GPUCompiler.CompilerJob{Enzyme.Compiler.EnzymeTarget, Enzyme.Compiler.EnzymeCompilerParams, GPUCompiler.FunctionSpec{var"#1#2", Tuple{NamedTuple{(:layer_1, :layer_2, :layer_3, :layer_4, :layer_5), Tuple{NamedTuple{(:scale, :bias), Tuple{Vector{Float32}, Vector{Float32}}}, NamedTuple{(:weight, :bias), Tuple{Matrix{Float32}, Matrix{Float32}}}, NamedTuple{(:scale, :bias), Tuple{Vector{Float32}, Vector{Float32}}}, NamedTuple{(:weight, :bias), Tuple{Matrix{Float32}, Matrix{Float32}}}, NamedTuple{(:weight, :bias), Tuple{Matrix{Float32}, Matrix{Float32}}}}}}}}, mod::LLVM.Module, primalf::LLVM.Function, adjoint::GPUCompiler.FunctionSpec{var"#1#2", Tuple{Active{NamedTuple{(:layer_1, :layer_2, :layer_3, :layer_4, :layer_5), Tuple{NamedTuple{(:scale, :bias), Tuple{Vector{Float32}, Vector{Float32}}}, NamedTuple{(:weight, :bias), Tuple{Matrix{Float32}, Matrix{Float32}}}, NamedTuple{(:scale, :bias), Tuple{Vector{Float32}, Vector{Float32}}}, NamedTuple{(:weight, :bias), Tuple{Matrix{Float32}, Matrix{Float32}}}, NamedTuple{(:weight, :bias), Tuple{Matrix{Float32}, Matrix{Float32}}}}}}}}, mode::Enzyme.API.CDerivativeMode, width::Int64, parallel::Bool, actualRetType::Type, dupClosure::Bool, wrap::Bool, modifiedBetween::Bool)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/Ctome/src/compiler.jl:3176
  [8] codegen(output::Symbol, job::GPUCompiler.CompilerJob{Enzyme.Compiler.EnzymeTarget, Enzyme.Compiler.EnzymeCompilerParams, GPUCompiler.FunctionSpec{var"#1#2", Tuple{NamedTuple{(:layer_1, :layer_2, :layer_3, :layer_4, :layer_5), Tuple{NamedTuple{(:scale, :bias), Tuple{Vector{Float32}, Vector{Float32}}}, NamedTuple{(:weight, :bias), Tuple{Matrix{Float32}, Matrix{Float32}}}, NamedTuple{(:scale, :bias), Tuple{Vector{Float32}, Vector{Float32}}}, NamedTuple{(:weight, :bias), Tuple{Matrix{Float32}, Matrix{Float32}}}, NamedTuple{(:weight, :bias), Tuple{Matrix{Float32}, Matrix{Float32}}}}}}}}; libraries::Bool, deferred_codegen::Bool, optimize::Bool, ctx::LLVM.Context, strip::Bool, validate::Bool, only_entry::Bool, parent_job::Nothing)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/Ctome/src/compiler.jl:3991
  [9] _thunk(job::GPUCompiler.CompilerJob{Enzyme.Compiler.EnzymeTarget, Enzyme.Compiler.EnzymeCompilerParams, GPUCompiler.FunctionSpec{var"#1#2", Tuple{NamedTuple{(:layer_1, :layer_2, :layer_3, :layer_4, :layer_5), Tuple{NamedTuple{(:scale, :bias), Tuple{Vector{Float32}, Vector{Float32}}}, NamedTuple{(:weight, :bias), Tuple{Matrix{Float32}, Matrix{Float32}}}, NamedTuple{(:scale, :bias), Tuple{Vector{Float32}, Vector{Float32}}}, NamedTuple{(:weight, :bias), Tuple{Matrix{Float32}, Matrix{Float32}}}, NamedTuple{(:weight, :bias), Tuple{Matrix{Float32}, Matrix{Float32}}}}}}}})
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/Ctome/src/compiler.jl:4397
 [10] cached_compilation(job::GPUCompiler.CompilerJob, key::UInt64, specid::UInt64)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/Ctome/src/compiler.jl:4435
 [11] #s512#108
    @ ~/.julia/packages/Enzyme/Ctome/src/compiler.jl:4495 [inlined]
 [12] var"#s512#108"(F::Any, DF::Any, A::Any, TT::Any, Mode::Any, ModifiedBetween::Any, width::Any, specid::Any, ::Any, f::Any, df::Any, #unused#::Type, tt::Any, #unused#::Type, #unused#::Type, #unused#::Type, #unused#::Any)
    @ Enzyme.Compiler ./none:0
 [13] (::Core.GeneratedFunctionStub)(::Any, ::Vararg{Any})
    @ Core ./boot.jl:580
 [14] thunk
    @ ~/.julia/packages/Enzyme/Ctome/src/compiler.jl:4523 [inlined]
 [15] autodiff(f::var"#1#2", #unused#::Type{Active}, args::Active{NamedTuple{(:layer_1, :layer_2, :layer_3, :layer_4, :layer_5), Tuple{NamedTuple{(:scale, :bias), Tuple{Vector{Float32}, Vector{Float32}}}, NamedTuple{(:weight, :bias), Tuple{Matrix{Float32}, Matrix{Float32}}}, NamedTuple{(:scale, :bias), Tuple{Vector{Float32}, Vector{Float32}}}, NamedTuple{(:weight, :bias), Tuple{Matrix{Float32}, Matrix{Float32}}}, NamedTuple{(:weight, :bias), Tuple{Matrix{Float32}, Matrix{Float32}}}}}})
    @ Enzyme ~/.julia/packages/Enzyme/Ctome/src/Enzyme.jl:320
 [16] top-level scope
    @ REPL[12]:1
 [17] top-level scope
    @ ~/.julia/packages/CUDA/fAEDi/src/initialization.jl:52

The performance at that array size is not unexpected. Since the test case is small, there will be a higher allocation overhead (cost for having no mutation in batchnorm). But once you go for larger arrays say x β†’ 128 x 1024 then


julia> @benchmark  gradient(() -> sum(model2(x)), ps2)
BenchmarkTools.Trial: 527 samples with 1 evaluation.
 Range (min … max):  7.678 ms … 15.833 ms  β”Š GC (min … max):  0.00% … 23.92%
 Time  (median):     8.053 ms              β”Š GC (median):     0.00%
 Time  (mean Β± Οƒ):   9.475 ms Β±  2.172 ms  β”Š GC (mean Β± Οƒ):  12.17% Β± 14.10%

  β–β–ˆβ–‡β–„β–‚  ▁                      ▁▁   β–‚ β–‚β–‚ ▁                   
  β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‡β–ˆβ–ˆβ–‡β–ˆβ–‡β–†β–β–β–β–β–„β–β–β–β–„β–β–β–β–β–β–β–†β–„β–ˆβ–ˆβ–ˆβ–ˆβ–‡β–‡β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‡β–„β–†β–‡β–„β–…β–β–„β–β–β–β–β–β–β–„β–‡ β–‡
  7.68 ms      Histogram: log(frequency) by time     15.2 ms <

 Memory estimate: 29.88 MiB, allocs estimate: 1187.

julia> @benchmark gradient(p -> sum(Lux.apply(model, x, p, st)[1]), ps)
BenchmarkTools.Trial: 609 samples with 1 evaluation.
 Range (min … max):  6.767 ms … 15.384 ms  β”Š GC (min … max):  0.00% … 21.73%
 Time  (median):     7.175 ms              β”Š GC (median):     0.00%
 Time  (mean Β± Οƒ):   8.202 ms Β±  1.891 ms  β”Š GC (mean Β± Οƒ):  10.46% Β± 13.38%

  β–‚β–ˆβ–ˆβ–†β–…β–„β–„β–‚                        β–‚ ▁ ▁  ▁ ▁▁                 
  β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–†β–‡β–†β–‡β–„β–β–…β–β–β–„β–β–„β–β–β–β–β–β–β–„β–β–β–…β–ˆβ–ˆβ–ˆβ–ˆβ–†β–ˆβ–…β–‡β–ˆβ–ˆβ–ˆβ–ˆβ–…β–‡β–‡β–ˆβ–‡β–„β–…β–ˆβ–‡β–β–‡β–„β–‡β–‡β–„ β–ˆ
  6.77 ms      Histogram: log(frequency) by time       13 ms <

 Memory estimate: 23.96 MiB, allocs estimate: 2679.

Regarding Enzyme, I will have to test the latest version

Note that what AD systems like Zygote and Diffractor call β€œexplicit parameter” mode is not what Lux refers to as β€œexplicit parameters.” (nor what I think Chris meant in his comment above).

Lux’s explicit parameters means parameters separate from the model structure that behave like a flat vector. Great for SciML use-cases that might involve black box optimization packages that expect the parameters in this form.

What is meant by implicit/explicit parameters in Zygote and Diffractor is whether you explicitly write out what you are taking the gradient w.r.t.

# explicit
gradient(m ->loss(m(x), y), model)

# implicit
gradient(() -> loss(m(x), y), params(model))

Note that model can be any deeply nested struct, not just a β€œvector.”

As far as speed is concerned, there should not be major differences for deep learning cases, since Flux and Lux share the same kernels, optimizers, etc. For some wrapper layers, like Chain, Lux uses generated functions to work around some pain points with Zygote and control flow. So, Lux might be faster due to this especially for compile times or TTFG. Hard to say though cause I’ve found benchmarking Zygote in this regard to vary widely. Lux does have DDP training through MPI though, so it’s worth keeping in mind when you choose a framework.

7 Likes

I will still be maintaining FluxMPI.jl in a way to ensure both the frameworks can leverage it (at least Flux through its destructure API). So there shouldn’t be a significant difference in that regard.

2 Likes

The major speed (and memory) difference at this size seems to be about how the variance is calculated within the normalisation layer, fixed here.

In this example, Lux’s parameters are a nested set of NamedTuples, thus the storage is many individual arrays, exactly like Flux. I think you need ComponentArray(ps) to get a flat vector which looks like the NamedTuples. Or you can use Optimisers.destructure on this nested parameter object (instead of on the Flux model) to get a flat vector; this seems to be a bit more efficient in this example.

5 Likes

I do not know a lot on this topic, and have a couple of questions:

  1. From this post, it seems like explicit parameters are plain better than implicit ones. Are there cases where implicit parameters might be better?
  2. Do other artificial neural network frameworks use explicit parameters? I am thinking of Pytorch, TensorFlow, Keras etc.
  3. I am now unsure about if I should use Flux or Lux for my constructing my next artificial neural network. This uncertainty is a necessary evil of the free exploration being done by packages that do the same thing differently, and it is perfectly fine as Lux is new. But what do you imagine the endgame relationship to be between Lux and Flux? Should Lux become more adopted due to it’s explicit parameters, or should Flux adopt explicit parameters in some future breaking update, or will the packages coexist?
3 Likes

Home Β· Zygote has some of the original motivation for implicit params. With the benefit of hindsight, I think we’ve converged on this path not being the right choice. Hence Flux is now trying to dispose of implicit parameters as soon as possible (you can use it without them already), and Lux never supported them to begin with.

It’s a little hard to classify them because β€œimplicit params” is very much a Flux/Zygote-ism. PyTorch is similar in that gradients are attached to their parameter arrays, but for multiple reasons their system is less error-prone than the Zygote one. If you’re familiar with Tracker (Flux’s old AD), that would be a better comparison. PyTorch and TF/Keras (same library these days) also exposes an β€œexplicit” autograd API which doesn’t accumulate in-place. JAX’s API probably comes the closest to what Flux/Lux with explicit params would look like.

The dream would be to merge both layer systems into a single library. This doesn’t seem technically intractable, but I think it should wait until Flux drops implicit params entirely and Lux has had some more proving time in and outside of the SciML-verse. As for short-medium term collaboration, see my earlier comment above.

On the final question about future breaking updates, the plan is to drop implicit params as soon as possible for Flux v0.14. Many of us would like nothing more to see them go, but there are still a couple design bits to be worked out and one cannot migrate 3+ years of documentation, tutorials and 3rd party code using implicit params overnight :slight_smile:

5 Likes