Can I modify Flux parameters manually?

I want to be able to manually modify my Flux parameters, but when I have a Zygote.Params object, I can’t seem to change anything about it. I am trying to average the parameters of two networks that are the same shape (Polyak averaging).

Consider this REPL:

julia> using Flux

julia> nn = Dense(3, 3)
Dense(3, 3)

julia> params(nn)
Params([Float32[-0.8835063 0.7971344 -0.3184507; -0.38867068 -0.966625 -0.118569136; -0.39296985 -0.4625342 0.55543184], Float32[0.0, 0.0, 0.0]])

julia> params(nn)[1] *= 0
ERROR: MethodError: no method matching setindex!(::Zygote.Params, ::Array{Float32,2}, ::Int64)
Stacktrace:
 [1] top-level scope at REPL[4]:1

julia> the_params = params(nn)[1]
3×3 Array{Float32,2}:
 -0.883506   0.797134  -0.318451
 -0.388671  -0.966625  -0.118569
 -0.39297   -0.462534   0.555432

julia> typeof(the_params)
Array{Float32,2}

julia> the_params *= 0
3×3 Array{Float32,2}:
 -0.0   0.0  -0.0
 -0.0  -0.0  -0.0
 -0.0  -0.0   0.0

julia> the_params
3×3 Array{Float32,2}:
 -0.0   0.0  -0.0
 -0.0  -0.0  -0.0
 -0.0  -0.0   0.0

julia> params(nn)
Params([Float32[-0.8835063 0.7971344 -0.3184507; -0.38867068 -0.966625 -0.118569136; -0.39296985 -0.4625342 0.55543184], Float32[0.0, 0.0, 0.0]])

julia> params(nn)[1]
3×3 Array{Float32,2}:
 -0.883506   0.797134  -0.318451
 -0.388671  -0.966625  -0.118569
 -0.39297   -0.462534   0.555432

As you can see, I either cannot change the parameters. Or, if I do manage to get a hold of something I can change, it’s just a copy and doesn’t actually change the parameters.

Is there a way I can manually change the parameters?

I needed to use .*= to “broadcast” the change instead of *=.

Consider:

julia> using Flux

julia> nn = Dense(3, 3)
Dense(3, 3)

julia> params(nn)
Params([Float32[-0.1867888 0.45788264 0.42513847; 0.29547477 0.8505957 0.6445496; 0.31435943 -0.9087379 0.5611925], Float32[0.0, 0.0, 0.0]])

julia> params(nn)[1] .*= 0
3×3 Array{Float32,2}:
 -0.0   0.0  0.0
  0.0   0.0  0.0
  0.0  -0.0  0.0

julia> params(nn)
Params([Float32[-0.0 0.0 0.0; 0.0 0.0 0.0; 0.0 -0.0 0.0], Float32[0.0, 0.0, 0.0]])