You need to use the correct variable names ps_trained, st_trained.
params, states = load("checkpoints/trained_model.jld2") doesn’t work because you have Turing.jl imported and you cannot assign a variable with the saeme name as an exported function