How to manipulate gradients in Lux.jl

I am using Lux.jl, and I want to do algebra on the gradients before applying. Currently I convert them to ComponentArray and convert them back to NamedTuple after I do the necessary algebra. But this requires me to move my ps to CPU and back. Is there any way to manipulate where I don’t need to offload to CPU?

A follow up doubt is, how to keep track of the ps even after updates, so that I can reset the system back after any updates? I currently use a deepcopied ComponentArray for that, but I believe that is not the efficient way to go about.