Unfortunately this doesn’t work because cpu
/gpu
don’t recurse into Dicts, but it lead me to something that should
ps_gpu = params(model)
model = cpu(model)
# you could also create a new ADAM() here
opt.state = IdDict(pc => cpu(opt.state[pg]) for (pc, pg) in zip(params(model), ps_gpu))
BSON.@save "model.bson" model opt
#### now lets get it back
BSON.@load "model.bson" model opt
ps_cpu = params(model)
model = gpu(model)
# you could also create a new ADAM() here
opt.state = IdDict(pg => gpu(opt.state[pc]) for (pc, pg) in zip(params(model), ps_gpu))
This can be pulled out into a function:
function load_opt_state!(opt::ADAM, ps_dest, ps_src; transform=identity)
opt.state = IdDict(p_dest => transform(opt.state[p_src]) for (p_dest, p_src) in zip(ps_dest, ps_src))
end
# example usage
load_opt_state!(opt, ps_cpu, ps_gpu, transform=cpu)