Generic way to change float precision in FluxML

I’ve recently started using FastAI.jl and firstly want to say how much I appreciate the efforts of anyone involved here - this is really cool!

When going through chapter 5 of the fast.ai book, I just came across the point on how decreasing float precision can help speed up training (this is of course also mentioned in the Flux docs and additionally encouraged through warnings). To this end, the Python fast.ai package comes with a simple callback function:

from fastai.callback.fp16 import *
learn = cnn_learner(dls, resnet50, metrics=error_rate).to_fp16()

I don’t think this is available in FastAI.jl or FluxTraining.jl, probably because it’s easy enough to ensure lower precision further upstream. Still, it would be nice to be able to just take a Learner() and change the floating point precision. Is that possible?

Edit: The error below happens when calling FastAI.showbatch on the data loader, so is related to plotting only.

Relatedly, I just found that the following

task = BlockTask(
    blocks,
    (   
        ProjectiveTransforms(
            (_resize, _resize), 
            sharestate=false,
        ),
        ImagePreprocessing(T=Float16),
        OneHot(),
    )
)

leads to an error when calling FastAI.showbatch. The following

batchsize = 3
train_dl, val_dl = taskdataloaders(train_data, val_data, task, batchsize)
showbatch(task, first(train_dl))

throws this error:

ArgumentError: N0f8 is an 8-bit type representing 256 values from 0.0 to 1.0; cannot represent -0.00024414062

Is that expected behaviour?

You were right to suspect something upstream. Calling f16 on the model is probably the easiest way to go. The built-in ToGPU callback uses FluxTraining.jl - ToDevice under the hood, so you should be able to use that to also convert precision. Something like ToDevice(gpu∘f16, gpu∘f16). That could also save you from having to set the precision of the data too early in ImagePreprocessing(T=Float16), though without a full stacktrace it’s hard to tell where the error is coming from.

1 Like

Just what I was looking for, thanks @ToucheSir, and sorry I missed this one in the docs