I want to define a convolutional layer in Flux.jl using the groups
option. Ideally, I want to use Conv((1,1), 3=>3, groups=3)
. For reference, the Flux.Conv documentation defines this argument as
- Keyword
is expected to be anInt
. It specifies the number of groups to divide a convolution into.
As I understand, Conv((1,1), 3=>3, groups=3)
should define a 1 \times 1 convolution taking as input 3 channels, outputting 3 channels, but the convolution should be layer-wise. In other words, this should be equivalent to three convolutions applied independently to each channel, in particular the number of parameters should be 3 \times (1+1)=6 since a 1=>1 convolution has just 2 parameters.
julia> T = Conv((1,1), 3=>3, groups=3)
Conv((1, 1), 1 => 3) # 6 parameters
The number of parameters fits, but the in/out channels seems bizarre. And then, when I try to apply T to an array with 3 channels, I get ERROR: AssertionError: DimensionMismatch("Data input channel count (3 vs. 3)")
(see below for the full stacktrace).
For the sake of testing, I tried to apply T to an array with 1 channel, and I also get an ERROR: DimensionMismatch("Input channels must match! (1 vs. 1)")
Can someone explain how the groups
argument works in Flux.Conv ?
Thanks !
Stacktrace with 3 channels (this should work but does not) :
julia> T(rand(10, 10, 3, 1))
ERROR: AssertionError: DimensionMismatch("Data input channel count (3 vs. 3)")
[1] check_dims(x::NTuple{5, Int64}, w::NTuple{5, Int64}, y::NTuple{5, Int64}, cdims::DenseConvDims{3, (1, 1, 1), 3, 3, 3, (1, 1, 1), (0, 0, 0, 0, 0, 0), (1, 1, 1), false})
@ NNlib ~/.julia/packages/NNlib/P9BhZ/src/dim_helpers/DenseConvDims.jl:73
[2] conv_direct!(y::Array{Float64, 5}, x::Array{Float64, 5}, w::Array{Float32, 5}, cdims::DenseConvDims{3, (1, 1, 1), 3, 3, 3, (1, 1, 1), (0, 0, 0, 0, 0, 0), (1, 1, 1), false}; alpha::Float64, beta::Bool)
@ NNlib ~/.julia/packages/NNlib/P9BhZ/src/impl/conv_direct.jl:51
[3] conv_direct!
@ ~/.julia/packages/NNlib/P9BhZ/src/impl/conv_direct.jl:51 [inlined]
[4] conv!(y::Array{Float64, 5}, in1::Array{Float64, 5}, in2::Array{Float32, 5}, cdims::DenseConvDims{3, (1, 1, 1), 3, 3, 3, (1, 1, 1), (0, 0, 0, 0, 0, 0), (1, 1, 1), false}; kwargs::Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
@ NNlib ~/.julia/packages/NNlib/P9BhZ/src/conv.jl:293
[5] conv!(y::Array{Float64, 5}, in1::Array{Float64, 5}, in2::Array{Float32, 5}, cdims::DenseConvDims{3, (1, 1, 1), 3, 3, 3, (1, 1, 1), (0, 0, 0, 0, 0, 0), (1, 1, 1), false})
@ NNlib ~/.julia/packages/NNlib/P9BhZ/src/conv.jl:291
[6] conv!(y::Array{Float64, 4}, x::Array{Float64, 4}, w::Array{Float32, 4}, cdims::DenseConvDims{2, (1, 1), 3, 3, 3, (1, 1), (0, 0, 0, 0), (1, 1), false}; kwargs::Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
@ NNlib ~/.julia/packages/NNlib/P9BhZ/src/conv.jl:151
[7] conv!
@ ~/.julia/packages/NNlib/P9BhZ/src/conv.jl:151 [inlined]
[8] conv(x::Array{Float64, 4}, w::Array{Float32, 4}, cdims::DenseConvDims{2, (1, 1), 3, 3, 3, (1, 1), (0, 0, 0, 0), (1, 1), false}; kwargs::Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
@ NNlib ~/.julia/packages/NNlib/P9BhZ/src/conv.jl:91
[9] conv(x::Array{Float64, 4}, w::Array{Float32, 4}, cdims::DenseConvDims{2, (1, 1), 3, 3, 3, (1, 1), (0, 0, 0, 0), (1, 1), false})
@ NNlib ~/.julia/packages/NNlib/P9BhZ/src/conv.jl:89
[10] (::Conv{2, 4, typeof(identity), Array{Float32, 4}, Vector{Float32}})(x::Array{Float64, 4})
@ Flux ~/.julia/packages/Flux/ZnXxS/src/layers/conv.jl:163
[11] top-level scope
@ REPL[48]:1
` errors:
Stacktrace with 1 channel (this should not work and does not work):
julia> T(rand(10,10,1,1))
ERROR: DimensionMismatch("Input channels must match! (1 vs. 1)")
[1] DenseConvDims(x_size::NTuple{4, Int64}, w_size::NTuple{4, Int64}; stride::Tuple{Int64, Int64}, padding::NTuple{4, Int64}, dilation::Tuple{Int64, Int64}, flipkernel::Bool, groups::Int64)
@ NNlib ~/.julia/packages/NNlib/P9BhZ/src/dim_helpers/DenseConvDims.jl:30
[2] #DenseConvDims#7
@ ~/.julia/packages/NNlib/P9BhZ/src/dim_helpers/DenseConvDims.jl:60 [inlined]
[3] (::Conv{2, 4, typeof(identity), Array{Float32, 4}, Vector{Float32}})(x::Array{Float64, 4})
@ Flux ~/.julia/packages/Flux/ZnXxS/src/layers/conv.jl:162
[4] top-level scope
@ REPL[49]:1