I am toying with Large Language models and amazing Transformers.jl. And although I am forced to call PyTorch because some models are not supported by Transformers.jl (falcon) and due to some external forces, I got interested what kind of secret sauce HuggingFace has. One of them, according to this paper https://arxiv.org/pdf/1910.02054.pdf is a mix-mode training, where gradients, activations, and weights are for stored in Float16
, but optimization of weoghts is carried in Float32. So I was naturally curios, if Transformers.jl
would work off the shelf with Float16
.
I have started by writing a custom Adaptor
was a piece of cake. I just copy-pasted adaptor from Flux, which lead to.
struct FluxFloatAdaptor{T} end
function FluxFloatAdaptor(T::DataType)
!(T <: Real) && error("FluxFloatAdaptor is reserved for floats only")
FluxFloatAdaptor{T}()
end
# define rules for handling structured arrays
adapt_storage(to::FluxFloatAdaptor{T}, x::AbstractArray{S,N}) where {T,S,N} = adapt(Array{T,N}, x)
adapt_storage(to::FluxFloatAdaptor{T}, x::AbstractRange) where {T} = x
adapt_storage(to::FluxFloatAdaptor{T}, x::Zygote.FillArrays.AbstractFill) where {T} = x
adapt_storage(to::FluxFloatAdaptor{T}, x::X) where {T, X<: CUDA.CUSPARSE.CUDA.CUSPARSE.AbstractCuSparseMatrix} = adapt(Array, x)
adapt_storage(to::FluxFloatAdaptor{T}, x::Zygote.OneElement) where {T} = x
# adapt_storage(to::FluxFloatAdaptor{T}, x::AbstractSparseArray) where {T} = x
adapt_storage(to::FluxFloatAdaptor{T}, x::CUDA.RNG) where {T} = Random.default_rng()
adapt_storage(to::FluxFloatAdaptor{T}, x::AbstractRNG) where {T} = x
cpu_float16(x) = fmap(x -> adapt(FluxFloatAdaptor(Float16), x), x, exclude = _isleaf)
Equipped with that, I decided to benchmark inference and computation of gradient on GPT2.
textenc = hgf"gpt2:tokenizer"
model_32 = hgf"gpt2:lmheadmodel"
model_16 = cpu_float16(model_32)
gpu_model_32 = gpu(model_32)
gpu_model_16 = gpu(model_16)
tokens = encode(textenc, "Lorem ipsum dolor sit amet, consectetur adipiscing elit. In tristique iaculis arcu. Nullam non purus facilisis, dignissim lorem ut, consectetur nunc. Pellentesque sit amet tortor suscipit odio ultrices egestas at vel tellus. Donec molestie, mauris sed blandit gravida, turpis risus lacinia nulla, ut eleifend nulla justo non justo. Donec finibus dolor non turpis imperdiet dapibus. Integer venenatis ex ut ex cursus venenatis. Interdum et malesuada fames ac ante ipsum primis in faucibus. Aenean varius sapien vel enim molestie aliquet. Maecenas mi leo, dignissim a gravida eget, vestibulum et lorem. Morbi malesuada in metus vel lobortis.")
tokens = gpu(tokens)
julia> @benchmark CUDA.@sync gpu_model_16(tokens)
BenchmarkTools.Trial: 1008 samples with 1 evaluation.
Range (min β¦ max): 4.244 ms β¦ 52.717 ms β GC (min β¦ max): 0.00% β¦ 34.93%
Time (median): 4.402 ms β GC (median): 0.00%
Time (mean Β± Ο): 4.955 ms Β± 4.738 ms β GC (mean Β± Ο): 3.50% Β± 3.33%
βββ β
ββββββ
βββββ
ββββ
βββββββββββββββββββββββββββββββββββββββββββ β
4.24 ms Histogram: log(frequency) by time 11.8 ms <
Memory estimate: 377.58 KiB, allocs estimate: 6684.
julia> @benchmark CUDA.@sync gpu_model_32(tokens)
BenchmarkTools.Trial: 542 samples with 1 evaluation.
Range (min β¦ max): 7.961 ms β¦ 44.150 ms β GC (min β¦ max): 0.00% β¦ 44.35%
Time (median): 8.526 ms β GC (median): 0.00%
Time (mean Β± Ο): 9.231 ms Β± 4.299 ms β GC (mean Β± Ο): 3.08% Β± 5.12%
ββ β
βββ
βββ
ββββββββββββββββββββββββββββββββββββββββββββββββββββ β
7.96 ms Histogram: log(frequency) by time 37.9 ms <
Memory estimate: 397.30 KiB, allocs estimate: 6901.
According to benchmarks (I run the benchmark twice to precompile), the inference with Float16
is almost twice as fast as that with Float32
which pretty match the paper.
Letβs now try gradient as
julia> @benchmark CUDA.@sync gradient(m -> sum(m(tokens).hidden_state), gpu_model_16)
BenchmarkTools.Trial: 277 samples with 1 evaluation.
Range (min β¦ max): 14.271 ms β¦ 53.894 ms β GC (min β¦ max): 0.00% β¦ 28.71%
Time (median): 15.378 ms β GC (median): 0.00%
Time (mean Β± Ο): 18.039 ms Β± 9.122 ms β GC (mean Β± Ο): 5.36% Β± 6.97%
βββ
βββββ
βββ
βββββββββββββββββββββββββββββββββββββββββββββββββ
ββ β
14.3 ms Histogram: log(frequency) by time 51.8 ms <
Memory estimate: 1.92 MiB, allocs estimate: 23417.
julia> @benchmark CUDA.@sync gradient(m -> sum(m(tokens).hidden_state), gpu_model_32)
BenchmarkTools.Trial: 157 samples with 1 evaluation.
Range (min β¦ max): 26.911 ms β¦ 61.454 ms β GC (min β¦ max): 0.00% β¦ 18.83%
Time (median): 28.012 ms β GC (median): 0.00%
Time (mean Β± Ο): 31.922 ms Β± 8.795 ms β GC (mean Β± Ο): 5.24% Β± 7.89%
ββ
β β
βββββββββββββββββ
βββββββββββββββββββββββββββββββ
βββββββββββ β
26.9 ms Histogram: log(frequency) by time 56.5 ms <
Memory estimate: 1.95 MiB, allocs estimate: 23769.
where we again see the same story. Nice. So far so good. With few lines of code, we do Float16
on GPU.
Now, the remaining part to work-out is update of the model. The weights of the model used for training are kept in Float32
and then converted to Float16
. Letβs now try to update the model with Float32
weights with gradient stored in Float16.
opt_state = Flux.setup(ADAM(), gpu_model_32);
julia> @benchmark CUDA.@sync Flux.update!(opt_state, gpu_model_32, gs_16)
BenchmarkTools.Trial: 134 samples with 1 evaluation.
Range (min β¦ max): 36.914 ms β¦ 52.927 ms β GC (min β¦ max): 0.00% β¦ 0.00%
Time (median): 36.974 ms β GC (median): 0.00%
Time (mean Β± Ο): 37.485 ms Β± 2.385 ms β GC (mean Β± Ο): 0.00% Β± 0.00%
β
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
β
36.9 ms Histogram: log(frequency) by time 51 ms <
Memory estimate: 2.63 MiB, allocs estimate: 32047.
julia> @benchmark CUDA.@sync Flux.update!(opt_state, gpu_model_32, gs_32)
BenchmarkTools.Trial: 134 samples with 1 evaluation.
Range (min β¦ max): 36.881 ms β¦ 53.180 ms β GC (min β¦ max): 0.00% β¦ 0.00%
Time (median): 36.986 ms β GC (median): 0.00%
Time (mean Β± Ο): 37.461 ms Β± 2.623 ms β GC (mean Β± Ο): 0.00% Β± 0.00%
β
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
β
36.9 ms Histogram: log(frequency) by time 52.5 ms <
Memory estimate: 2.69 MiB, allocs estimate: 32893.
That seems to work (does not give errors) as well. There is no performance gain, but no penalty, so I consider this fine. I am surprised here that the update of the model is more expensive that computing the gradient. This does not seem to be right, since the operation should be effectively parallel. It is possible that I have some inefficiency, which I am not aware of (help wanted).
The last thing to remain is to copy Float32
weights stored in gpu_model_32
to gpu_model_16
to compute finish the cycle. This is I do not know, how to do efficiently. The adaptor approach I have used above for GPU written as
struct CUDAFloatAdaptor{T} end
function CUDAFloatAdaptor(T::DataType)
!(T <: Real) && error("CUDAFloatAdaptor is reserved for floats only")
CUDAFloatAdaptor{T}()
end
adapt_storage(to::CUDAFloatAdaptor{T}, x::AbstractArray{S,N}) where {T,S,N} = adapt(CUDA.CuArray{T}, x)
adapt_storage(to::CUDAFloatAdaptor{T}, x::AbstractRange) where {T} = x
adapt_storage(to::CUDAFloatAdaptor{T}, x::Zygote.FillArrays.AbstractFill) where {T} = adapt(CUDA.CuArray{T}, collect(x))
adapt_storage(to::CUDAFloatAdaptor{T}, x::X) where {T, X<: CUDA.CUSPARSE.CUDA.CUSPARSE.AbstractCuSparseMatrix} = adapt(Array, x)
adapt_storage(to::CUDAFloatAdaptor{T}, x::Zygote.OneElement) where {T} = x
adapt_storage(to::CUDAFloatAdaptor, x::Random.TaskLocalRNG) = CUDA.default_rng()
# adapt_storage(to::CUDAFloatAdaptor{T}, x::AbstractSparseArray) where {T} = x
adapt_storage(to::CUDAFloatAdaptor{T}, x::CUDA.RNG) where {T} = x
adapt_storage(to::CUDAFloatAdaptor, x::AbstractRNG) =
error("Cannot map RNG of type $(typeof(x)) to GPU. GPU execution only supports Random.default_rng().")
adapt_storage(to::CUDAFloatAdaptor{T}, x::Zygote.OneElement) where {T} = CUDA.CuArray{T}(collect(x))
adapt_storage(to::CUDAFloatAdaptor{T}, x::AbstractRNG) where {T} = x
gpu_float16(x) = fmap(x -> adapt(CUDAFloatAdaptor(Float16), x), x, exclude = _isleaf)
is horribly slow
julia> @benchmark gpu_float16(gpu_model_32)
BenchmarkTools.Trial: 31 samples with 1 evaluation.
Range (min β¦ max): 152.372 ms β¦ 180.391 ms β GC (min β¦ max): 2.60% β¦ 15.52%
Time (median): 162.814 ms β GC (median): 7.14%
Time (mean Β± Ο): 161.825 ms Β± 7.902 ms β GC (mean Β± Ο): 7.90% Β± 4.25%
β β β β
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ β
152 ms Histogram: frequency by time 180 ms <
Memory estimate: 712.20 MiB, allocs estimate: 2814.
Here, I am effectively stuck. Suggestions and criticisms welcomed.