How to control output type with Flux?

Here is a MWE:

using Flux

function NeuralNetwork(; nb_in=1, nb_out=1, layer_size=8, nb_hid_layers=1)
    # Note the use of Any[] since the Dense layers are of different types
    layers = Any[Flux.Dense(nb_in => layer_size, tanh)]
    push!(layers, Flux.Dense(layer_size => nb_out))
    return Chain(layers...)
end

# Set up Neural Netoork (input layer, hidden layers, output layer)
model_univ = NeuralNetwork(nb_in=9, nb_out=9, layer_size=8, nb_hid_layers=0)
p_model, re = Flux.destructure(model_univ)
p_model = zeros(Float64, length(p_model))

input = rand(Float64, 9)

model_univ(input)

The last line produces output of type Float32, even though the weights p_model is of type Float64. I cannot figure out how to generate output data of type Float64. Any help is appreciated.

Flux now standardises on the eltype of the weight arrays. This is because accidental promotion to Float64 was a very common huge performance bug. Your network has Float32 weights because this is the default of Dense (and all other layers).

If you want to run the network with Float64 everywhere, then you want f64(model) or Chain(layers...) |> f64. Just once, after construction.

If you want Float64 numbers out of a network with Float32 weights, you could for instance write Chain(layers..., f64). The input should then ideally be Float32 (it will be converted if not).

2 Likes

Thanks! I never would have figured this out. I am debugging code, and this is a temporary phase.

The first time I run this, it prints a warning:

julia> model_univ(input)
┌ Warning: Layer with Float32 parameters got Float64 input.
│   The input will be converted, but any earlier layers may be very slow.
│   layer = Dense(9 => 8, tanh)  # 80 parameters
│   summary(x) = "9-element Vector{Float64}"
└ @ Flux ~/.julia/packages/Flux/Nzh8J/src/layers/stateless.jl:50
9-element Vector{Float32}:
 -0.19098403
  0.20039776
  0.031824134

Maybe this should suggest f64. It is certainly easy to miss but making it print every time would be pretty noisy.

Is there an obvious place in the documentation to mention this?