Yes, Float32 will often be twice as fast or better. It’s the default for Flux but not for Julia, and a very common performance problem is to mix the two (which can easily be 10x slower than either alone).
Edit – with an example:
julia> const m1 = Chain(Dense(100,40,relu), Dense(40,40,relu), Dense(40,1,identity));
julia> x1 = randn(Float32, 100, 6000);
julia> @btime gradient(sum∘m1, $x1); # all 32
2.914 ms (139 allocations: 11.54 MiB)
julia> @btime gradient(sum∘m1, $(Float64.(x1))); # mixed
94.752 ms (3000224 allocations: 66.48 MiB)
julia> @btime gradient(sum∘$(Flux.f64(m1)), $(Float64.(x1))); # all 64
7.222 ms (140 allocations: 23.08 MiB)
If m1
is the whole model, I think this is going to evaluate it twice. You also take exp.()
twice, which is relatively expensive. You may want to try something like this:
function xent_loss(x0,y0)
x = m1(x0)
xbar = mean(x)
expx = exp.(x .- xbar)
pden = expx * y0[2:end,:]
lossc = -sum(y0[1,:]' .* log.(expx ./ pden))
end
This will still allocate quite a bit… I think Zygote doesn’t respect fused broadcasts, so will make a lot of intermediate arrays here.