A custom type-stable regularization function in Lux & Enzyme

We are trying to develop a custom regularization function for the biases and weights of a simple dense neural network in Lux. However, we have not found a way to establish type-stable gradient computation for Lux and Enzyme. In Custom loss functions in `Lux.jl`, it was said that calling Functors.fleaves(ps) would yield the parameters of the neural network, but the resulting vector is type unstable. We cannot use WeightDecay either, because our custom regularization function is not a simple L2 regularization function. Our regularization function is written in Julia. In the toy code below, we do not use our regularization function, but just a plain abs2 function for demonstration purposes. The first issue is that the code does not run because it attempts to index the ps entries through scalar indexing, which raises an error in the mapreduce fucntion: Scalar indexing is disallowed.

What is the correct way to use Lux and Enzyme for custom regularization in a type-stable and GPU-friendly manner? The recent discussion at Custom loss functions in `Lux.jl` sounds a bit like custom regularization of parameters is not even possible currently.

using Lux
using AMDGPU
using Reactant
using Random
using Functors
using Enzyme
using FiniteDifferences
using ComponentArrays

# import .EnzymeRules: reverse, augmented_primal
# using .EnzymeRules

rng = Random.default_rng()
Random.seed!(rng, 0)

 model = @compact(
    w1 = Dense(12 => 2),          
    w2 = Dense(2 => 1),   
    w3 = Dense(1 => 1),             
    act = tanh,
) do x
    embed = act.(w1(x))
    embed = act.(w2(embed))
    out = w3(embed)
    @return out
end 

ps, st = Lux.setup(rng, model) 

x = rand(rng, Float32, 12, 2)
y, st1 = Lux.apply(model, x, ps, st)
y = y + Float32.(randn(size(y)))

const xdev = reactant_device()
# const xdev = cpu_device()

x_ra = x |> xdev
y_ra = y |> xdev
ps_ra = ps |> ComponentArray |> xdev
st_ra = st |> xdev


 function regulari(ps)

    q = Functors.fleaves(ps)
    # @show typeof(q)
    v = mapreduce(t->sum(abs2.(t)),+,q)

    return v
end


function customloss(model, ps, st, x, y)
    # pred, _ = model(x, ps, st)
    pri = regulari(ps)
    # return MSELoss()(pred, y) + pri
   return pri
end


function enzyme_gradient(model, ps, st, x, y)
    return Enzyme.gradient(Enzyme.Reverse, Const(customloss), Const(model),
        ps, Const(st), Const(x), Const(y))[2]
end

function fd_gradient(model,ps,st,x,y)
    function f0(p)
       return customloss(model, p, st, x, y)
    end
    return FiniteDifferences.grad(central_fdm(3,1),f0,ps)
end



enzyme_gradient_compiled = @compile enzyme_gradient(model, ps_ra, st_ra, x_ra, y_ra)
gr0 = enzyme_gradient_compiled(model, ps_ra, st_ra, x_ra, y_ra)
gr = gr0 |> cpu_device()
gr_fdm = fd_gradient(model,ps,st,x,y)