Bug when training a custom model using Flux

Hello,
I am pretty new to Julia and Flux. I am trying to build a simple neural network but using an attention layer. I wrote the code as follows, which works fine in the inference(feed-forward) mode:

using Flux

struct Attention
    W
    v
end

Attention(vehile_embedding_dim::Integer) = Attention(
    Dense(vehile_embedding_dim => vehile_embedding_dim, tanh),
    Dense(vehile_embedding_dim, 1, bias=false, init=Flux.zeros32)
)

function (a::Attention)(inputs)
    alphas = [a.v(e) for e in a.W.(inputs)]
    alphas = sigmoid.(alphas)
    output = sum([alpha.*input for (alpha, input) in zip(alphas, inputs)])
    return output
end

Flux.@functor Attention

struct AttentionNet 
    embedding
    attention
    fc_output
    vehicle_num::Integer
    vehicle_dim::Integer
end

AttentionNet(vehicle_num::Integer, vehicle_dim::Integer, embedding_dim::Integer) = AttentionNet(
    Dense(vehicle_dim+1 => embedding_dim, relu),
    Attention(embedding_dim),
    Dense(1+embedding_dim => 1),
    vehicle_num,
    vehicle_dim
)

function (a_net::AttentionNet)(x)
    time_idx = x[[1], :]
    vehicle_states = [x[2+a_net.vehicle_dim*(i-1):2+a_net.vehicle_dim*i-1, :] for i in 1:a_net.vehicle_num]
    vehicle_states = [vcat(time_idx, vehicle_state) for vehicle_state in vehicle_states]

    vehicle_embeddings = a_net.embedding.(vehicle_states)
    attention_output = a_net.attention(vehicle_embeddings)
    
    x = a_net.fc_output(vcat(time_idx, attention_output))
    return x
end



Flux.@functor AttentionNet
Flux.trainable(a_net::AttentionNet) = (a_net.embedding, a_net.attention, a_net.fc_output,)

fake_inputs = rand(22, 32)
fake_outputs = rand(1, 32)
a_net = AttentionNet(3, 7, 64)|> gpu
opt = Adam(.01)
opt_state = Flux.setup(opt, a_net)

data = Flux.DataLoader((fake_inputs, fake_outputs)|>gpu, batchsize=32, shuffle=true)

Flux.train!(a_net, data, opt_state) do m, x, y
    Flux.mse(m(x), y)
end

But when I trained it, I got the following error message and a warning:

┌ Warning: trainable(x) should now return a NamedTuple with the field names, not a Tuple
└ @ Optimisers C:\Users\Herr LU\.julia\packages\Optimisers\SoKJO\src\interface.jl:164
ERROR: MethodError: no method matching +(::Base.RefValue{Any}, ::NamedTuple{(:contents,), Tuple{CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}})
Closest candidates are:
  +(::Any, ::Any, ::Any, ::Any...) at operators.jl:591
  +(::Union{InitialValues.NonspecificInitialValue, InitialValues.SpecificInitialValue{typeof(+)}}, ::Any) at C:\Users\Herr LU\.julia\packages\InitialValues\OWP8V\src\InitialValues.jl:154
  +(::ChainRulesCore.Tangent{P}, ::P) where P at C:\Users\Herr LU\.julia\packages\ChainRulesCore\C73ay\src\tangent_arithmetic.jl:146
  ...
Stacktrace:
  [1] accum(x::Base.RefValue{Any}, y::NamedTuple{(:contents,), Tuple{CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}})
    @ Zygote C:\Users\Herr LU\.julia\packages\Zygote\SmJK6\src\lib\lib.jl:17
  [2] accum(x::Base.RefValue{Any}, y::NamedTuple{(:contents,), Tuple{CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}}, zs::Base.RefValue{Any}) (repeats 2 times)
    @ Zygote C:\Users\Herr LU\.julia\packages\Zygote\SmJK6\src\lib\lib.jl:22
  [3] Pullback
    @ e:\Master Thesis\lu_jizhou\toy exmaple\dqn_model.jl:39 [inlined]
  [4] (::typeof(∂(λ)))(Δ::CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer})
    @ Zygote C:\Users\Herr LU\.julia\packages\Zygote\SmJK6\src\compiler\interface2.jl:0
  [5] Pullback
    @ e:\Master Thesis\lu_jizhou\toy exmaple\dqn_model.jl:62 [inlined]
  [6] #208
    @ C:\Users\Herr LU\.julia\packages\Zygote\SmJK6\src\lib\lib.jl:206 [inlined]
  [7] #2066#back
    @ C:\Users\Herr LU\.julia\packages\ZygoteRules\AIbCs\src\adjoint.jl:67 [inlined]
  [8] Pullback
    @ C:\Users\Herr LU\.julia\packages\Flux\ZdbJr\src\train.jl:102 [inlined]
  [9] (::typeof(∂(λ)))(Δ::Float32)
    @ Zygote C:\Users\Herr LU\.julia\packages\Zygote\SmJK6\src\compiler\interface2.jl:0
 [10] (::Zygote.var"#60#61"{typeof(∂(λ))})(Δ::Float32)
    @ Zygote C:\Users\Herr LU\.julia\packages\Zygote\SmJK6\src\compiler\interface.jl:45
 [11] withgradient(f::Function, args::AttentionNet)
    @ Zygote C:\Users\Herr LU\.julia\packages\Zygote\SmJK6\src\compiler\interface.jl:133
 [12] macro expansion
    @ C:\Users\Herr LU\.julia\packages\Flux\ZdbJr\src\train.jl:102 [inlined]
 [13] macro expansion
    @ C:\Users\Herr LU\.julia\packages\ProgressLogging\6KXlp\src\ProgressLogging.jl:328 [inlined]
 [14] train!(loss::Function, model::AttentionNet, data::MLUtils.DataLoader{Tuple{CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}, Random._GLOBAL_RNG, Val{nothing}}, opt::Named
Tuple{(:embedding, :attention, :fc_output, :vehicle_num, :vehicle_dim), Tuple{NamedTuple{(:weight, :bias, :σ), Tuple{Optimisers.Leaf{Optimisers.Adam{Float64}, Tuple{CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CUDA.CuArr
ay{Float32, 2, CUDA.Mem.DeviceBuffer}, Tuple{Float64, Float64}}}, Optimisers.Leaf{Optimisers.Adam{Float64}, Tuple{CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Tuple{Float64
, Float64}}}, Tuple{}}}, NamedTuple{(:W, :v), Tuple{NamedTuple{(:weight, :bias, :σ), Tuple{Optimisers.Leaf{Optimisers.Adam{Float64}, Tuple{CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 2, CUDA.Mem.De
viceBuffer}, Tuple{Float64, Float64}}}, Optimisers.Leaf{Optimisers.Adam{Float64}, Tuple{CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Tuple{Float64, Float64}}}, Tuple{}}}, N
amedTuple{(:weight, :bias, :σ), Tuple{Optimisers.Leaf{Optimisers.Adam{Float64}, Tuple{CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, Tuple{Float64, Float64}}}, Tuple{}, Tupl
e{}}}}}, NamedTuple{(:weight, :bias, :σ), Tuple{Optimisers.Leaf{Optimisers.Adam{Float64}, Tuple{CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, Tuple{Float64, Float64}}}, Opt
imisers.Leaf{Optimisers.Adam{Float64}, Tuple{CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Tuple{Float64, Float64}}}, Tuple{}}}, Tuple{}, Tuple{}}}; cb::Nothing)
    @ Flux.Train C:\Users\Herr LU\.julia\packages\Flux\ZdbJr\src\train.jl:100
 [15] train!(loss::Function, model::AttentionNet, data::MLUtils.DataLoader{Tuple{CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}, Random._GLOBAL_RNG, Val{nothing}}, opt::Named
Tuple{(:embedding, :attention, :fc_output, :vehicle_num, :vehicle_dim), Tuple{NamedTuple{(:weight, :bias, :σ), Tuple{Optimisers.Leaf{Optimisers.Adam{Float64}, Tuple{CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CUDA.CuArr
ay{Float32, 2, CUDA.Mem.DeviceBuffer}, Tuple{Float64, Float64}}}, Optimisers.Leaf{Optimisers.Adam{Float64}, Tuple{CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Tuple{Float64
, Float64}}}, Tuple{}}}, NamedTuple{(:W, :v), Tuple{NamedTuple{(:weight, :bias, :σ), Tuple{Optimisers.Leaf{Optimisers.Adam{Float64}, Tuple{CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 2, CUDA.Mem.De
viceBuffer}, Tuple{Float64, Float64}}}, Optimisers.Leaf{Optimisers.Adam{Float64}, Tuple{CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Tuple{Float64, Float64}}}, Tuple{}}}, N
amedTuple{(:weight, :bias, :σ), Tuple{Optimisers.Leaf{Optimisers.Adam{Float64}, Tuple{CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, Tuple{Float64, Float64}}}, Tuple{}, Tupl
e{}}}}}, NamedTuple{(:weight, :bias, :σ), Tuple{Optimisers.Leaf{Optimisers.Adam{Float64}, Tuple{CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, Tuple{Float64, Float64}}}, Opt
imisers.Leaf{Optimisers.Adam{Float64}, Tuple{CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Tuple{Float64, Float64}}}, Tuple{}}}, Tuple{}, Tuple{}}})
    @ Flux.Train C:\Users\Herr LU\.julia\packages\Flux\ZdbJr\src\train.jl:97
 [16] top-level scope
    @ e:\Master Thesis\lu_jizhou\toy exmaple\dqn_model.jl:61

I followed the instruction from the official tutorial on custom layers, but it doesn’t specify how to get custom layers properly trained. Could someone help me out?

Can you share the following:

  • If your forward pass works correctly
  • Error when calculating the gradient without using GPU

gradient(m -> Flux.mse(m(x), y), m)

1 Like

Thanks for your reply. This problem is already solved in this github thread. My code causes a captured variable problem.

2 Likes