A possible way to improve training in Flux?

Let’s say I need my weights to be determined with an accuracy of 10^10 and hence want to train my network using Float64. Does it make sense to train my network with Float32 and then (after I get around 10^-7 accuracy) shift to Float64?

Also, I have a feeling that this idea of first training with smaller precision numbers (even Float16 ) and then moving on to higher precision numbers will speed up the training process a lot. Is there a reason why it isn’t inherently coded yet?

I can’t comment on your first paragraph, but the truth is most DL applications just don’t need the level of precision afforded by 64-bit floats. Normalization and other forms of regularization generally encourage smaller magnitude weights that can more effectively use the limited precision of float32 or even float16. Models that are sensitive to small weight perturbations are also more likely to be susceptible to adversarial attacks, while training with noisy data can improve network generalization.

With regards to speeding up training, mixed-precision has been gaining traction of late (see e.g. torch.cuda.amp).

1 Like

I’ll second the usage of mixed precision. But this raises a gotchya that I have been burned with a few times. Flux’s default floating point precision I believe is 32bit. So if you throw in a Float64 it doesn’t mean you will get a Float64 all the way through your chain. Someone more involved in the guts of Flux/Zygote can comment more, but a couple versions back I had serious troubles with this until I converted everything to Float32’s. One thing to check is that your input and reference/groundtruth/whatever values match precision at Float64.

Most applications do not actually require 64 bit floats though so usually not an issue, but consider the error propagation through a deep net - not so sure 1e-10 is easy but I’m too lazy to run the numbers. Even chaining a few simple LAPACK ops can make 1e-11 difficult iirc.

Again I’m not a pro, someone else can probably give more actionable insight.

1 Like

You’re right about this, but I seem to really need that extra precision. I am trying to use Flux to solve a PDE with numerical techniques and need this extra precision to compare between schemes. In fact, I even plan to use NVIDIA Quadro for this, which supports double precision calculations.

Given that the dataset is created by a code using Float64 values, I have a solution with noise generated due to machine precision errors only.

I just implemented Flux with Float64 successfully(and probably fixed the issues you may be referring to here) so I don’t have any issues till now.

As an example to why I need this Float64
Consider solving the Laplace Equation using higher order Stencils.

Now if I want to check if my network is able to capture the properties of the stencils along with their precision, I really need Float64 to do so. For my research work, I seem to really need a precision of at least 10^-9. I hope this example is clear enough.

2 Likes

I suspected there was a SciML component at play here! About as far as you can get from an expert in that area, so I’ll let those who actually know their stuff chime in. The original comment still stands as a response to “Is there a reason why it isn’t inherently coded yet?” however.

1 Like

Same, time to call in the pros… @ChrisRackauckas @JeffreySarnoff @dhairyagandhi96

Not a pro and I can’t comment on whether this would be useful, but doing it is easy in flux. Just use the same functor mapping function as one uses for cpu → gpu conversion:


julia> m = Chain(Dense(3,4), Dense(4,5));

julia> typeof.(params(m))
4-element Array{DataType,1}:
 Array{Float32,2}
 Array{Float32,1}
 Array{Float32,2}
 Array{Float32,1}

julia> cfun(x::AbstractArray) = Float64.(x); 

julia> cfun(x) = x; Noop for stuff which is not arrays (e.g. activation functions)

julia> m64 = Flux.fmap(cfun, m)
Chain(Dense(3, 4), Dense(4, 5))

julia> typeof.(params(m64))
4-element Array{DataType,1}:
 Array{Float64,2}
 Array{Float64,1}
 Array{Float64,2}
 Array{Float64,1}
3 Likes

I did consider SciML based on suggestions of @ChrisRackauckas (in Slack) and found that the approach taken by the package was not the same as what I had in mind. SciML tries to train a network (based on say 5-Point stencil for Poisson Equation) to obtain a solution, but I want to train a numerical scheme based on the solution itself (like what are the coefficients used in the 5-Point stencil scheme).

Flux also ships with f64 and f32 which act like the gpu function. You could use that directly as well.

I know that it is possible, I have seen many posts which have explained it.
My question is whether the performance of the training can be improved by first training the network with say f32 precision and then shift to f64 precision (or f16f32f64). This way (i think) we can get a high speedup in the initial few epochs of the training.

This will speed up training, but could make it miss some very narrow optima.

I think that will happen even otherwise, since training step size at the first few epochs is the same for a given network, irrespective of the datatype (assuming all other things same). The effect of Float64 only shows up when the accuracy of the weights is 10^-7 and beyond.

Time a few good examples of your computation, or possibly a simplified version of it, coding it each way (use BenchmarkTools.jl). Find out. (That’s what I would do.)

1 Like

Copying from the Slack, PDE stencils are convolutional layers so if you use them directly you’ll get fast GPU goodness so that’s the way to go. For discovering PDEs from data using the discovery of the discretization, it’s Section 2.3 https://arxiv.org/pdf/2001.04385.pdf

My question was not related to solving PDE using ML tools here (Slack question was different). Here I only explained that part to show why I need Float64 precision. My main question here is quite general, about improving the performance of training step of the code by progressively switching from Float16 to Float32 and then Float64, instead of using Float64 precision from the start of training till the end.

Oh sorry, forgot why I had this tab open :laughing:. Yes, changing precision in Julia is a fairly trivial Float32.(x) so abuse it to make multi-precision algorithms. Nick Higham gives a great talk on it:

2 Likes