Flux How to convert model weights to Float16

m  = Dense(10, 2)
# now the initialized weights are Float32 type.
m2 = Dense(rand(Float16, 2, 10)) # works

When I init the weights it is fine but I have an existing complex model with many layers and I would like Float16 version of it.


Looking at @less f32([1.0]) shows that this ought to work:

f16(m) = Flux.paramtype(Float16, m)

Probably this should be built-in, if someone wants to make a PR.

1 Like

Thanks that does exactly what I need. Can you say more about what you mean by this should be built in?

We can also have something like CUDA.jl’s device! ?