Reset model parameters Flux.jl

Hi, I’d like to ask how to reset model parameters without saving initial parameters of the model such as init_theta = Flux.params(model) and then loading them to the model when reseting.

I would like to use built in Flux.reset() but it does not give me what I want, model params remain same.

Question to @MikeInnes probably. Thanks a lot.

julia> best_m
Chain(Dense(11, 128), Dense(128, 64, relu), Dense(64, 16, relu), Dense(16, 10), #37)

julia> bm_pars = Flux.params(best_m)
Params([Float32[0.18607965 0.061578143 … -0.06005481 0.09483149; 0.19160587 0.12552795 … -0.19273274 -0.13584293; … ; -0.1559881 -0.15767108 … 0.05275579 0.09514192; 0.09612731 -0.039975233 … 0.074506715 -0.10788479], Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0  …  0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], Float32[0.023085387 -0.16253611 … -0.043780163 -0.03891604; 0.11446182 0.00051014544 … 0.06537806 0.07422932; … ; -0.07333294 -0.14721991 … 0.14628817 0.17414984; 0.021892758 -0.034226023 … -0.038455416 -0.0754217], Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0  …  0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], Float32[0.10153302 0.14251384 … 0.077751316 -0.22656487; 0.2341217 -0.20544885 … -0.074172765 -0.07469577; … ; 0.2531221 -0.2712811 … -0.18341343 -0.08414964; -0.09518282 -0.07602195 … -0.17687938 -0.049255677], Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], Float32[-0.010512488 0.251019 … -0.29921338 0.27331173; -0.17924601 -0.4749688 … 0.39040837 -0.07162238; … ; 0.30139822 -0.30392915 … -0.25658587 -0.2025866; -0.2734321 -0.44143367 … 0.38832903 0.0912851], Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]])

julia> Flux.reset!(best_m)

julia> br_pars = Flux.params(best_m)
Params([Float32[0.18607965 0.061578143 … -0.06005481 0.09483149; 0.19160587 0.12552795 … -0.19273274 -0.13584293; … ; -0.1559881 -0.15767108 … 0.05275579 0.09514192; 0.09612731 -0.039975233 … 0.074506715 -0.10788479], Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0  …  0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], Float32[0.023085387 -0.16253611 … -0.043780163 -0.03891604; 0.11446182 0.00051014544 … 0.06537806 0.07422932; … ; -0.07333294 -0.14721991 … 0.14628817 0.17414984; 0.021892758 -0.034226023 … -0.038455416 -0.0754217], Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0  …  0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], Float32[0.10153302 0.14251384 … 0.077751316 -0.22656487; 0.2341217 -0.20544885 … -0.074172765 -0.07469577; … ; 0.2531221 -0.2712811 … -0.18341343 -0.08414964; -0.09518282 -0.07602195 … -0.17687938 -0.049255677], Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], Float32[-0.010512488 0.251019 … -0.29921338 0.27331173; -0.17924601 -0.4749688 … 0.39040837 -0.07162238; … ; 0.30139822 -0.30392915 … -0.25658587 -0.2025866; -0.2734321 -0.44143367 … 0.38832903 0.0912851], Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]])

# best params are same as parameters of model after reset
julia> bm_pars .== br_pars
8-element BitArray{1}:
 1
 1
 1
 1
 1
 1
 1
 1

Flux.reset! is just for RNNs (and the like) — and it resets their recurrent inputs to their original values, and doesn’t mess with their weights.

To reset the values I often just re-define my model. But I suppose you could also do something like:

map(p->p .= randn.(), params(best_m))
2 Likes

Thanks for your answer. Neat and clean solution.
As one can do it completely:

m2 = Chain(Dense(10,20,relu), Dense(20,10), x -> σ.(x))
Flux.params(m2) # old params
Flux.loadparams!(m2, map(p -> p .= randn.(), Flux.params(m2))) # new params
3 Likes