Hi,
Is there a way to make Zygote gradients type stable when optimizing over a ComponentArray of parameters? If not - what is a good alternative that is compatible with the SciML ecosystem?
Using @NamedTuple does as I expect, but is not possible for my problem as it is not a SciMLStructure
using Lux, Zygote, ComponentArrays, StableRNGs
const model = Chain(
Dense(2, 8, tanh),
Dense(8, 1))
ps, st = Lux.setup(StableRNG(1111), model)
const _st = st
const P = ComponentArray(ps=ps)
#* This is type stable
@benchmark Zygote.gradient(p -> model(x, p, _st)[1][1], ps) # fast
@code_warntype Zygote.gradient(p -> model(x, p, _st)[1][1], ps)
#* This is NOT type stable for some reason...
@benchmark Zygote.gradient(p -> model(x, p, _st)[1][1], P.ps) # slow
@code_warntype Zygote.gradient(p -> model(x, p, _st)[1][1], P.ps)
Best regards.