I am using Lux.jl
, and I want to do algebra on the gradients before applying. Currently I convert them to ComponentArray
and convert them back to NamedTuple
after I do the necessary algebra. But this requires me to move my ps
to CPU and back. Is there any way to manipulate where I don’t need to offload to CPU?
A follow up doubt is, how to keep track of the ps
even after updates, so that I can reset the system back after any updates? I currently use a deepcopied ComponentArray for that, but I believe that is not the efficient way to go about.