I believe that fmap
returns a copy of the model with the updated parameters, instead of doing it in-place. If you do mod = fmap(f64, mod)
are all of mod
’s parameters still Float32
?
Edit: this seems to work for me, on Flux version 0.10.0:
julia> using Flux
[ Info: CUDAdrv.jl failed to initialize, GPU functionality unavailable (set JULIA_CUDA_SILENT or JULIA_CUDA_VERBOSE to silence or expand this message)
julia> m = Chain(Dense(4, 4, relu), Dense(4, 4), softmax)
Chain(Dense(4, 4, relu), Dense(4, 4), softmax)
julia> m[1].W
4×4 Array{Float32,2}:
0.823812 -0.593816 -0.799553 -0.570861
-0.242723 -0.529529 0.012944 0.740021
0.00234743 0.542825 -0.627849 -0.746003
0.640687 0.0562518 0.183272 0.300056
julia> m = fmap(f64, m)
Chain(Dense(4, 4, relu), Dense(4, 4), softmax)
julia> m[1].W
4×4 Array{Float64,2}:
0.823812 -0.593816 -0.799553 -0.570861
-0.242723 -0.529529 0.012944 0.740021
0.00234743 0.542825 -0.627849 -0.746003
0.640687 0.0562518 0.183272 0.300056
And, in fact, I’ve just learned that you can skip the fmap
altogether and just call f64
on the model directly.
julia> m = Chain(Dense(4, 4, relu), softmax) |> f64
Chain(Dense(4, 4, relu), softmax)
julia> m[1].W
4×4 Array{Float64,2}:
-0.378957 0.74928 0.776147 0.804017
-0.298902 -0.494922 -0.519403 -0.668515
0.0772961 -0.219601 0.776379 -0.067537
0.64462 -0.259356 -0.697054 -0.218149