Flux - LSTM - Issue with input format for multiple features

Hi community,

I’m diving into machine learning in Julia and currently having some issues with the input format for a LSTM model. The target is to predict the current activity of a human based on motion data from a smartphone:

https://archive.ics.uci.edu/ml/datasets/human+activity+recognition+using+smartphones

The data has 9 features and a sequence length of 128. With batch loading via DataLoader, the input data has currently the format 9x128xbatch_size as a Array{Float32,3}. No matter how I try to apply the model (currently with broadcasting), the training quits with either a DimensionsMismatch or MethodError: no method matching *(::Array{Float32,2}, ::Array{Float32,3}).

I searched the documentation and examples for hints around input shape for multiple features, but I did not find any helpful.

The code of the project can be accessed under
https://github.com/thomaszub/human-activity-ml-julia

Thanks for help!

1 Like

Hi @thomaszub!

Flux RNNs and broadcasting work in a different way than the python alternatives. In short, RNNs in julia don’t work on 3d-arrays (except when using CUDA, but this is handled at compile time I believe). Instead they work on arrays of matrices.

What this means for your example can be summarized in the following snippet:

julia> using Flux
julia> rnn = RNN(9, 1);
julia> lstm = LSTM(9, 1);
julia> x = rand(9, 128, 32);
julia> map(rnn, [view(x, :, t, :) for t ∈ 1:128])
julia> map(lstm, [view(x, :, t, :) for t ∈ 1:128])

I’m using map instead of the normal broadcast because of a bug when moving to zygote. It believe it is fixed on master, but don’t know what release it is apart of. Effectively the gradients were being truncated to only be one step back in time.

4 Likes

Thanks! This solved the issue.

Is there any reason in Flux to not use 3d arrays? I think this would make the API more convenient.

1 Like

Good question. I wasn’t around when Flux was originally written. But I assume it has to do with how matrix multiplication works in julia and Flux’s overall goal of being just normal looking Julia.

Imagine you have a chain

julia> model = chain(Dense(2, 10), RNN(10, 1))

The chain is just a convenient operator for chaining information forward. The dense layer is very simple under the hood resulting in this operation for an input vector x (you can find this for 0.11.6 here.

function (a::Dense)(x::AbstractArray)
  W, b, σ = a.W, a.b, a.σ
  σ.(W*x .+ b)
end

Now if we use 3d-arrays, this doesn’t work because W (which is a matrix) and x (which is now a 3d-array) don’t have default behavior for multiplication, and in fact there are many ways to do this where it doesn’t make sense for * to have default behavior. To make this work w/ 3-d arrays in the most “julian” way (at least by my reckoning) we would dispatch on the number of dimensions the array has and specialize. This would cause every layer (i.e. dense, recurrent, conv, etc…) to need to have a special function for these 3d arrays. Adding to the code complexity of Flux. What this function would look like is almost exactly (i think) what I wrote above, but now for each layer in the network individually. In fact what I wrote above should be just as fast as anything you would write here because for-loops are fast in Julia, so there isn’t much optimization to be had, and as far as I know there aren’t any specialized blas operations for this.

Instead, it makes more sense to use the broadcast or map functionality. This cuts down on specialized code in Flux, and makes running a model seem very julian/intuitve. There is no confusion about what the broadcast is doing, but in the 3d-array version now I have to check which strand is the time index.

It may seem more convenient to have support for 3d-arrays because this is what Tensorflow/Pytorch do. But they have to do this for efficiency as python for loops can be slow, so they just do the for loop in c instead. This isn’t a bad thing. And actually quite clever. It is just that Julia doesn’t have that constraint. So why design like we do.

1 Like

Of course. This doesn’t apply to GPU versions, where CUDNN specializes for 3d-arrays. But I’m 99% sure julia handles this at compilation time/when the data gets moved to the GPU (which needs to happen anyway), so the user facing code stays the same (I’m pretty sure) and dispatch is used to do all the onerous stuff.

1 Like

Thanks for the explanation.

You gave good arguments about the design. From the mathematical view it makes now even more sense to me as the time nature is handled by reapplying the network sequentially to the time series. So Julia maps more “natural” from this perspective.

1 Like

Hello, I had a look at the code you posted after using ‘map’. Do you mind if I ask why you only took the last row in the function ‘‘apply’’? Isn’t that like basically calculating only the last value and disregarding all other values?
I thought a more sensible approach would be to take all values produced by ‘map’ and creating your own loss function as below. Let me know your input:

function lossfn(x, y)
Flux.reset!(model)
y_model = eval_model(x)
#adjusting the loss function for mean square error Flux.mse(y_model,y)
diff = [mean(abs2, y[i] .- [ x[i] for x in y_model ]) for i = 1:length(y) ]
return mean(diff)
end

I found this thread and I think it’s not accurate. First of all Flux which uses NNLib dispatches to both 2D and 3D API exposed by CUDNN (that’s the case since long time). The 3D version is just not documented but you can pass three dimensional array through LSTM/GRU or any RNN. It’s something CUDNN supports and that is correctly dispatched to. I don’t know about non CUDAs applications.

Secondly, it states in documentation not to use map or broadcast dot operator with reccurent layers. Because order is not guaranteed and you want to send your input in sequencial order.

That hasn’t been the case for over 2 years. cuDNN was dropped because it turned out to be slower with the current RNN interface and was also a nightmare to interop with. Now that we have the (experimental) 3D path there has been talk of trying again (see Block parameters for RNNs (for cudnn paths) by mkschleg · Pull Request #1855 · FluxML/Flux.jl · GitHub from @mkschleg himself), but it is far from trivial and might require a complete rethinking of the RNN interface[1].

That warning was added in July, many months after this thread wrapped up. It was triggered by Strange result with gradient · Issue #1547 · FluxML/Flux.jl · GitHub, which happened in March/April. That said, thank you for bringing it up so that future readers are aware of the warning to not use map or broadcast with RNNs.


  1. To be clear, few of us like the current interface and it has a number of known issues/limitations. Designing a new interface for anything though can take some effort, especially when you have to consider stability + backwards compatibility like Flux does. See RNN design for efficient CUDNN usage · Issue #1365 · FluxML/Flux.jl · GitHub for some further reading. ↩︎

@ToucheSir is correct. Mapping to CUDNN is a mess and wasn’t worth the performance gain with a naive version from what I remember. CUDNN only really supports fast operations for a subset of RNN types (i.e. activations, archs, etc…) anyway. I’ve thought a bit on how we might accomplish this interop at compile time rather than run-time, but I haven’t had time to do anything about it.

The best option (imo) would be to have Julia kernels which follow the algs developed by Nvidia for fast RNN feed-forward/backward pass. This could mean we could have the faster algs for a wider range of possible archs. But this is even less trivial and requires a lot of engineering work (much of which I don’t have the time/the expertise for).

1 Like