[ANN] Lux.jl: Explicitly Parameterized Neural Networks in Julia

All sounds great.

I have compared the speed of Flux and Lux on the example from Lux’s readme as

using Flux, Random, Optimisers, Zygote, BenchmarkTools
rng = Random.default_rng()
Random.seed!(rng, 0)
model = Chain(
    BatchNorm(128),
    Dense(128, 256, tanh),
    BatchNorm(256),
    Chain(
        Dense(256, 1, tanh),
        Dense(1, 10)
    )
)

ps = Flux.params(model)
x = rand(rng, Float32, 128, 2)

@btime gradient(() -> sum(model(x)), ps)
  387.708 μs (1141 allocations: 300.03 KiB)

and Lux

using Lux, Random, Optimisers, Zygote, BenchmarkTools
rng = Random.default_rng()
Random.seed!(rng, 0)
model = Chain(
    BatchNorm(128),
    Dense(128, 256, tanh),
    BatchNorm(256),
    Chain(
        Dense(256, 1, tanh),
        Dense(1, 10)
    )
)
ps, st = Lux.setup(rng, model)
x = rand(rng, Float32, 128, 2)
julia> @btime gradient(p -> sum(Lux.apply(model, x, p, st)[1]), ps)
  1.013 ms (2645 allocations: 366.52 KiB)

And flux is a bit faster.

I really like how the models can be create 1:1, no need for changes. I also wanted to try it with diffractor, Finally, I wanted to try Enzyme v0.9.6 as

autodiff(p -> sum(Lux.apply(model, x, p, st)[1]), Active, Active(ps))

but that has failed as well with an error

Error message
ERROR: MethodError: no method matching parent(::LLVM.Argument)
Closest candidates are:
  parent(::LLVM.GlobalValue) at ~/.julia/packages/LLVM/gE6U9/src/core/value/constant.jl:529
  parent(::LLVM.Instruction) at ~/.julia/packages/LLVM/gE6U9/src/core/instructions.jl:45
  parent(::LLVM.BasicBlock) at ~/.julia/packages/LLVM/gE6U9/src/core/basicblock.jl:27
Stacktrace:
  [1] parent_scope(val::LLVM.Argument, depth::Int64) (repeats 2 times)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/Ctome/src/compiler.jl:2570
  [2] (::Enzyme.Compiler.var"#49#50"{LLVM.Argument})(io::IOBuffer)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/Ctome/src/compiler.jl:2585
  [3] sprint(::Function; context::Nothing, sizehint::Int64)
    @ Base ./strings/io.jl:114
  [4] sprint(::Function)
    @ Base ./strings/io.jl:108
  [5] julia_error(cstr::Cstring, val::Ptr{LLVM.API.LLVMOpaqueValue}, errtype::Enzyme.API.ErrorType, data::Ptr{Nothing})
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/Ctome/src/compiler.jl:2585
  [6] EnzymeCreateAugmentedPrimal(logic::Enzyme.Logic, todiff::LLVM.Function, retType::Enzyme.API.CDIFFE_TYPE, constant_args::Vector{Enzyme.API.CDIFFE_TYPE}, TA::Enzyme.TypeAnalysis, returnUsed::Bool, shadowReturnUsed::Bool, typeInfo::Enzyme.FnTypeInfo, uncacheable_args::Vector{Bool}, forceAnonymousTape::Bool, width::Int64, atomicAdd::Bool)
    @ Enzyme.API ~/.julia/packages/Enzyme/Ctome/src/api.jl:147
  [7] enzyme!(job::GPUCompiler.CompilerJob{Enzyme.Compiler.EnzymeTarget, Enzyme.Compiler.EnzymeCompilerParams, GPUCompiler.FunctionSpec{var"#1#2", Tuple{NamedTuple{(:layer_1, :layer_2, :layer_3, :layer_4, :layer_5), Tuple{NamedTuple{(:scale, :bias), Tuple{Vector{Float32}, Vector{Float32}}}, NamedTuple{(:weight, :bias), Tuple{Matrix{Float32}, Matrix{Float32}}}, NamedTuple{(:scale, :bias), Tuple{Vector{Float32}, Vector{Float32}}}, NamedTuple{(:weight, :bias), Tuple{Matrix{Float32}, Matrix{Float32}}}, NamedTuple{(:weight, :bias), Tuple{Matrix{Float32}, Matrix{Float32}}}}}}}}, mod::LLVM.Module, primalf::LLVM.Function, adjoint::GPUCompiler.FunctionSpec{var"#1#2", Tuple{Active{NamedTuple{(:layer_1, :layer_2, :layer_3, :layer_4, :layer_5), Tuple{NamedTuple{(:scale, :bias), Tuple{Vector{Float32}, Vector{Float32}}}, NamedTuple{(:weight, :bias), Tuple{Matrix{Float32}, Matrix{Float32}}}, NamedTuple{(:scale, :bias), Tuple{Vector{Float32}, Vector{Float32}}}, NamedTuple{(:weight, :bias), Tuple{Matrix{Float32}, Matrix{Float32}}}, NamedTuple{(:weight, :bias), Tuple{Matrix{Float32}, Matrix{Float32}}}}}}}}, mode::Enzyme.API.CDerivativeMode, width::Int64, parallel::Bool, actualRetType::Type, dupClosure::Bool, wrap::Bool, modifiedBetween::Bool)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/Ctome/src/compiler.jl:3176
  [8] codegen(output::Symbol, job::GPUCompiler.CompilerJob{Enzyme.Compiler.EnzymeTarget, Enzyme.Compiler.EnzymeCompilerParams, GPUCompiler.FunctionSpec{var"#1#2", Tuple{NamedTuple{(:layer_1, :layer_2, :layer_3, :layer_4, :layer_5), Tuple{NamedTuple{(:scale, :bias), Tuple{Vector{Float32}, Vector{Float32}}}, NamedTuple{(:weight, :bias), Tuple{Matrix{Float32}, Matrix{Float32}}}, NamedTuple{(:scale, :bias), Tuple{Vector{Float32}, Vector{Float32}}}, NamedTuple{(:weight, :bias), Tuple{Matrix{Float32}, Matrix{Float32}}}, NamedTuple{(:weight, :bias), Tuple{Matrix{Float32}, Matrix{Float32}}}}}}}}; libraries::Bool, deferred_codegen::Bool, optimize::Bool, ctx::LLVM.Context, strip::Bool, validate::Bool, only_entry::Bool, parent_job::Nothing)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/Ctome/src/compiler.jl:3991
  [9] _thunk(job::GPUCompiler.CompilerJob{Enzyme.Compiler.EnzymeTarget, Enzyme.Compiler.EnzymeCompilerParams, GPUCompiler.FunctionSpec{var"#1#2", Tuple{NamedTuple{(:layer_1, :layer_2, :layer_3, :layer_4, :layer_5), Tuple{NamedTuple{(:scale, :bias), Tuple{Vector{Float32}, Vector{Float32}}}, NamedTuple{(:weight, :bias), Tuple{Matrix{Float32}, Matrix{Float32}}}, NamedTuple{(:scale, :bias), Tuple{Vector{Float32}, Vector{Float32}}}, NamedTuple{(:weight, :bias), Tuple{Matrix{Float32}, Matrix{Float32}}}, NamedTuple{(:weight, :bias), Tuple{Matrix{Float32}, Matrix{Float32}}}}}}}})
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/Ctome/src/compiler.jl:4397
 [10] cached_compilation(job::GPUCompiler.CompilerJob, key::UInt64, specid::UInt64)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/Ctome/src/compiler.jl:4435
 [11] #s512#108
    @ ~/.julia/packages/Enzyme/Ctome/src/compiler.jl:4495 [inlined]
 [12] var"#s512#108"(F::Any, DF::Any, A::Any, TT::Any, Mode::Any, ModifiedBetween::Any, width::Any, specid::Any, ::Any, f::Any, df::Any, #unused#::Type, tt::Any, #unused#::Type, #unused#::Type, #unused#::Type, #unused#::Any)
    @ Enzyme.Compiler ./none:0
 [13] (::Core.GeneratedFunctionStub)(::Any, ::Vararg{Any})
    @ Core ./boot.jl:580
 [14] thunk
    @ ~/.julia/packages/Enzyme/Ctome/src/compiler.jl:4523 [inlined]
 [15] autodiff(f::var"#1#2", #unused#::Type{Active}, args::Active{NamedTuple{(:layer_1, :layer_2, :layer_3, :layer_4, :layer_5), Tuple{NamedTuple{(:scale, :bias), Tuple{Vector{Float32}, Vector{Float32}}}, NamedTuple{(:weight, :bias), Tuple{Matrix{Float32}, Matrix{Float32}}}, NamedTuple{(:scale, :bias), Tuple{Vector{Float32}, Vector{Float32}}}, NamedTuple{(:weight, :bias), Tuple{Matrix{Float32}, Matrix{Float32}}}, NamedTuple{(:weight, :bias), Tuple{Matrix{Float32}, Matrix{Float32}}}}}})
    @ Enzyme ~/.julia/packages/Enzyme/Ctome/src/Enzyme.jl:320
 [16] top-level scope
    @ REPL[12]:1
 [17] top-level scope
    @ ~/.julia/packages/CUDA/fAEDi/src/initialization.jl:52