Flux: Embeddings on GPU

In a regression model I would like to use embeddings for the categorical variables. My current implementation is like the simplified example below. On CPU it works seemingly fine. Any suggestions for improvement are welcome, though. On GPU, however, problems occur (with Flux 0.11.6 and Julia 1.5.3). How would you implement embedding in this case?

using Flux

struct Embed{T}
    w::T
end
Flux.@functor Embed
Embed(in::Integer, out::Integer; initW=Flux.glorot_uniform) = Embed(initW(out, in))
(m::Embed)(x::AbstractVector) = m.w[:,x]

# concat two embedded variables with continuous vars
function create_model(embed1, embed2, main_model)
    return function(x)
        x1 = embed1(x[1])
        x2 = embed2(x[2])
        x3 = cat(x1, x2, x[3], dims=1)
        return main_model(x3)
    end, Flux.params(embed1, embed2, main_model)
end

function train(n, d, p; device = cpu)         
    x = (rand(1:d, n), rand(1:d, n), rand(Float32, p, n)) |> device
    y = rand(Float32, n) |> device
    trdata = Flux.Data.DataLoader((x, y), batchsize=100) |> device
    m, prm = create_model(Embed(d, 3), Embed(d, 3), Dense(p+6, 1)) |> device
    loss(x, y) = Flux.mse(m(x), y)
    @time Flux.train!(loss, prm, trdata, ADAM()) 
end
train(100_000, 100, 50, device = gpu)

┌ Warning: Performing scalar operations on GPU arrays: This is very slow, consider disallowing these operations with `allowscalar(false)`
└ @ GPUArrays ~/.julia/packages/GPUArrays/WV76E/src/host/indexing.jl:43
ERROR: GPU compilation of kernel broadcast_kernel(CUDA.CuKernelContext, CUDA.CuDeviceArray{Float32,2,1}, Base.Broadcast.Broadcasted{Nothing,Tuple{Base.OneTo{Int64},Base.OneTo{Int64}},typeof(-),Tuple{Base.Broadcast.Extruded{Array{Float32,2},Tuple{Bool,Bool},Tuple{Int64,Int64}},Base.Broadcast.Extruded{CUDA.CuDeviceArray{Float32,1,1},Tuple{Bool},Tuple{Int64}}}}, Int64) failed
KernelError: passing and using non-bitstype argument

Argument 4 to your kernel function is of type Base.Broadcast.Broadcasted{Nothing,Tuple{Base.OneTo{Int64},Base.OneTo{Int64}},typeof(-),Tuple{Base.Broadcast.Extruded{Array{Float32,2},Tuple{Bool,Bool},Tuple{Int64,Int64}},Base.Broadcast.Extruded{CUDA.CuDeviceArray{Float32,1,1},Tuple{Bool},Tuple{Int64}}}}, which is not isbits:
  .args is of type Tuple{Base.Broadcast.Extruded{Array{Float32,2},Tuple{Bool,Bool},Tuple{Int64,Int64}},Base.Broadcast.Extruded{CUDA.CuDeviceArray{Float32,1,1},Tuple{Bool},Tuple{Int64}}} which is not isbits.
    .1 is of type Base.Broadcast.Extruded{Array{Float32,2},Tuple{Bool,Bool},Tuple{Int64,Int64}} which is not isbits.
      .x is of type Array{Float32,2} which is not isbits.

Stacktrace:
 [1] check_invocation(::GPUCompiler.CompilerJob, ::LLVM.Function) at /home/johnbb/.julia/packages/GPUCompiler/uTpNx/src/validation.jl:68
 [2] macro expansion at /home/johnbb/.julia/packages/GPUCompiler/uTpNx/src/driver.jl:238 [inlined]
 [3] macro expansion at /home/johnbb/.julia/packages/TimerOutputs/ZmKD7/src/TimerOutput.jl:206 [inlined]
 [4] codegen(::Symbol, ::GPUCompiler.CompilerJob; libraries::Bool, deferred_codegen::Bool, optimize::Bool, strip::Bool, validate::Bool, only_entry::Bool) at /home/johnbb/.julia/packages/GPUCompiler/uTpNx/src/driver.jl:237
 [5] compile(::Symbol, ::GPUCompiler.CompilerJob; libraries::Bool, deferred_codegen::Bool, optimize::Bool, strip::Bool, validate::Bool, only_entry::Bool) at /home/johnbb/.julia/packages/GPUCompiler/uTpNx/src/driver.jl:39
 [6] compile at /home/johnbb/.julia/packages/GPUCompiler/uTpNx/src/driver.jl:35 [inlined]
 [7] cufunction_compile(::GPUCompiler.FunctionSpec; kwargs::Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}}) at /home/johnbb/.julia/packages/CUDA/wTQsK/src/compiler/execution.jl:302
 [8] cufunction_compile(::GPUCompiler.FunctionSpec) at /home/johnbb/.julia/packages/CUDA/wTQsK/src/compiler/execution.jl:297
 [9] check_cache(::Dict{UInt64,Any}, ::Any, ::Any, ::GPUCompiler.FunctionSpec{GPUArrays.var"#broadcast_kernel#12",Tuple{CUDA.CuKernelContext,CUDA.CuDeviceArray{Float32,2,1},Base.Broadcast.Broadcasted{Nothing,Tuple{Base.OneTo{Int64},Base.OneTo{Int64}},typeof(-),Tuple{Base.Broadcast.Extruded{Array{Float32,2},Tuple{Bool,Bool},Tuple{Int64,Int64}},Base.Broadcast.Extruded{CUDA.CuDeviceArray{Float32,1,1},Tuple{Bool},Tuple{Int64}}}},Int64}}, ::UInt64; kwargs::Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}}) at /home/johnbb/.julia/packages/GPUCompiler/uTpNx/src/cache.jl:40
 [10] broadcast_kernel at /home/johnbb/.julia/packages/GPUArrays/WV76E/src/host/broadcast.jl:60 [inlined]
 [11] cached_compilation at /home/johnbb/.julia/packages/GPUCompiler/uTpNx/src/cache.jl:65 [inlined]
 [12] cufunction(::GPUArrays.var"#broadcast_kernel#12", ::Type{Tuple{CUDA.CuKernelContext,CUDA.CuDeviceArray{Float32,2,1},Base.Broadcast.Broadcasted{Nothing,Tuple{Base.OneTo{Int64},Base.OneTo{Int64}},typeof(-),Tuple{Base.Broadcast.Extruded{Array{Float32,2},Tuple{Bool,Bool},Tuple{Int64,Int64}},Base.Broadcast.Extruded{CUDA.CuDeviceArray{Float32,1,1},Tuple{Bool},Tuple{Int64}}}},Int64}}; name::Nothing, kwargs::Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}}) at /home/johnbb/.julia/packages/CUDA/wTQsK/src/compiler/execution.jl:289
 [13] cufunction at /home/johnbb/.julia/packages/CUDA/wTQsK/src/compiler/execution.jl:286 [inlined]
 [14] macro expansion at /home/johnbb/.julia/packages/CUDA/wTQsK/src/compiler/execution.jl:100 [inlined]
 [15] #launch_heuristic#857 at /home/johnbb/.julia/packages/CUDA/wTQsK/src/gpuarrays.jl:17 [inlined]
 [16] launch_heuristic at /home/johnbb/.julia/packages/CUDA/wTQsK/src/gpuarrays.jl:17 [inlined]
 [17] copyto! at /home/johnbb/.julia/packages/GPUArrays/WV76E/src/host/broadcast.jl:66 [inlined]
 [18] copyto! at ./broadcast.jl:886 [inlined]
 [19] copy at ./broadcast.jl:862 [inlined]
 [20] materialize at ./broadcast.jl:837 [inlined]
 [21] adjoint at /home/johnbb/.julia/packages/Zygote/KpME9/src/lib/broadcast.jl:73 [inlined]
 [22] _pullback at /home/johnbb/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:57 [inlined]
 [23] #mse#6 at /home/johnbb/.julia/packages/Flux/goUGu/src/losses/functions.jl:17 [inlined]
 [24] _pullback(::Zygote.Context, ::Flux.Losses.var"##mse#6", ::typeof(Statistics.mean), ::typeof(Flux.Losses.mse), ::Array{Float32,2}, ::CUDA.CuArray{Float32,1}) at /home/johnbb/.julia/packages/Zygote/KpME9/src/compiler/interface2.jl:0
 [25] mse at /home/johnbb/.julia/packages/Flux/goUGu/src/losses/functions.jl:17 [inlined]
 [26] _pullback(::Zygote.Context, ::typeof(Flux.Losses.mse), ::Array{Float32,2}, ::CUDA.CuArray{Float32,1}) at /home/johnbb/.julia/packages/Zygote/KpME9/src/compiler/interface2.jl:0
 [27] loss at ./REPL[10]:6 [inlined]
 [28] _pullback(::Zygote.Context, ::var"#loss#7"{var"#4#5"{Embed{Array{Float32,2}},Embed{Array{Float32,2}},Dense{typeof(identity),Array{Float32,2},Array{Float32,1}}}}, ::Tuple{CUDA.CuArray{Int64,1},CUDA.CuArray{Int64,1},CUDA.CuArray{Float32,2}}, ::CUDA.CuArray{Float32,1}) at /home/johnbb/.julia/packages/Zygote/KpME9/src/compiler/interface2.jl:0
 [29] adjoint at /home/johnbb/.julia/packages/Zygote/KpME9/src/lib/lib.jl:188 [inlined]
 [30] _pullback at /home/johnbb/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:57 [inlined]
 [31] #15 at /home/johnbb/.julia/packages/Flux/goUGu/src/optimise/train.jl:103 [inlined]
 [32] _pullback(::Zygote.Context, ::Flux.Optimise.var"#15#21"{var"#loss#7"{var"#4#5"{Embed{Array{Float32,2}},Embed{Array{Float32,2}},Dense{typeof(identity),Array{Float32,2},Array{Float32,1}}}},Tuple{Tuple{CUDA.CuArray{Int64,1},CUDA.CuArray{Int64,1},CUDA.CuArray{Float32,2}},CUDA.CuArray{Float32,1}}}) at /home/johnbb/.julia/packages/Zygote/KpME9/src/compiler/interface2.jl:0
 [33] pullback(::Function, ::Zygote.Params) at /home/johnbb/.julia/packages/Zygote/KpME9/src/compiler/interface.jl:167
 [34] gradient(::Function, ::Zygote.Params) at /home/johnbb/.julia/packages/Zygote/KpME9/src/compiler/interface.jl:48
 [35] macro expansion at /home/johnbb/.julia/packages/Flux/goUGu/src/optimise/train.jl:102 [inlined]
 [36] macro expansion at /home/johnbb/.julia/packages/Juno/n6wyj/src/progress.jl:134 [inlined]
 [37] train!(::Function, ::Zygote.Params, ::Flux.Data.DataLoader{Tuple{Tuple{CUDA.CuArray{Int64,1},CUDA.CuArray{Int64,1},CUDA.CuArray{Float32,2}},CUDA.CuArray{Float32,1}}}, ::ADAM; cb::Flux.Optimise.var"#16#22") at /home/johnbb/.julia/packages/Flux/goUGu/src/optimise/train.jl:100
 [38] train!(::Function, ::Zygote.Params, ::Flux.Data.DataLoader{Tuple{Tuple{CUDA.CuArray{Int64,1},CUDA.CuArray{Int64,1},CUDA.CuArray{Float32,2}},CUDA.CuArray{Float32,1}}}, ::ADAM) at /home/johnbb/.julia/packages/Flux/goUGu/src/optimise/train.jl:98
 [39] macro expansion at ./timing.jl:174 [inlined]
 [40] train(::Int64, ::Int64, ::Int64; device::Function) at ./REPL[10]:7
 [41] top-level scope at REPL[12]:1

Have a look at How to implement embeddings in Flux that aren't tragically slow? - #2 by dhairyagandhi96. Embeddings are an interesting case because they’re trivial to implement as a loop on GPU, but extremely difficult to express as a vectorized computation. If you just want something that works, Transformers.jl/embed.jl at master · chengchingwen/Transformers.jl · GitHub has a working implementation and there’s a PR out to add something like it to NNlib.

1 Like

In my experience, your approach to Embed usage did work fine. At least, in my use case, adding such embedding layer to a straightforward MLP structure didn’t result in abnormal slowdown nor warning message about fallback to scalar operation on GPU.

I did a minimal working example from your example. I’m not sure what is the key component that seems to solve the embedding pain point (potentially the type inference hint provided to the MyModel definition wrapped as a functor?), but hopefully it should provide a comparison point to nail it down.

using Flux
using Flux: @functor
struct Embed{T}
    w::T
end
Flux.@functor Embed
Embed(in::Integer, out::Integer; initW=Flux.glorot_uniform) = Embed(initW(out, in))
(m::Embed)(x::AbstractVector) = m.w[:,x]

# concat two embedded variables with continuous vars
struct MyModel{E,D}
    e1::E
    e2::E
    d::D
end

@functor MyModel

function (m::MyModel)(x1::AbstractVector, x2::AbstractVector, x3::AbstractMatrix)
    e1 = m.e1(x1)
    e2 = m.e2(x2)
    agg = cat(e1, e2, x3, dims=1)
    out =m.d(agg)
    return out
end
    
d = 5
p = 4
e = 3
n = 1000
m = MyModel(Embed(d, e), Embed(d, e), Dense(p + 2*e, 1)) |> gpu
ps = params(m)

x = (rand(1:d, n), rand(1:d, n), rand(Float32, p, n)) |> gpu
y = rand(Float32, n) |> gpu

m(x...)

julia> m(x...)
1×1000 CUDA.CuArray{Float32,2}:
 -0.891172  -1.50137  -0.39195  -0.793842  -1.20421  -1.4277  -0.935811  -0.97199  -0.592833  -0.494732  -1.34041  -0.892502  -0.419997  …  -0.649078  -0.386578  -0.777367  -1.13931  -0.505382  -1.66707  -0.258669  -1.19645  -1.23625  -0.968328  -0.33922  -0.764131
2 Likes

Thanks @jeremiedb. I can confirm that your approach work. On a larger model/data set using the GPU was about 4 times faster than on CPU.

Is there any elegant way to make MyModel invariant to the number of variables to be embedded?

Please not that there unfortunately appears to be a bug at the moment with the GPU indexing approach as discussed here.
The ongoing PR for embedding layer in Flux might be the safest route for now: add Embedding layer by CarloLucibello · Pull Request #1516 · FluxML/Flux.jl · GitHub

Thanks. I guess I would have noticed if I had tested on real data. Good to see that there is ongoing work on including embedding in Flux.