Type stability of `Lux.batched_jacobian`

How can I make the call to Lux.batched_jacobian of a StatefulLuxLayer type stable? I am using Lux v1.12.4.

Consider this MWE:

using Lux
using Random; rng = Xoshiro(42)

input, output = 6, 2
model = Chain(Dense(input => input^2, tanh), Dense(input^2 => output));
ps, st = Lux.setup(rng, model);

n = 100
x = rand(rng, Float32, input, n);
f = StatefulLuxLayer{true}(model, ps, st)

@code_warntype f(x) # Type stable, return type `Matrix{Float32}`
const backend = AutoForwardDiff()
@code_warntype batched_jacobian(f, backend, x) # Type unstable, return type `Any`

Am I missing something?

Edit: typo in MWE

I think you’re missing m = 6 in the MWE?

1 Like

Fixed, thanks!

1 Like

You are missing the chunksize

julia> @code_warntype batched_jacobian(f, AutoForwardDiff(; chunksize=8), x)
MethodInstance for batched_jacobian(::StatefulLuxLayer{Static.True, Chain{@NamedTuple{layer_1::Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}, @NamedTuple{layer_1::@NamedTuple{weight::Matrix{Float32}, bias::Vector{Float32}}, layer_2::@NamedTuple{weight::Matrix{Float32}, bias::Vector{Float32}}}, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}}}, ::AutoForwardDiff{8, Nothing}, ::Matrix{Float32})
  from batched_jacobian(f::F, backend::AutoForwardDiff, x::AbstractArray) where F @ Lux /mnt/.julia/packages/Lux/L2VO7/src/autodiff/api.jl:121
Static Parameters
  F = StatefulLuxLayer{Static.True, Chain{@NamedTuple{layer_1::Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}, @NamedTuple{layer_1::@NamedTuple{weight::Matrix{Float32}, bias::Vector{Float32}}, layer_2::@NamedTuple{weight::Matrix{Float32}, bias::Vector{Float32}}}, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}}}
Arguments
  #self#::Core.Const(Lux.batched_jacobian)
  f::StatefulLuxLayer{Static.True, Chain{@NamedTuple{layer_1::Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}, @NamedTuple{layer_1::@NamedTuple{weight::Matrix{Float32}, bias::Vector{Float32}}, layer_2::@NamedTuple{weight::Matrix{Float32}, bias::Vector{Float32}}}, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}}}
  backend::Core.Const(AutoForwardDiff(chunksize=8))
  x::Matrix{Float32}
Body::Array{Float32, 3}
1 ─ %1 = Lux.AutoDiffInternalImpl::Core.Const(Lux.AutoDiffInternalImpl)
│   %2 = Base.getproperty(%1, :batched_jacobian)::Core.Const(Lux.AutoDiffInternalImpl.batched_jacobian)
│   %3 = (%2)(f, backend, x)::Array{Float32, 3}
└──      return %3
1 Like

That’s strange. Should this be at least documented? Or better, when used in this context, AutoForwardDiff should have a default chunksize?