I have a pretrained custom model which I use it as the input for a second model.
I can send the first model to gpu without any issue and save the model parameters as expected. The second model just copies the parameters of the first model and creates another custom model. But when I send this 2nd custom model to gpu it simply does not.
For the mwe:
using Flux: Embedding, EmbeddingBag using BSON: @save # this is from first model struct GEmbeddings WE::Embedding CE::Embedding wbias::Embedding cbias::Embedding end function createParams(VSIZE::Int64, IN_DIM::Int64) WE = Embedding(VSIZE, IN_DIM, init=Flux.truncated_normal()); CE = Embedding(VSIZE, IN_DIM, init=Flux.truncated_normal()); wbias = Embedding(VSIZE, 1, init=Flux.truncated_normal()); cbias = Embedding(VSIZE, 1, init=Flux.truncated_normal()); return WE, CE, wbias, cbias end # this is for the second model struct GEmbeddings2 WE::EmbeddingBag CE::EmbeddingBag wbias::EmbeddingBag cbias::EmbeddingBag end # copy parameters of the first model as embedding bag # copy and create Parameters function ccParams(model::Glove) names = fieldnames(typeof(model)) fields = map(name -> getfield(model, name), names) WE, CE, wbias, cbias = map(x -> EmbeddingBag(size(x.weight, 2) => size(x.weight, 1); init=(_,_) -> x.weight) , fields) return GEmbeddings2(WE, CE, wbias, cbias) end # creating 1st model (Glove) WE, CE, wbias, cbias = createParams(10, 20) # dummy inputs model = GEmbeddings(WE, CE, wbias, cbias) |> gpu # works perfectly fine # returns all parameters as expected model2 = ccParams(model) # does not send the parameters to gpu model2_gpu = model2 |> gpu
In order to send the parameters of the 2nd model to gpu, I have to call
|> gpu for every
EmbeddingBag: WE, CE, … I am also ok with that. But the annoying thing is that when I try to save the 2nd model parameters, which I have to send them to cpu, I cannot see them. In execution of the model file, the model seems to save the parameters but when training finishes I do not get any values from the saved .bson files.
Could some help ?