I am using Lux.jl
and I want to access the flattened parameter vector. I use ComponentArrays.jl
to achieve this. Now, if I have a flattened model vector m0
, how do I get the ps back from this?
# flattened vector from the ps
ca = ComponentArray(ps)
m0 = getdata(ca)
In this case how do I get back ps
from m0
, given I still have a copy of the ca
?