Here is my train!
function:
# Train
function train!(epoch, train_loader, test_loader)
@info "start training"
for e in 1:epoch
@info "epoch: $e"
i = 1
al::Float64 = 0.0
for batch in train_loader
data, label, mask = todevice(preprocess(batch[1], batch[2]))
(l, p), back = Flux.pullback(ps) do
loss(data, label, train_loader.batchsize; mask=mask)
end
#@show l
a = acc(p, label)
al += a
grad = back((Flux.Zygote.sensitivity(l), nothing))
i += 1
update!(opt, ps, grad)
#@show al / i
end
test()
end
end
But there is an error in the execution:
julia> train!(2, train_loader, test_loader)
[ Info: start training
[ Info: epoch: 1
ERROR: MethodError: no method matching batchedmul(::CuArray{Float64, 3, CUDA.Mem.DeviceBuffer}, ::CuArray{Float32, 3, CUDA.Mem.DeviceBuffer}; transB=true)
Closest candidates are:
batchedmul(::AbstractArray{T, 3}, ::AbstractArray{T, 3}; transA, transB) where T at /home/storopoli/.julia/packages/Transformers/V363g/src/fix/batchedmul.jl:5
batchedmul(::AbstractArray{T, N}, ::AbstractArray{T, N}; transA, transB) where {T, N} at /home/storopoli/.julia/packages/Transformers/V363g/src/fix/batchedmul.jl:13
Stacktrace:
[1] (::Transformers.var"#8#12"{CuArray{Float32, 3, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 3, CUDA.Mem.DeviceBuffer}})(Δ::CuArray{Float64, 3, CUDA.Mem.DeviceBuffer})
@ Transformers ~/.julia/packages/Transformers/V363g/src/fix/batchedmul.jl:45
[2] (::Transformers.var"#11#back#13"{Transformers.var"#8#12"{CuArray{Float32, 3, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 3, CUDA.Mem.DeviceBuffer}}})(Δ::CuArray{Float64, 3, CUDA.Mem.DeviceBuffer})
@ Transformers ~/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:59
[3] Pullback
@ ~/.julia/packages/Transformers/V363g/src/basic/mh_atten.jl:207 [inlined]
[4] (::typeof(∂(attention)))(Δ::CuArray{Float64, 3, CUDA.Mem.DeviceBuffer})
@ Zygote ~/.julia/packages/Zygote/ajuwN/src/compiler/interface2.jl:0
[5] Pullback
@ ~/.julia/packages/Transformers/V363g/src/basic/mh_atten.jl:102 [inlined]
[6] (::typeof(∂(#_#54)))(Δ::CuArray{Float64, 3, CUDA.Mem.DeviceBuffer})
@ Zygote ~/.julia/packages/Zygote/ajuwN/src/compiler/interface2.jl:0
[7] Pullback
@ ~/.julia/packages/Transformers/V363g/src/basic/mh_atten.jl:80 [inlined]
[8] (::typeof(∂(Any##kw)))(Δ::CuArray{Float64, 3, CUDA.Mem.DeviceBuffer})
@ Zygote ~/.julia/packages/Zygote/ajuwN/src/compiler/interface2.jl:0
[9] Pullback
@ ~/.julia/packages/Transformers/V363g/src/basic/transformer.jl:69 [inlined]
[10] (::typeof(∂(λ)))(Δ::CuArray{Float64, 3, CUDA.Mem.DeviceBuffer})
@ Zygote ~/.julia/packages/Zygote/ajuwN/src/compiler/interface2.jl:0
[11] macro expansion
@ ~/.julia/packages/Transformers/V363g/src/stacks/stack.jl:0 [inlined]
[12] Pullback
@ ~/.julia/packages/Transformers/V363g/src/stacks/stack.jl:17 [inlined]
[13] (::typeof(∂(λ)))(Δ::Tuple{CuArray{Float64, 3, CUDA.Mem.DeviceBuffer}, Nothing})
@ Zygote ~/.julia/packages/Zygote/ajuwN/src/compiler/interface2.jl:0
[14] Pullback
@ ~/.julia/packages/Transformers/V363g/src/bert/bert.jl:55 [inlined]
[15] (::typeof(∂(#_#9)))(Δ::CuArray{Float64, 3, CUDA.Mem.DeviceBuffer})
@ Zygote ~/.julia/packages/Zygote/ajuwN/src/compiler/interface2.jl:0
[16] Pullback
@ ~/.julia/packages/Transformers/V363g/src/bert/bert.jl:50 [inlined]
[17] (::typeof(∂(λ)))(Δ::CuArray{Float64, 3, CUDA.Mem.DeviceBuffer})
@ Zygote ~/.julia/packages/Zygote/ajuwN/src/compiler/interface2.jl:0
[18] Pullback
@ ./REPL[55]:3 [inlined]
[19] (::typeof(∂(#loss#4)))(Δ::Tuple{Float64, Nothing})
@ Zygote ~/.julia/packages/Zygote/ajuwN/src/compiler/interface2.jl:0
[20] Pullback
@ ./REPL[55]:2 [inlined]
[21] (::typeof(∂(loss##kw)))(Δ::Tuple{Float64, Nothing})
@ Zygote ~/.julia/packages/Zygote/ajuwN/src/compiler/interface2.jl:0
[22] Pullback
@ ./REPL[62]:10 [inlined]
[23] (::typeof(∂(λ)))(Δ::Tuple{Float64, Nothing})
@ Zygote ~/.julia/packages/Zygote/ajuwN/src/compiler/interface2.jl:0
[24] (::Zygote.var"#94#95"{Zygote.Params, typeof(∂(λ)), Zygote.Context})(Δ::Tuple{Float64, Nothing})
@ Zygote ~/.julia/packages/Zygote/ajuwN/src/compiler/interface.jl:348
[25] train!(epoch::Int64, train_loader::DataLoader{Tuple{Vector{String}, Vector{Int64}}, Random._GLOBAL_RNG}, test_loader::DataLoader{Tuple{Vector{String}, Vector{Int64}}, Random._GLOBAL_RNG})
@ Main ./REPL[62]:15
[26] top-level scope
@ REPL[65]:1
[27] top-level scope
@ ~/.julia/packages/CUDA/9T5Sq/src/initialization.jl:66
The full code can be found here: https://github.com/LabCidades/COVID-Classifier/blob/main/src/tweet_classifier_BERT.jl