Saving a model built with Transformers.jl

Hi there,
I am pretty new to anything related to this topic in general, but I found Flux to be quite accessible and wanted to have a look at the Transformers.jl package as well. I’ve been playing around with the copy task (the toy example found here: Transformers.jl/example/AttentionIsAllYouNeed at master · chengchingwen/Transformers.jl · GitHub) and was wondering how one could save the model to use in a future session. I know that there is the @save macro from the BSON.jl package but something like

@save "/path/to/folder/" trf_model

or saving every single layer did not do the trick for me.
I’d appreciate any hint on how to save and load the model in the above-mentioned example (@chengchingwen).

Cheers

What error did you get?

1 Like

Hi,
thanks for the quick response!
When not loading the model onto CPU before saving, I got the ReadOnlyMemoryError(). So I moved it to CPU and back to GPU after loading, but then the decode_loss() function returned NaNs while training.
But from your response it seems that saving trf_model alone (and not all single layers on top) should suffice?

Cheers

It should be suffice. Did you move the model back to cpu before saving, and move it to gpu after loading? BSON does not support directly saving array from gpu, so an extra copy to cpu is needed.

Yes, I did that. I think I found the mistake: I just did not set the model as constant again when loading it. Now it seems to have loaded the model, but this brings me to a follow-up question: in the example models you used translate(), which uses embed, encoder etc. - which are not saved as such when only saving trf_model. How can I access their trained parameters in a subsequent session?

They are in the trf_model. trf_model is just a model wrapper wrapping those embed, encoder etc. You should be able to access those model in trf_model’s field. (or you can just save embed/encoder/… instead.)

1 Like

Right! I was struggling with this because fieldnames() didn’t work but propertynames() did the trick and after a little bit of maneuvering I found all I needed to make translate() work with the trained parameters. Many thanks!

Cheers