Deepcopy Flux Model

Unfortunately this doesn’t work because cpu/gpu don’t recurse into Dicts, but it lead me to something that should :slight_smile:

ps_gpu = params(model)
model = cpu(model)

# you could also create a new ADAM() here
opt.state = IdDict(pc => cpu(opt.state[pg]) for (pc, pg) in zip(params(model), ps_gpu))

BSON.@save "model.bson" model opt

#### now lets get it back

BSON.@load "model.bson" model opt

ps_cpu = params(model)
model = gpu(model)

# you could also create a new ADAM() here
opt.state = IdDict(pg => gpu(opt.state[pc]) for (pc, pg) in zip(params(model), ps_gpu))

This can be pulled out into a function:

function load_opt_state!(opt::ADAM, ps_dest, ps_src; transform=identity)
  opt.state = IdDict(p_dest => transform(opt.state[p_src]) for (p_dest, p_src) in zip(ps_dest, ps_src))
end

# example usage
load_opt_state!(opt, ps_cpu, ps_gpu, transform=cpu)
4 Likes