Flux: "cannot broadcast array to have fewer dimensions"

I’m trying to get into Flux and Julia, so I defined two simple models. The first one works, the second one throws an error that I don’t quite understand yet. I’m sure it’s relatively simple to solve, I just don’t see it.

The models:

module Models
    export Standard, Attention

    using Flux

    struct Standard
        dense1
        lstm
        dense2
    end

    Flux.@functor Standard

    function (m::Standard)(xᵢ)
        in  = m.dense1(xᵢ)
        hₙ   = m.lstm(in)[:, end]
        out = m.dense2(hₙ)
        softmax(out)
    end

    struct Attention
        dense1
        lstm
        attnᵢ
        attn_query
        v
        dense2
    end

    @Flux.functor Attention

    function (attn::Attention)(xᵢ)
        in = attn.dense1(xᵢ)
        h = attn.lstm(in)
        q   = attn.attn_query(h[:, end])
        αₙ  = softmax(attn.v' * tanh.(attn.attnᵢ(h) .+ q), dims=2)
        hₛ = reduce(+, eachcol(αₙ .* h))
        out = attn.dense2(hₛ)
        softmax(out)
    end
end

With this I’m able to reproduce the error:

using Flux

classes = "pos", "neg"
labels = 1:2
emb_dim = 300
hidden_dim = 50
attn_dim = 30

model = Models.Attention(
    Dense(emb_dim, hidden_dim),
    LSTM(hidden_dim, hidden_dim),
    Dense(hidden_dim, attn_dim),
    Dense(hidden_dim, attn_dim),
    rand(attn_dim),
    Dense(hidden_dim, length(labels))
)

opt = ADAM()
loss(xᵢ, yᵢ) = -log(sum(model(xᵢ) .* Flux.onehot(yᵢ, labels)))
ps = params(model)

for (i, (xᵢ, yᵢ)) in enumerate([(rand(300, 55), 1), (rand(300, 176), 2), (rand(300, 13), 1)]) # random data
    Flux.reset!(model)
    gs = gradient(ps) do
        training_loss = loss(xᵢ, yᵢ)
        println(training_loss)
        training_loss
    end
    Flux.update!(opt, ps, gs)
end

The error:

ERROR: DimensionMismatch("cannot broadcast array to have fewer dimensions")
Stacktrace:
 [1] check_broadcast_shape(::Tuple{}, ::Tuple{Base.OneTo{Int64}}) at ./broadcast.jl:507
 [2] check_broadcast_shape(::Tuple{Base.OneTo{Int64}}, ::Tuple{Base.OneTo{Int64},Base.OneTo{Int64}}) at ./broadcast.jl:510
 [3] check_broadcast_axes at ./broadcast.jl:512 [inlined]
 [4] check_broadcast_axes at ./broadcast.jl:516 [inlined]
 [5] instantiate at ./broadcast.jl:259 [inlined]
 [6] materialize!(::Array{Float64,1}, ::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{2},Nothing,typeof(+),Tuple{Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1},Nothing,typeof(*),Tuple{Float64,Array{Float64,1}}},Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{2},Nothing,typeof(*),Tuple{Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{0},Nothing,typeof(-),Tuple{Int64,Float64}},LinearAlgebra.Adjoint{Float64,Array{Float64,2}}}}}}) at ./broadcast.jl:823
 [7] apply!(::ADAM, ::Array{Float64,1}, ::LinearAlgebra.Adjoint{Float64,Array{Float64,2}}) at /home/f/user/.julia/packages/Flux/Fj3bt/src/optimise/optimisers.jl:175
 [8] update!(::ADAM, ::Array{Float64,1}, ::LinearAlgebra.Adjoint{Float64,Array{Float64,2}}) at /home/f/user/.julia/packages/Flux/Fj3bt/src/optimise/train.jl:25
 [9] update!(::ADAM, ::Zygote.Params, ::Zygote.Grads) at /home/f/user/.julia/packages/Flux/Fj3bt/src/optimise/train.jl:31
 [10] top-level scope at REPL[47]:8

This seems to be related to Issue with CRF loss function · Issue #1087 · FluxML/Flux.jl · GitHub

Hipshot as I’m on the phone: Try removing that transpose of attn.v and initialize it as rand(1, attn_dim).

1 Like

That seems to be the solution, thanks. Why is the error message so uninformative? I’d never guess that that’s the problem.
How do I transpose then if I have to? collect(v')? reshape?

1 Like

I would guess that it is uninformative due to being caught at a low level which in turn is an indication that it should work but there is a bug somewhere.

My guess is that line 175 in ADAM does alot of broadcasted operations in one go and somehow the abstraction of LinearAlgebra.Adjoint breaks down in that case.

It might be worthwhile to try to reproduce without Flux and see if an issue should be reported in LinearAlgebra.

I had same problem with Flux. Avoid operations that change the total rank of array, like [1, 2, 3]’ which increases the rank from 1 to 2.