[ANN] Lux.jl: Explicitly Parameterized Neural Networks in Julia

Lux is a new Julia deep learning framework that decouples models and parameterization using deeply nested named tuples.

  • Functional Layer API – Pure Functions and Deterministic Function Calls.
  • No more implicit parameterization
  • Compiler and AD-friendly Neural Networks
using Lux, Random, Optimisers, Zygote

# Seeding
rng = Random.default_rng()
Random.seed!(rng, 0)

# Construct the layer
model = Chain(
    BatchNorm(128),
    Dense(128, 256, tanh),
    BatchNorm(256),
    Chain(
        Dense(256, 1, tanh),
        Dense(1, 10)
    )
)

# Parameter and State Variables
ps, st = Lux.setup(rng, model) .|> gpu

# Dummy Input
x = rand(rng, Float32, 128, 2) |> gpu

# Run the model
y, st = Lux.apply(model, x, ps, st)

# Gradients
gs = gradient(p -> sum(Lux.apply(model, x, p, st)[1]), ps)[1]

# Optimization
st_opt = Optimisers.setup(Optimisers.ADAM(0.0001), ps)
st_opt, ps = Optimisers.update(st_opt, ps, gs)

To learn more about the design principle, see https://lux.csail.mit.edu/dev/introduction/overview/. For tutorials see the documentation and examples

We also provide a bunch of pretrained model weights for computer vision applications via another package Boltz scheduled to be released in a few days.

Why Lux when we already had Flux?

Lux is essentially Flux without all forms of internal mutations and a decoupled model-parameter state (and more type-stable). Quoting from SciMLStyle

Mutating codes and non-mutating codes fall into different worlds. When a code is fully immutable, the compiler can better reason about dependencies, optimize the code, and check for correctness. However, many times a code making the fullest use of mutation can outperform even what the best compilers of today can generate. That said, the worst of all worlds is when code mixes mutation with non-mutating code. Not only is this a mishmash of coding styles, it has the potential non-locality and compiler-proof issues of mutating code while not fully benefiting from the mutation.

Finally, one cherry-picked point from State of machine learning in Julia

54 Likes

This was a tremendous effort. I like the functional point of view.

May I ask if it has any effect for example on speed? I went through docs and you write that functional view is “nicer” to julia Optimizer.

Another question is, would Lux.jl stimulate the adoption of Diffractor, as according to my understanding, one problem preventing the use of Diffractor in Flux was implicit parameters.

And another question. Is there a chance / possibility to write modules that would be compatible with Lux and Flux?

Thanks for answers in advance.

1 Like

There are some improvements but that is mostly due to fixing some type-stability issues and such, which are not Lux specific and once I get some time I will try to upstream those to NNlib.

The "nice"ss I mentioned will come from things like Immutable Arrays · Issue #8 · LuxDL/Lux.jl · GitHub once they are ready.

I have not tested Diffractor so cant say much on that topic. However, I have tested out Enzyme (with its new BLAS support and stuff) and we will mostly be moving the enzyme route once the rules system is ready

Nothing in Flux prevents using a functional Layer so something like

struct WrappedDense
    d::Lux.Dense
    ps
    st
end

@functor WrappedDense

trainable(wd::Wra...) = wd.ps

(wd::WrappedDense)(x) = (wd.d)(x, wd.ps, wd.st)

So it is possible to go from Lux → Flux side but the other direction though possible is strongly not recommended

4 Likes

Adding on to what Avik said:

As of Flux 0.13, there is official support for training models end-to-end without ever touching implicit params. Indeed Flux and Lux share most of the core code that enables this (namely Optimisers.jl and Zygote’s “explicit parameters” mode). The only reason implicit params haven’t been dropped completely is that it would be one of the biggest compat breaks in recent history.

In addition to the WrappedDense example, you can implement the Lux AbstractExplicitLayer API for layers you control. Here’s a (not robust or comprehensive) example for Flux.Dense:

initialparameters(rng::AbstractRNG, d::Dense) = (weight=d.weight, bias=d.bias)
# or 
initialparameters(rng::AbstractRNG, d::Dense) = (weight=Flux.glorot_uniform(size(d.weight...), bias=d.bias)

# Alternatively, overload Lux.apply
function (d::Dense)(x::AbstractVecOrMat, ps::Union{ComponentArray, NamedTuple}, ::NamedTuple)
  σ = NNlib.fast_act(d.σ, x)  # replaces tanh => tanh_fast, etc
  return σ.(ps.weight * x .+ ps.bias)
end

Since neither library enforces an abstract layer supertype, no additional dependencies should be required.

Longer term, the hope is to pull out as much common functionality as possible from both libraries. This includes functional routines for NNlib, initializations, loss functions and possibly even part of the layer interface. For example, we’ve talked about having an apply function in Flux for some time now (in large part to avoid mutable layer structs and any mutation in the forward pass), but it was never seriously attempted because of the backwards compat headache. Now that Lux has blown that door open, we can better understand the implications of such an API and how to “backport” it.

9 Likes

All sounds great.

I have compared the speed of Flux and Lux on the example from Lux’s readme as

using Flux, Random, Optimisers, Zygote, BenchmarkTools
rng = Random.default_rng()
Random.seed!(rng, 0)
model = Chain(
    BatchNorm(128),
    Dense(128, 256, tanh),
    BatchNorm(256),
    Chain(
        Dense(256, 1, tanh),
        Dense(1, 10)
    )
)

ps = Flux.params(model)
x = rand(rng, Float32, 128, 2)

@btime gradient(() -> sum(model(x)), ps)
  387.708 μs (1141 allocations: 300.03 KiB)

and Lux

using Lux, Random, Optimisers, Zygote, BenchmarkTools
rng = Random.default_rng()
Random.seed!(rng, 0)
model = Chain(
    BatchNorm(128),
    Dense(128, 256, tanh),
    BatchNorm(256),
    Chain(
        Dense(256, 1, tanh),
        Dense(1, 10)
    )
)
ps, st = Lux.setup(rng, model)
x = rand(rng, Float32, 128, 2)
julia> @btime gradient(p -> sum(Lux.apply(model, x, p, st)[1]), ps)
  1.013 ms (2645 allocations: 366.52 KiB)

And flux is a bit faster.

I really like how the models can be create 1:1, no need for changes. I also wanted to try it with diffractor, Finally, I wanted to try Enzyme v0.9.6 as

autodiff(p -> sum(Lux.apply(model, x, p, st)[1]), Active, Active(ps))

but that has failed as well with an error

Error message
ERROR: MethodError: no method matching parent(::LLVM.Argument)
Closest candidates are:
  parent(::LLVM.GlobalValue) at ~/.julia/packages/LLVM/gE6U9/src/core/value/constant.jl:529
  parent(::LLVM.Instruction) at ~/.julia/packages/LLVM/gE6U9/src/core/instructions.jl:45
  parent(::LLVM.BasicBlock) at ~/.julia/packages/LLVM/gE6U9/src/core/basicblock.jl:27
Stacktrace:
  [1] parent_scope(val::LLVM.Argument, depth::Int64) (repeats 2 times)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/Ctome/src/compiler.jl:2570
  [2] (::Enzyme.Compiler.var"#49#50"{LLVM.Argument})(io::IOBuffer)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/Ctome/src/compiler.jl:2585
  [3] sprint(::Function; context::Nothing, sizehint::Int64)
    @ Base ./strings/io.jl:114
  [4] sprint(::Function)
    @ Base ./strings/io.jl:108
  [5] julia_error(cstr::Cstring, val::Ptr{LLVM.API.LLVMOpaqueValue}, errtype::Enzyme.API.ErrorType, data::Ptr{Nothing})
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/Ctome/src/compiler.jl:2585
  [6] EnzymeCreateAugmentedPrimal(logic::Enzyme.Logic, todiff::LLVM.Function, retType::Enzyme.API.CDIFFE_TYPE, constant_args::Vector{Enzyme.API.CDIFFE_TYPE}, TA::Enzyme.TypeAnalysis, returnUsed::Bool, shadowReturnUsed::Bool, typeInfo::Enzyme.FnTypeInfo, uncacheable_args::Vector{Bool}, forceAnonymousTape::Bool, width::Int64, atomicAdd::Bool)
    @ Enzyme.API ~/.julia/packages/Enzyme/Ctome/src/api.jl:147
  [7] enzyme!(job::GPUCompiler.CompilerJob{Enzyme.Compiler.EnzymeTarget, Enzyme.Compiler.EnzymeCompilerParams, GPUCompiler.FunctionSpec{var"#1#2", Tuple{NamedTuple{(:layer_1, :layer_2, :layer_3, :layer_4, :layer_5), Tuple{NamedTuple{(:scale, :bias), Tuple{Vector{Float32}, Vector{Float32}}}, NamedTuple{(:weight, :bias), Tuple{Matrix{Float32}, Matrix{Float32}}}, NamedTuple{(:scale, :bias), Tuple{Vector{Float32}, Vector{Float32}}}, NamedTuple{(:weight, :bias), Tuple{Matrix{Float32}, Matrix{Float32}}}, NamedTuple{(:weight, :bias), Tuple{Matrix{Float32}, Matrix{Float32}}}}}}}}, mod::LLVM.Module, primalf::LLVM.Function, adjoint::GPUCompiler.FunctionSpec{var"#1#2", Tuple{Active{NamedTuple{(:layer_1, :layer_2, :layer_3, :layer_4, :layer_5), Tuple{NamedTuple{(:scale, :bias), Tuple{Vector{Float32}, Vector{Float32}}}, NamedTuple{(:weight, :bias), Tuple{Matrix{Float32}, Matrix{Float32}}}, NamedTuple{(:scale, :bias), Tuple{Vector{Float32}, Vector{Float32}}}, NamedTuple{(:weight, :bias), Tuple{Matrix{Float32}, Matrix{Float32}}}, NamedTuple{(:weight, :bias), Tuple{Matrix{Float32}, Matrix{Float32}}}}}}}}, mode::Enzyme.API.CDerivativeMode, width::Int64, parallel::Bool, actualRetType::Type, dupClosure::Bool, wrap::Bool, modifiedBetween::Bool)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/Ctome/src/compiler.jl:3176
  [8] codegen(output::Symbol, job::GPUCompiler.CompilerJob{Enzyme.Compiler.EnzymeTarget, Enzyme.Compiler.EnzymeCompilerParams, GPUCompiler.FunctionSpec{var"#1#2", Tuple{NamedTuple{(:layer_1, :layer_2, :layer_3, :layer_4, :layer_5), Tuple{NamedTuple{(:scale, :bias), Tuple{Vector{Float32}, Vector{Float32}}}, NamedTuple{(:weight, :bias), Tuple{Matrix{Float32}, Matrix{Float32}}}, NamedTuple{(:scale, :bias), Tuple{Vector{Float32}, Vector{Float32}}}, NamedTuple{(:weight, :bias), Tuple{Matrix{Float32}, Matrix{Float32}}}, NamedTuple{(:weight, :bias), Tuple{Matrix{Float32}, Matrix{Float32}}}}}}}}; libraries::Bool, deferred_codegen::Bool, optimize::Bool, ctx::LLVM.Context, strip::Bool, validate::Bool, only_entry::Bool, parent_job::Nothing)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/Ctome/src/compiler.jl:3991
  [9] _thunk(job::GPUCompiler.CompilerJob{Enzyme.Compiler.EnzymeTarget, Enzyme.Compiler.EnzymeCompilerParams, GPUCompiler.FunctionSpec{var"#1#2", Tuple{NamedTuple{(:layer_1, :layer_2, :layer_3, :layer_4, :layer_5), Tuple{NamedTuple{(:scale, :bias), Tuple{Vector{Float32}, Vector{Float32}}}, NamedTuple{(:weight, :bias), Tuple{Matrix{Float32}, Matrix{Float32}}}, NamedTuple{(:scale, :bias), Tuple{Vector{Float32}, Vector{Float32}}}, NamedTuple{(:weight, :bias), Tuple{Matrix{Float32}, Matrix{Float32}}}, NamedTuple{(:weight, :bias), Tuple{Matrix{Float32}, Matrix{Float32}}}}}}}})
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/Ctome/src/compiler.jl:4397
 [10] cached_compilation(job::GPUCompiler.CompilerJob, key::UInt64, specid::UInt64)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/Ctome/src/compiler.jl:4435
 [11] #s512#108
    @ ~/.julia/packages/Enzyme/Ctome/src/compiler.jl:4495 [inlined]
 [12] var"#s512#108"(F::Any, DF::Any, A::Any, TT::Any, Mode::Any, ModifiedBetween::Any, width::Any, specid::Any, ::Any, f::Any, df::Any, #unused#::Type, tt::Any, #unused#::Type, #unused#::Type, #unused#::Type, #unused#::Any)
    @ Enzyme.Compiler ./none:0
 [13] (::Core.GeneratedFunctionStub)(::Any, ::Vararg{Any})
    @ Core ./boot.jl:580
 [14] thunk
    @ ~/.julia/packages/Enzyme/Ctome/src/compiler.jl:4523 [inlined]
 [15] autodiff(f::var"#1#2", #unused#::Type{Active}, args::Active{NamedTuple{(:layer_1, :layer_2, :layer_3, :layer_4, :layer_5), Tuple{NamedTuple{(:scale, :bias), Tuple{Vector{Float32}, Vector{Float32}}}, NamedTuple{(:weight, :bias), Tuple{Matrix{Float32}, Matrix{Float32}}}, NamedTuple{(:scale, :bias), Tuple{Vector{Float32}, Vector{Float32}}}, NamedTuple{(:weight, :bias), Tuple{Matrix{Float32}, Matrix{Float32}}}, NamedTuple{(:weight, :bias), Tuple{Matrix{Float32}, Matrix{Float32}}}}}})
    @ Enzyme ~/.julia/packages/Enzyme/Ctome/src/Enzyme.jl:320
 [16] top-level scope
    @ REPL[12]:1
 [17] top-level scope
    @ ~/.julia/packages/CUDA/fAEDi/src/initialization.jl:52

The performance at that array size is not unexpected. Since the test case is small, there will be a higher allocation overhead (cost for having no mutation in batchnorm). But once you go for larger arrays say x → 128 x 1024 then


julia> @benchmark  gradient(() -> sum(model2(x)), ps2)
BenchmarkTools.Trial: 527 samples with 1 evaluation.
 Range (min … max):  7.678 ms … 15.833 ms  ┊ GC (min … max):  0.00% … 23.92%
 Time  (median):     8.053 ms              ┊ GC (median):     0.00%
 Time  (mean ± σ):   9.475 ms ±  2.172 ms  ┊ GC (mean ± σ):  12.17% ± 14.10%

  ▁█▇▄▂  ▁                      ▁▁   ▂ ▂▂ ▁                   
  █████▇██▇█▇▆▁▁▁▁▄▁▁▁▄▁▁▁▁▁▁▆▄████▇▇███████▇▄▆▇▄▅▁▄▁▁▁▁▁▁▄▇ ▇
  7.68 ms      Histogram: log(frequency) by time     15.2 ms <

 Memory estimate: 29.88 MiB, allocs estimate: 1187.

julia> @benchmark gradient(p -> sum(Lux.apply(model, x, p, st)[1]), ps)
BenchmarkTools.Trial: 609 samples with 1 evaluation.
 Range (min … max):  6.767 ms … 15.384 ms  ┊ GC (min … max):  0.00% … 21.73%
 Time  (median):     7.175 ms              ┊ GC (median):     0.00%
 Time  (mean ± σ):   8.202 ms ±  1.891 ms  ┊ GC (mean ± σ):  10.46% ± 13.38%

  ▂██▆▅▄▄▂                        ▂ ▁ ▁  ▁ ▁▁                 
  █████████▆▇▆▇▄▁▅▁▁▄▁▄▁▁▁▁▁▁▄▁▁▅████▆█▅▇████▅▇▇█▇▄▅█▇▁▇▄▇▇▄ █
  6.77 ms      Histogram: log(frequency) by time       13 ms <

 Memory estimate: 23.96 MiB, allocs estimate: 2679.

Regarding Enzyme, I will have to test the latest version

Note that what AD systems like Zygote and Diffractor call “explicit parameter” mode is not what Lux refers to as “explicit parameters.” (nor what I think Chris meant in his comment above).

Lux’s explicit parameters means parameters separate from the model structure that behave like a flat vector. Great for SciML use-cases that might involve black box optimization packages that expect the parameters in this form.

What is meant by implicit/explicit parameters in Zygote and Diffractor is whether you explicitly write out what you are taking the gradient w.r.t.

# explicit
gradient(m ->loss(m(x), y), model)

# implicit
gradient(() -> loss(m(x), y), params(model))

Note that model can be any deeply nested struct, not just a “vector.”

As far as speed is concerned, there should not be major differences for deep learning cases, since Flux and Lux share the same kernels, optimizers, etc. For some wrapper layers, like Chain, Lux uses generated functions to work around some pain points with Zygote and control flow. So, Lux might be faster due to this especially for compile times or TTFG. Hard to say though cause I’ve found benchmarking Zygote in this regard to vary widely. Lux does have DDP training through MPI though, so it’s worth keeping in mind when you choose a framework.

7 Likes

I will still be maintaining FluxMPI.jl in a way to ensure both the frameworks can leverage it (at least Flux through its destructure API). So there shouldn’t be a significant difference in that regard.

3 Likes

The major speed (and memory) difference at this size seems to be about how the variance is calculated within the normalisation layer, fixed here.

In this example, Lux’s parameters are a nested set of NamedTuples, thus the storage is many individual arrays, exactly like Flux. I think you need ComponentArray(ps) to get a flat vector which looks like the NamedTuples. Or you can use Optimisers.destructure on this nested parameter object (instead of on the Flux model) to get a flat vector; this seems to be a bit more efficient in this example.

6 Likes

I do not know a lot on this topic, and have a couple of questions:

  1. From this post, it seems like explicit parameters are plain better than implicit ones. Are there cases where implicit parameters might be better?
  2. Do other artificial neural network frameworks use explicit parameters? I am thinking of Pytorch, TensorFlow, Keras etc.
  3. I am now unsure about if I should use Flux or Lux for my constructing my next artificial neural network. This uncertainty is a necessary evil of the free exploration being done by packages that do the same thing differently, and it is perfectly fine as Lux is new. But what do you imagine the endgame relationship to be between Lux and Flux? Should Lux become more adopted due to it’s explicit parameters, or should Flux adopt explicit parameters in some future breaking update, or will the packages coexist?
4 Likes

Home · Zygote has some of the original motivation for implicit params. With the benefit of hindsight, I think we’ve converged on this path not being the right choice. Hence Flux is now trying to dispose of implicit parameters as soon as possible (you can use it without them already), and Lux never supported them to begin with.

It’s a little hard to classify them because “implicit params” is very much a Flux/Zygote-ism. PyTorch is similar in that gradients are attached to their parameter arrays, but for multiple reasons their system is less error-prone than the Zygote one. If you’re familiar with Tracker (Flux’s old AD), that would be a better comparison. PyTorch and TF/Keras (same library these days) also exposes an “explicit” autograd API which doesn’t accumulate in-place. JAX’s API probably comes the closest to what Flux/Lux with explicit params would look like.

The dream would be to merge both layer systems into a single library. This doesn’t seem technically intractable, but I think it should wait until Flux drops implicit params entirely and Lux has had some more proving time in and outside of the SciML-verse. As for short-medium term collaboration, see my earlier comment above.

On the final question about future breaking updates, the plan is to drop implicit params as soon as possible for Flux v0.14. Many of us would like nothing more to see them go, but there are still a couple design bits to be worked out and one cannot migrate 3+ years of documentation, tutorials and 3rd party code using implicit params overnight :slight_smile:

8 Likes

With the functional/immutable data angle this strikes me as similar in spirit to JAX/Haiku. Am I wrong?

1 Like

That is right.

1 Like

I am revisiting Lux vs Flux, and it seems like Lux is gaining stars slower, and generally only has one main contributor (@avikpal ). Also, the code is generally more verbose in the Lux examples, due to the explicit setup, and the direct usage of 2 extra packages for AD for gradient calculation and subsequent optimisation (as opposed to a single call to train!).

With these off-putting facts, I do not see myself using Lux, despite liking the sound of it. I just wanted to share my experience, as I feel like Lux would deserve more love than I see it getting. But when the rubber hits the road, I would also choose Flux.

What are reasons (1) today and (2) 1-2 years from now to choose Lux vs Flux ?

I’m really happy with Lux so far but statements as above worry me a bit :slight_smile:

(1) Entire SciML Ecosystem defaults to it.
(2) Entire SciML Ecosystem will keep defaulting to it.

4 Likes

Thats a pretty nice endorsement :slight_smile: - I think you told me that before, sorry for the doubt.

Every six months we do a maintenance round through SciML. In the past, this usually took about 3 weeks to get most of the packages up to green on everything including GPUs and all doc examples again. The vast majority of that time was generally fixing weird Zygote and Flux interactions. This year, it took 4 days. One of those days was handling Flux data loader changes (i.e. rebasing and further testing Fix data loader in MNIST Neural ODE example by adrhill · Pull Request #835 · SciML/DiffEqFlux.jl · GitHub), along with Flux changes related to the eltype problem (this is related to Potential gradient issues with Flux chains when changing parameter type · Issue #533 · SciML/NeuralPDE.jl · GitHub, Consistency in the type behavior of restructure · Issue #95 · FluxML/Optimisers.jl · GitHub), and especially https://github.com/FluxML/Flux.jl/pull/2156 which broke a lot of our tests. So in the last update round we had quite the issue getting things up to speed. But this time around, we just updated those tests to Lux, and we had no issues with Lux. In fact, the only doc example that didn’t get fixed is the one that we couldn’t update to Lux because it’s the doc example that says we still support Flux (and shows how to do it), and this example just spits out NaNs (Neural Ordinary Differential Equations with Flux produces NaNs · Issue #859 · SciML/SciMLSensitivity.jl · GitHub). This isn’t a knock on Flux though (as I’ll explain further down), it’s that the API assumptions of Flux don’t match that of what we generally expect from a function in terms of type promotion and genericness, and thus it needs to be special-cased.

I think the key is that Flux is fine if you’re sticking to its “normal” interface, but if you want to get arrays of parameters to say interface with C functions (i.e. optimizers that aren’t written in Julia) or nonlinear stuff in the way that SciML does all of the time, destructure and restructure are a much less tested interface in Flux and have some really odd behaviors that make it hard to maintain. Lux makes it really easy to just get parameter vectors and give it a vector of parameters (via ComponentArrays), and so this interfacing is well-supported, maintained, and ergonomic. As a result, the SciML codes are much more stable using Lux than Flux for what we do.

But is it fracturing the ecosystem?

In the end they are the same kernels under the hood so it doesn’t really matter. They both use NNLib.jl, CUDA.jl, the same JIT compiler, etc. So a good way to describe it interface-wise is it’s Jax/Haiku vs PyTorch. However, PyTorch and Jax have a different numpy definition, different GPU kernels, different JIT compilers, different contributors and AD systems and … they are completely different.

Lux is not much different from Flux. They are two different APIs on the same background. If you use the same neural network in the two you should get the same answer. I think focusing on the unification of underlying tooling in the ML space makes a lot more sense than focusing on the unification of the higher level API because in this domain, the set of choices people make it fairly limited but how it’s used varies drastically. There’s dense networks, convnets, recurrent models, transformers, graph nets, etc. there’s basically a fixed list of layers that people spend forever optimizing, and then use in different ways. All of that work is shared between the two projects.

The reason for the different APIs is that there’s some major differences in expectation. Even the Flux change that I disagree with the most and absolutely wrecked havoc on our SciMLSensitivity tests (Match layer output to weights by mcabbott · Pull Request #2156 · FluxML/Flux.jl · GitHub) has a very good justification for the majority of the people that Flux is targeted at. It sees that when someone “accidentally” puts a Float64 value into a Float32 network, they likely did so on accident and that can be a big performance accident on GPUs. Flux is targeting naïve high-level standard deep learning users who most likely are running on standard GPUs, so for that use case that choice makes sense. However, in SciML we like to just treat neural networks as functions within generic code, many times using Float64 values and such, and so having an operator which acts differently from standard Julia promotion semantics is odd and can easily lead to some issues that are hard to debug. SciML has different use cases: we can non-Julia optimizers all of the time that we want to interface with arrays, we do linear algebra on Jacobians and Hessians of the network parameters for implicit methods so we want easy access to the flattened structure often. Flux is an interface that targets other use cases.

But with Lux vs Flux, this is just a difference of high-level API semantics and not re-engineering. I still talk to the same people, open MWEs on the same AD library, contribute back to the same CUDA.jl. We’re all still in the same world, sharing the internals but not sharing the same expectations on the user and thus not the same high level API. We don’t need to impose the semantics and assumptions of SciML users onto Flux in order to work together on what really matters. I think this is a fantastic way for us to be exploring the space of deep learning.

37 Likes

Thanks Chris for the great summary of your experience and reasons to use Lux, that’s helpful, and justifies our focus on Lux too.

2 Likes

Absolutely. I have known that a lot of internals are shared, but you perspective still helps a lot in understanding why both libraries exist and when to use which one. Thanks!