All sounds great.
I have compared the speed of Flux and Lux on the example from Lux’s readme as
using Flux, Random, Optimisers, Zygote, BenchmarkTools
rng = Random.default_rng()
Random.seed!(rng, 0)
model = Chain(
BatchNorm(128),
Dense(128, 256, tanh),
BatchNorm(256),
Chain(
Dense(256, 1, tanh),
Dense(1, 10)
)
)
ps = Flux.params(model)
x = rand(rng, Float32, 128, 2)
@btime gradient(() -> sum(model(x)), ps)
387.708 μs (1141 allocations: 300.03 KiB)
and Lux
using Lux, Random, Optimisers, Zygote, BenchmarkTools
rng = Random.default_rng()
Random.seed!(rng, 0)
model = Chain(
BatchNorm(128),
Dense(128, 256, tanh),
BatchNorm(256),
Chain(
Dense(256, 1, tanh),
Dense(1, 10)
)
)
ps, st = Lux.setup(rng, model)
x = rand(rng, Float32, 128, 2)
julia> @btime gradient(p -> sum(Lux.apply(model, x, p, st)[1]), ps)
1.013 ms (2645 allocations: 366.52 KiB)
And flux is a bit faster.
I really like how the models can be create 1:1, no need for changes. I also wanted to try it with diffractor, Finally, I wanted to try Enzyme v0.9.6 as
autodiff(p -> sum(Lux.apply(model, x, p, st)[1]), Active, Active(ps))
but that has failed as well with an error
Error message
ERROR: MethodError: no method matching parent(::LLVM.Argument)
Closest candidates are:
parent(::LLVM.GlobalValue) at ~/.julia/packages/LLVM/gE6U9/src/core/value/constant.jl:529
parent(::LLVM.Instruction) at ~/.julia/packages/LLVM/gE6U9/src/core/instructions.jl:45
parent(::LLVM.BasicBlock) at ~/.julia/packages/LLVM/gE6U9/src/core/basicblock.jl:27
Stacktrace:
[1] parent_scope(val::LLVM.Argument, depth::Int64) (repeats 2 times)
@ Enzyme.Compiler ~/.julia/packages/Enzyme/Ctome/src/compiler.jl:2570
[2] (::Enzyme.Compiler.var"#49#50"{LLVM.Argument})(io::IOBuffer)
@ Enzyme.Compiler ~/.julia/packages/Enzyme/Ctome/src/compiler.jl:2585
[3] sprint(::Function; context::Nothing, sizehint::Int64)
@ Base ./strings/io.jl:114
[4] sprint(::Function)
@ Base ./strings/io.jl:108
[5] julia_error(cstr::Cstring, val::Ptr{LLVM.API.LLVMOpaqueValue}, errtype::Enzyme.API.ErrorType, data::Ptr{Nothing})
@ Enzyme.Compiler ~/.julia/packages/Enzyme/Ctome/src/compiler.jl:2585
[6] EnzymeCreateAugmentedPrimal(logic::Enzyme.Logic, todiff::LLVM.Function, retType::Enzyme.API.CDIFFE_TYPE, constant_args::Vector{Enzyme.API.CDIFFE_TYPE}, TA::Enzyme.TypeAnalysis, returnUsed::Bool, shadowReturnUsed::Bool, typeInfo::Enzyme.FnTypeInfo, uncacheable_args::Vector{Bool}, forceAnonymousTape::Bool, width::Int64, atomicAdd::Bool)
@ Enzyme.API ~/.julia/packages/Enzyme/Ctome/src/api.jl:147
[7] enzyme!(job::GPUCompiler.CompilerJob{Enzyme.Compiler.EnzymeTarget, Enzyme.Compiler.EnzymeCompilerParams, GPUCompiler.FunctionSpec{var"#1#2", Tuple{NamedTuple{(:layer_1, :layer_2, :layer_3, :layer_4, :layer_5), Tuple{NamedTuple{(:scale, :bias), Tuple{Vector{Float32}, Vector{Float32}}}, NamedTuple{(:weight, :bias), Tuple{Matrix{Float32}, Matrix{Float32}}}, NamedTuple{(:scale, :bias), Tuple{Vector{Float32}, Vector{Float32}}}, NamedTuple{(:weight, :bias), Tuple{Matrix{Float32}, Matrix{Float32}}}, NamedTuple{(:weight, :bias), Tuple{Matrix{Float32}, Matrix{Float32}}}}}}}}, mod::LLVM.Module, primalf::LLVM.Function, adjoint::GPUCompiler.FunctionSpec{var"#1#2", Tuple{Active{NamedTuple{(:layer_1, :layer_2, :layer_3, :layer_4, :layer_5), Tuple{NamedTuple{(:scale, :bias), Tuple{Vector{Float32}, Vector{Float32}}}, NamedTuple{(:weight, :bias), Tuple{Matrix{Float32}, Matrix{Float32}}}, NamedTuple{(:scale, :bias), Tuple{Vector{Float32}, Vector{Float32}}}, NamedTuple{(:weight, :bias), Tuple{Matrix{Float32}, Matrix{Float32}}}, NamedTuple{(:weight, :bias), Tuple{Matrix{Float32}, Matrix{Float32}}}}}}}}, mode::Enzyme.API.CDerivativeMode, width::Int64, parallel::Bool, actualRetType::Type, dupClosure::Bool, wrap::Bool, modifiedBetween::Bool)
@ Enzyme.Compiler ~/.julia/packages/Enzyme/Ctome/src/compiler.jl:3176
[8] codegen(output::Symbol, job::GPUCompiler.CompilerJob{Enzyme.Compiler.EnzymeTarget, Enzyme.Compiler.EnzymeCompilerParams, GPUCompiler.FunctionSpec{var"#1#2", Tuple{NamedTuple{(:layer_1, :layer_2, :layer_3, :layer_4, :layer_5), Tuple{NamedTuple{(:scale, :bias), Tuple{Vector{Float32}, Vector{Float32}}}, NamedTuple{(:weight, :bias), Tuple{Matrix{Float32}, Matrix{Float32}}}, NamedTuple{(:scale, :bias), Tuple{Vector{Float32}, Vector{Float32}}}, NamedTuple{(:weight, :bias), Tuple{Matrix{Float32}, Matrix{Float32}}}, NamedTuple{(:weight, :bias), Tuple{Matrix{Float32}, Matrix{Float32}}}}}}}}; libraries::Bool, deferred_codegen::Bool, optimize::Bool, ctx::LLVM.Context, strip::Bool, validate::Bool, only_entry::Bool, parent_job::Nothing)
@ Enzyme.Compiler ~/.julia/packages/Enzyme/Ctome/src/compiler.jl:3991
[9] _thunk(job::GPUCompiler.CompilerJob{Enzyme.Compiler.EnzymeTarget, Enzyme.Compiler.EnzymeCompilerParams, GPUCompiler.FunctionSpec{var"#1#2", Tuple{NamedTuple{(:layer_1, :layer_2, :layer_3, :layer_4, :layer_5), Tuple{NamedTuple{(:scale, :bias), Tuple{Vector{Float32}, Vector{Float32}}}, NamedTuple{(:weight, :bias), Tuple{Matrix{Float32}, Matrix{Float32}}}, NamedTuple{(:scale, :bias), Tuple{Vector{Float32}, Vector{Float32}}}, NamedTuple{(:weight, :bias), Tuple{Matrix{Float32}, Matrix{Float32}}}, NamedTuple{(:weight, :bias), Tuple{Matrix{Float32}, Matrix{Float32}}}}}}}})
@ Enzyme.Compiler ~/.julia/packages/Enzyme/Ctome/src/compiler.jl:4397
[10] cached_compilation(job::GPUCompiler.CompilerJob, key::UInt64, specid::UInt64)
@ Enzyme.Compiler ~/.julia/packages/Enzyme/Ctome/src/compiler.jl:4435
[11] #s512#108
@ ~/.julia/packages/Enzyme/Ctome/src/compiler.jl:4495 [inlined]
[12] var"#s512#108"(F::Any, DF::Any, A::Any, TT::Any, Mode::Any, ModifiedBetween::Any, width::Any, specid::Any, ::Any, f::Any, df::Any, #unused#::Type, tt::Any, #unused#::Type, #unused#::Type, #unused#::Type, #unused#::Any)
@ Enzyme.Compiler ./none:0
[13] (::Core.GeneratedFunctionStub)(::Any, ::Vararg{Any})
@ Core ./boot.jl:580
[14] thunk
@ ~/.julia/packages/Enzyme/Ctome/src/compiler.jl:4523 [inlined]
[15] autodiff(f::var"#1#2", #unused#::Type{Active}, args::Active{NamedTuple{(:layer_1, :layer_2, :layer_3, :layer_4, :layer_5), Tuple{NamedTuple{(:scale, :bias), Tuple{Vector{Float32}, Vector{Float32}}}, NamedTuple{(:weight, :bias), Tuple{Matrix{Float32}, Matrix{Float32}}}, NamedTuple{(:scale, :bias), Tuple{Vector{Float32}, Vector{Float32}}}, NamedTuple{(:weight, :bias), Tuple{Matrix{Float32}, Matrix{Float32}}}, NamedTuple{(:weight, :bias), Tuple{Matrix{Float32}, Matrix{Float32}}}}}})
@ Enzyme ~/.julia/packages/Enzyme/Ctome/src/Enzyme.jl:320
[16] top-level scope
@ REPL[12]:1
[17] top-level scope
@ ~/.julia/packages/CUDA/fAEDi/src/initialization.jl:52