Currently, cpu(optimiser) won’t move it. For instance, the state still consists of variables in GPU.
It should work adding
Flux.@functor ADAM (or whatever optimizer you are using) to your code. You should open an issue in Flux.jl stating your use case to see if it is worth adding this feature
Flux.@functor ADAM doesn’t work.
I thought this is a quite basic functionality as we need to restart the training for large-scale learning. Unless we stay with small problems that can easily finished within couple of hours.
And without loading previously saved optimizer, the training can not be restarted properly.
I tried but didn’t figure out how to use it with Flux models? The tested examples are not for neural network models.
Getting what you want here might require a bit of extra effort.
Flux’s current optimizers use
IdDicts to map weights to optimizer state and when you move parameters to and from the gpu you create new copies. Result is that the
IdDict will not recoginize them as the same weighs and instead you have a memory-leak like (weight-leak?) situation.
Depending on what method you use for storing the models and state you might end up with the same problem here before even moving anything (e.g. that weights in optimizer are no longer the same objects as the weights in the model).
I haven’t followed the development of the new optimizers very carefully, but I suppose both new and current optimizers would require you to manually compare weight values (hoping that there are no duplicates) or use some other way to identify the weights and then remap weigths to optimizer state.
@DrChainsaw Wonderful insights. Indeed, manually re-mapping the weights and optimizer state is just too much work. The time spent will enable me to re-implement everything in PyTorch
To get around the issue, I guess the Flux optimizer has to store the optimizer state in a string key rather than use the CUDA matrix as a key, which is so unreliable. For example
gpu(cpu(a)) will not be
a anymore when a is a CUDA array.
The new optimizers work off of Zygote’s support for structural gradients. That is, you get a nested (named)tuple back which has the same structure as your model. For those who’ve used JAX-based libraries recently, this may look familiar to you (likewise for
state_dict in PyTorch). You can try out Optimisers.jl today, and there should be 0 IdDicts stored anywhere when you use it
How to work with Flux models, like Dense, in optimiser.jl? Any example code will be appreciated.
Currently we need a bit more internal plumbing (Optimisers.jl is still experimental) to get most Flux layers working OOTB. Missing definition in the basic usage example · Issue #26 · FluxML/Optimisers.jl · GitHub has a good summary there. In the meantime, you can try something like these (warning: untested!) functions:
# change opt type and IdDict field name for whatever you're using function extract_opt_state(opt::ADAM, model) func = Flux.Functors.children(model) map(func) do child if Flux.isleaf(child) get(opt.state, child, nothing) else extract_opt_state(opt, child) end end end function restore_opt_state!(opt::ADAM, model, state) func = Flux.Functors.children(model) map(func, state) do child, st if Flux.isleaf(child) && st !== nothing opt.state[child] = st else restore_opt_state!(opt, child, st) end end end