Generic way to change float precision in FluxML

You were right to suspect something upstream. Calling f16 on the model is probably the easiest way to go. The built-in ToGPU callback uses FluxTraining.jl - ToDevice under the hood, so you should be able to use that to also convert precision. Something like ToDevice(gpu∘f16, gpu∘f16). That could also save you from having to set the precision of the data too early in ImagePreprocessing(T=Float16), though without a full stacktrace it’s hard to tell where the error is coming from.

1 Like