Yes, Float32 will often be twice as fast or better. It’s the default for Flux but not for Julia, and a very common performance problem is to mix the two (which can easily be 10x slower than either alone).

Edit – with an example:

julia> const m1 = Chain(Dense(100,40,relu), Dense(40,40,relu), Dense(40,1,identity));

julia> x1 = randn(Float32, 100, 6000);

julia> @btime gradient(sum∘m1, $x1); # all 32
  2.914 ms (139 allocations: 11.54 MiB)

julia> @btime gradient(sum∘m1, $(Float64.(x1))); # mixed
  94.752 ms (3000224 allocations: 66.48 MiB)

julia> @btime gradient(sum∘$(Flux.f64(m1)), $(Float64.(x1))); # all 64
  7.222 ms (140 allocations: 23.08 MiB)

If m1 is the whole model, I think this is going to evaluate it twice. You also take exp.() twice, which is relatively expensive. You may want to try something like this:

function xent_loss(x0,y0)
    x = m1(x0)
    xbar = mean(x)
    expx = exp.(x .- xbar)
    pden = expx * y0[2:end,:]
    lossc = -sum(y0[1,:]' .* log.(expx ./ pden))

This will still allocate quite a bit… I think Zygote doesn’t respect fused broadcasts, so will make a lot of intermediate arrays here.


Thanks, those are helpful, though my previous attempts to use elementwise operations with Zygote have been problematic.

I will try as many of your suggestions as I can. Thanks!

PS. I implemented your suggestions and now each batch iteration takes 10 seconds instead of 15!
PPS. I had tried using the functional gradients and representing the loss function in a series expansion in hopes of making the gradients easier for Zygote, but it didn’t accept it. Also I needed the ‘pullback’ function as a way of getting the unlinked values from the model (with tracker it was merely model(x).data which was fast). Is there any quicker way to get the values of the model (disentagled from Zygote)?
Thanks again!