Loading Lux params and states in a new script

To save a model the Training a Simple Lstm tutorial shows the following using JDL2 package

# train.jl
using JLD2
@save "trained_model.jld2" ps_trained st_trained

@load "trained_model.jld2" ps_trained st_trained

But this works only when the same variable names are used (Which can only be done within the same script). In the general case you would be doing inference in a new script and using two new variables, but that lead to an error

#test.jl
using JLD2
@load "trained_model.jld2" params states
>>>
KeyError: key "param" not found

# Alternatively trying this,
params, states = load("checkpoints/trained_model.jld2")
>>>
{
	"name": "ErrorException",
	"message": "cannot assign a value to imported variable Turing.params from module
...
}

I also tried the following, but that didnt work either

vae = VAE(rng, h_dims, z_dims, imp_dims, out_dims)
params, states = Lux.setup(rng, vae)
@load "checkpoints/vae_params_states.jld2" params states
>>>
{
	"name": "ErrorException",
	"message": "cannot assign a value to imported variable Turing.params from module Main",
...

How would you do this ?

You need to use the correct variable names ps_trained, st_trained.

params, states = load("checkpoints/trained_model.jld2") doesn’t work because you have Turing.jl imported and you cannot assign a variable with the saeme name as an exported function

2 Likes