Flux.params does not recognize parameters with `x -> layer(x)` syntax

Hi all,

I have a basic question about params in Flux that confuses me.
Why do I get different parameters in these two cases? Shouldn’t be equivalent?

using Flux

m1 = x -> Dense(4, 2)(x)
println(params(m1))

m2 = Dense(4, 2)
println(params(m2))

returns

Params([])

Params([Float32[0.23799 0.457357 0.632557 0.435667; -0.869019 -0.760002 0.018492 -0.615113], Float32[0.0, 0.0]])

i.e. parameters are not “recognized” in the first case.

I am using Julia Version 1.0.5 (2019-09-09) and [587475ba] Flux v0.10.3.

The reason is related to question #18877, where the solution proposed has the same issue of not recognizing the parameters in the first layers where x -> layer(x) syntax is employed.

Thanks in advance.

Dense(4, 2) creates a new dense layer. So m1 applies a brand new Dense to x each time it’s called, so there’s pretty much no way this is what you want…

You can overload params with a bit of type piracy to get the parameters of this thing, but I think it’s a better idea to figure out what you want to do exactly first, and find the best way to do that.

Also, I suggest upgrading to julia 1.5. The difference is huge, and unlike the standard libraries (which Flux is not a part of) most packages only guarantee compatibility with the latest or nearly latest julia version, not necessarily the LTS version.

Hi @tomerarnon,
thanks for the prompt clarification!

As I mentioned in the question, the issue is relative to #18877 where in the proposed solution the parameters are “recognized” only in the layers D and E, and not in A*, B*, and C*. From their discussion it seems that this wasn’t the intended behavior, since probably the other layers should also be trainable.

I see. In that case I do not recommend the solution proposed in that thread, since overloading params as necessary would involve some pretty severe type piracy. I think the Concat struct is actually the way to go. That post is quite old, and flux has changed somewhat (e.g. @treelike is deprecated, to my knowledge) so the following should be the way to do it now:

struct Concat{T}
    catted::T
end
Concat(xs...) = Concat(xs)

Flux.@functor Concat


# if the behavior you're looking for is to produce
# [f1(x); f2(x) ...], which is more image-processing-like
# function (C::Concat)(x)
#     mapreduce(f->f(x), vcat, C.catted)
# end

# if the behavior you're looking for is to produce
# [f1(x1); f2(x2) ...], like in the original issue
function (C::Concat)(x)
    mapreduce((f, x) -> f(x), vcat, C.catted, x)
end
# The example from the post:
julia> A1 = Dense(5, 5); B1 = Dense(5, 5); A2 = Dense(5, 5); 
B2 = Dense(5, 5); C2 = Dense(5, 5); D = Dense(10, 5);
E = Dense(5, 5);

julia> model = Chain(Concat(Chain(A1, B1), Chain(A2, B2, C2)), D, E);

# the expected input to `Concat` with the f1(x1) form is 
# now an iterable of arrays
julia> input = [rand(5), rand(5)]; 

julia> model(input)
5-element Array{Float32,1}:
  0.71026456
 -0.7175311
 -0.30744648
 -0.8034078
  0.8368984

And params works on the Concat as expected

julia> Flux.params(model[1])
Params([Float32[0.23016948 0.05715016 … 0.62341815 -0.5434392; -0.37009057 -0.3970632 … 0.62382203 0.41673365; … ; 0.061818454 0.16528758 … -0.2425396 0.603057; -0.260415 0.33849806 … -0.32674253 0.50213516], Float32[0.0, 0.0, 0.0, 0.0, 0.0], Float32[0.57999533 -0.34143996 … -0.14378327 0.4262758; -0.764463 -0.55587304 … -0.45619386 -0.2658353; … ; -0.45302996 0.5440691 … -0.6935754 -0.68062; -0.5139856 -0.12330929 … 0.17966792 -0.40127572], Float32[0.0, 0.0, 0.0, 0.0, 0.0], Float32[-0.061877552 0.5428227 … -0.31964645 0.76445305; 0.015927391 0.21079396 … -0.6141678 -0.17098361; … ; 0.34704348 -0.64504695 … -0.67650115 0.28716063; -0.5823784 0.16855657 … 0.4026331 -0.18516172], Float32[0.0, 0.0, 0.0, 0.0, 0.0], Float32[-0.42521372 0.13450116 … 0.21042147 -0.65950245; -0.5163227 0.31146872 … 0.5260253 -0.059268236; … ; -0.01898363 -0.010103932 … 0.5870273 -0.6593026; -0.23822755 -0.76355386 … -0.4103931 0.7052716], Float32[0.0, 0.0, 0.0, 0.0, 0.0], Float32[-0.48998186 0.76523256 … 0.42303228 -0.017605193; 0.09143734 0.28700328 … 0.24258374 0.46789637; … ; 0.749469 0.26842743 … 0.36736602 0.13056715; 0.44405162 -0.24533193 … -0.63523906 0.18275999], Float32[0.0, 0.0, 0.0, 0.0, 0.0]])
2 Likes

Thanks @tomerarnon, it works!
(updated Julia to version 1.5.1 to make your code working).

I will link your answer to the other issue.

1 Like