In a regression model I would like to use embeddings for the categorical variables. My current implementation is like the simplified example below. On CPU it works seemingly fine. Any suggestions for improvement are welcome, though. On GPU, however, problems occur (with Flux 0.11.6 and Julia 1.5.3). How would you implement embedding in this case?
using Flux
struct Embed{T}
    w::T
end
Flux.@functor Embed
Embed(in::Integer, out::Integer; initW=Flux.glorot_uniform) = Embed(initW(out, in))
(m::Embed)(x::AbstractVector) = m.w[:,x]
# concat two embedded variables with continuous vars
function create_model(embed1, embed2, main_model)
    return function(x)
        x1 = embed1(x[1])
        x2 = embed2(x[2])
        x3 = cat(x1, x2, x[3], dims=1)
        return main_model(x3)
    end, Flux.params(embed1, embed2, main_model)
end
function train(n, d, p; device = cpu)         
    x = (rand(1:d, n), rand(1:d, n), rand(Float32, p, n)) |> device
    y = rand(Float32, n) |> device
    trdata = Flux.Data.DataLoader((x, y), batchsize=100) |> device
    m, prm = create_model(Embed(d, 3), Embed(d, 3), Dense(p+6, 1)) |> device
    loss(x, y) = Flux.mse(m(x), y)
    @time Flux.train!(loss, prm, trdata, ADAM()) 
end
train(100_000, 100, 50, device = gpu)
┌ Warning: Performing scalar operations on GPU arrays: This is very slow, consider disallowing these operations with `allowscalar(false)`
└ @ GPUArrays ~/.julia/packages/GPUArrays/WV76E/src/host/indexing.jl:43
ERROR: GPU compilation of kernel broadcast_kernel(CUDA.CuKernelContext, CUDA.CuDeviceArray{Float32,2,1}, Base.Broadcast.Broadcasted{Nothing,Tuple{Base.OneTo{Int64},Base.OneTo{Int64}},typeof(-),Tuple{Base.Broadcast.Extruded{Array{Float32,2},Tuple{Bool,Bool},Tuple{Int64,Int64}},Base.Broadcast.Extruded{CUDA.CuDeviceArray{Float32,1,1},Tuple{Bool},Tuple{Int64}}}}, Int64) failed
KernelError: passing and using non-bitstype argument
Argument 4 to your kernel function is of type Base.Broadcast.Broadcasted{Nothing,Tuple{Base.OneTo{Int64},Base.OneTo{Int64}},typeof(-),Tuple{Base.Broadcast.Extruded{Array{Float32,2},Tuple{Bool,Bool},Tuple{Int64,Int64}},Base.Broadcast.Extruded{CUDA.CuDeviceArray{Float32,1,1},Tuple{Bool},Tuple{Int64}}}}, which is not isbits:
  .args is of type Tuple{Base.Broadcast.Extruded{Array{Float32,2},Tuple{Bool,Bool},Tuple{Int64,Int64}},Base.Broadcast.Extruded{CUDA.CuDeviceArray{Float32,1,1},Tuple{Bool},Tuple{Int64}}} which is not isbits.
    .1 is of type Base.Broadcast.Extruded{Array{Float32,2},Tuple{Bool,Bool},Tuple{Int64,Int64}} which is not isbits.
      .x is of type Array{Float32,2} which is not isbits.
Stacktrace:
 [1] check_invocation(::GPUCompiler.CompilerJob, ::LLVM.Function) at /home/johnbb/.julia/packages/GPUCompiler/uTpNx/src/validation.jl:68
 [2] macro expansion at /home/johnbb/.julia/packages/GPUCompiler/uTpNx/src/driver.jl:238 [inlined]
 [3] macro expansion at /home/johnbb/.julia/packages/TimerOutputs/ZmKD7/src/TimerOutput.jl:206 [inlined]
 [4] codegen(::Symbol, ::GPUCompiler.CompilerJob; libraries::Bool, deferred_codegen::Bool, optimize::Bool, strip::Bool, validate::Bool, only_entry::Bool) at /home/johnbb/.julia/packages/GPUCompiler/uTpNx/src/driver.jl:237
 [5] compile(::Symbol, ::GPUCompiler.CompilerJob; libraries::Bool, deferred_codegen::Bool, optimize::Bool, strip::Bool, validate::Bool, only_entry::Bool) at /home/johnbb/.julia/packages/GPUCompiler/uTpNx/src/driver.jl:39
 [6] compile at /home/johnbb/.julia/packages/GPUCompiler/uTpNx/src/driver.jl:35 [inlined]
 [7] cufunction_compile(::GPUCompiler.FunctionSpec; kwargs::Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}}) at /home/johnbb/.julia/packages/CUDA/wTQsK/src/compiler/execution.jl:302
 [8] cufunction_compile(::GPUCompiler.FunctionSpec) at /home/johnbb/.julia/packages/CUDA/wTQsK/src/compiler/execution.jl:297
 [9] check_cache(::Dict{UInt64,Any}, ::Any, ::Any, ::GPUCompiler.FunctionSpec{GPUArrays.var"#broadcast_kernel#12",Tuple{CUDA.CuKernelContext,CUDA.CuDeviceArray{Float32,2,1},Base.Broadcast.Broadcasted{Nothing,Tuple{Base.OneTo{Int64},Base.OneTo{Int64}},typeof(-),Tuple{Base.Broadcast.Extruded{Array{Float32,2},Tuple{Bool,Bool},Tuple{Int64,Int64}},Base.Broadcast.Extruded{CUDA.CuDeviceArray{Float32,1,1},Tuple{Bool},Tuple{Int64}}}},Int64}}, ::UInt64; kwargs::Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}}) at /home/johnbb/.julia/packages/GPUCompiler/uTpNx/src/cache.jl:40
 [10] broadcast_kernel at /home/johnbb/.julia/packages/GPUArrays/WV76E/src/host/broadcast.jl:60 [inlined]
 [11] cached_compilation at /home/johnbb/.julia/packages/GPUCompiler/uTpNx/src/cache.jl:65 [inlined]
 [12] cufunction(::GPUArrays.var"#broadcast_kernel#12", ::Type{Tuple{CUDA.CuKernelContext,CUDA.CuDeviceArray{Float32,2,1},Base.Broadcast.Broadcasted{Nothing,Tuple{Base.OneTo{Int64},Base.OneTo{Int64}},typeof(-),Tuple{Base.Broadcast.Extruded{Array{Float32,2},Tuple{Bool,Bool},Tuple{Int64,Int64}},Base.Broadcast.Extruded{CUDA.CuDeviceArray{Float32,1,1},Tuple{Bool},Tuple{Int64}}}},Int64}}; name::Nothing, kwargs::Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}}) at /home/johnbb/.julia/packages/CUDA/wTQsK/src/compiler/execution.jl:289
 [13] cufunction at /home/johnbb/.julia/packages/CUDA/wTQsK/src/compiler/execution.jl:286 [inlined]
 [14] macro expansion at /home/johnbb/.julia/packages/CUDA/wTQsK/src/compiler/execution.jl:100 [inlined]
 [15] #launch_heuristic#857 at /home/johnbb/.julia/packages/CUDA/wTQsK/src/gpuarrays.jl:17 [inlined]
 [16] launch_heuristic at /home/johnbb/.julia/packages/CUDA/wTQsK/src/gpuarrays.jl:17 [inlined]
 [17] copyto! at /home/johnbb/.julia/packages/GPUArrays/WV76E/src/host/broadcast.jl:66 [inlined]
 [18] copyto! at ./broadcast.jl:886 [inlined]
 [19] copy at ./broadcast.jl:862 [inlined]
 [20] materialize at ./broadcast.jl:837 [inlined]
 [21] adjoint at /home/johnbb/.julia/packages/Zygote/KpME9/src/lib/broadcast.jl:73 [inlined]
 [22] _pullback at /home/johnbb/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:57 [inlined]
 [23] #mse#6 at /home/johnbb/.julia/packages/Flux/goUGu/src/losses/functions.jl:17 [inlined]
 [24] _pullback(::Zygote.Context, ::Flux.Losses.var"##mse#6", ::typeof(Statistics.mean), ::typeof(Flux.Losses.mse), ::Array{Float32,2}, ::CUDA.CuArray{Float32,1}) at /home/johnbb/.julia/packages/Zygote/KpME9/src/compiler/interface2.jl:0
 [25] mse at /home/johnbb/.julia/packages/Flux/goUGu/src/losses/functions.jl:17 [inlined]
 [26] _pullback(::Zygote.Context, ::typeof(Flux.Losses.mse), ::Array{Float32,2}, ::CUDA.CuArray{Float32,1}) at /home/johnbb/.julia/packages/Zygote/KpME9/src/compiler/interface2.jl:0
 [27] loss at ./REPL[10]:6 [inlined]
 [28] _pullback(::Zygote.Context, ::var"#loss#7"{var"#4#5"{Embed{Array{Float32,2}},Embed{Array{Float32,2}},Dense{typeof(identity),Array{Float32,2},Array{Float32,1}}}}, ::Tuple{CUDA.CuArray{Int64,1},CUDA.CuArray{Int64,1},CUDA.CuArray{Float32,2}}, ::CUDA.CuArray{Float32,1}) at /home/johnbb/.julia/packages/Zygote/KpME9/src/compiler/interface2.jl:0
 [29] adjoint at /home/johnbb/.julia/packages/Zygote/KpME9/src/lib/lib.jl:188 [inlined]
 [30] _pullback at /home/johnbb/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:57 [inlined]
 [31] #15 at /home/johnbb/.julia/packages/Flux/goUGu/src/optimise/train.jl:103 [inlined]
 [32] _pullback(::Zygote.Context, ::Flux.Optimise.var"#15#21"{var"#loss#7"{var"#4#5"{Embed{Array{Float32,2}},Embed{Array{Float32,2}},Dense{typeof(identity),Array{Float32,2},Array{Float32,1}}}},Tuple{Tuple{CUDA.CuArray{Int64,1},CUDA.CuArray{Int64,1},CUDA.CuArray{Float32,2}},CUDA.CuArray{Float32,1}}}) at /home/johnbb/.julia/packages/Zygote/KpME9/src/compiler/interface2.jl:0
 [33] pullback(::Function, ::Zygote.Params) at /home/johnbb/.julia/packages/Zygote/KpME9/src/compiler/interface.jl:167
 [34] gradient(::Function, ::Zygote.Params) at /home/johnbb/.julia/packages/Zygote/KpME9/src/compiler/interface.jl:48
 [35] macro expansion at /home/johnbb/.julia/packages/Flux/goUGu/src/optimise/train.jl:102 [inlined]
 [36] macro expansion at /home/johnbb/.julia/packages/Juno/n6wyj/src/progress.jl:134 [inlined]
 [37] train!(::Function, ::Zygote.Params, ::Flux.Data.DataLoader{Tuple{Tuple{CUDA.CuArray{Int64,1},CUDA.CuArray{Int64,1},CUDA.CuArray{Float32,2}},CUDA.CuArray{Float32,1}}}, ::ADAM; cb::Flux.Optimise.var"#16#22") at /home/johnbb/.julia/packages/Flux/goUGu/src/optimise/train.jl:100
 [38] train!(::Function, ::Zygote.Params, ::Flux.Data.DataLoader{Tuple{Tuple{CUDA.CuArray{Int64,1},CUDA.CuArray{Int64,1},CUDA.CuArray{Float32,2}},CUDA.CuArray{Float32,1}}}, ::ADAM) at /home/johnbb/.julia/packages/Flux/goUGu/src/optimise/train.jl:98
 [39] macro expansion at ./timing.jl:174 [inlined]
 [40] train(::Int64, ::Int64, ::Int64; device::Function) at ./REPL[10]:7
 [41] top-level scope at REPL[12]:1