RNN only works with Float32?

Hello everyone,

I wanted to build an RNN model and it seems to be the case that Flux only accepts Float32 as inputs. For instance if I run the following code

using Flux
rnn = RNN(1,1)
rnn([1]) # Gives error
rnn([convert(Float32, 1)]) # Works

I was wondering why this is the case and how could one change it to accept other types.

Thanks!

1 Like

I’m not sure the error is intentional, but flux tends to warn about mismatched datatypes because it can make models much slower and memory consuming.

Im also a bit surprised that the second example works as I thought flux Rnns require 2D input (nfeatures x batchsize).

If you want to use Float64 instead of Float32 you can use rnn64 = Flux.f64(rnn) to get a copy of rnn with all parameters as Float64. There is also Flux.paramtype which can be used to change parameters to any type.

Using integer parameters will most likely fail if you try to train the model though.

A few shorter ways to get Float32 input are

rnn(Float32[1])

rnn(ones(Float32, 1, 1))

3 Likes