To save a model the Training a Simple Lstm tutorial shows the following using JDL2 package
# train.jl
using JLD2
@save "trained_model.jld2" ps_trained st_trained
@load "trained_model.jld2" ps_trained st_trained
But this works only when the same variable names are used (Which can only be done within the same script). In the general case you would be doing inference in a new script and using two new variables, but that lead to an error
#test.jl
using JLD2
@load "trained_model.jld2" params states
>>>
KeyError: key "param" not found
# Alternatively trying this,
params, states = load("checkpoints/trained_model.jld2")
>>>
{
"name": "ErrorException",
"message": "cannot assign a value to imported variable Turing.params from module
...
}
I also tried the following, but that didnt work either
vae = VAE(rng, h_dims, z_dims, imp_dims, out_dims)
params, states = Lux.setup(rng, vae)
@load "checkpoints/vae_params_states.jld2" params states
>>>
{
"name": "ErrorException",
"message": "cannot assign a value to imported variable Turing.params from module Main",
...
How would you do this ?