You were right to suspect something upstream. Calling f16
on the model is probably the easiest way to go. The built-in ToGPU
callback uses FluxTraining.jl - ToDevice under the hood, so you should be able to use that to also convert precision. Something like ToDevice(gpu∘f16, gpu∘f16)
. That could also save you from having to set the precision of the data too early in ImagePreprocessing(T=Float16)
, though without a full stacktrace it’s hard to tell where the error is coming from.
1 Like