What I’m trying to do is pass a linear combination of images (lets say a1img1 + a2img2 + a3*img3) through a neural network and I want the network to choose the best a1, a2, a3 for classification.
To do this I’ve made a custom first layer which performs the operation a1img1 + a2img2 + a3*img3 and I’ve chained it with the rest of the neural network. However when I try to train the network I get the error that back! was already used.
function mixing_layer(m)
a1 = param(rand())
a2 = param(rand())
a3 = param(rand())
I_m = Matrix{Float64}(I,m,m)
a1_mat = a1*I_m
a2_mat = a2*I_m
a3_mat = a3*I_m;
x -> [I_m a1_mat a2_mat a3_mat]*x
#x -> x
end
function nn_linear(test,train,D,k)
Model_ae = []
for i in 1:D
Encoder = Chain(mixing_layer(m), Dense(m,k))
Decoder = Chain(Dense(k,m),Dense(m,4*m))
Model = Chain(Encoder,Decoder)
dataset = Base.Iterators.repeated((train_conv[:,:,i], train_conv[:,:,i]), 200)
loss(x, y) = Flux.mse(Model(x), y)
parameters = params(Model)
opt = ADAM()
Flux.train!(loss, parameters, dataset, opt)
push!(Model_ae, Model)
end
return Model_ae
end
k = 20
rnk = 20
Model_linear = nn_linear(test_conv,train_conv,D,rnk);
Running the above cell gives the error
`back!` was already used
Stacktrace:
[1] error(::String) at ./error.jl:33
[2] back_(::Flux.Tracker.Call{Missing,Tuple{}}, ::Array{Float64,2}, ::Bool) at /home/dipak/.julia/packages/Flux/8XpDt/src/tracker/back.jl:30
[3] back(::Flux.Tracker.Tracked{Array{Float64,2}}, ::Array{Float64,2}, ::Bool) at /home/dipak/.julia/packages/Flux/8XpDt/src/tracker/back.jl:46
[4] (::getfield(Flux.Tracker, Symbol("##3#4")){Bool})(::Flux.Tracker.Tracked{Array{Float64,2}}, ::Array{Float64,2}) at /home/dipak/.julia/packages/Flux/8XpDt/src/tracker/back.jl:26
[5] foreach(::Function, ::Tuple{Nothing,Flux.Tracker.Tracked{Array{Float64,2}},Flux.Tracker.Tracked{Array{Float64,2}},Flux.Tracker.Tracked{Array{Float64,2}}}, ::NTuple{4,Array{Float64,2}}) at ./abstractarray.jl:1867
[6] back_(::Flux.Tracker.Call{getfield(Flux.Tracker, Symbol("##412#415")){Tuple{Array{Float64,2},TrackedArray{…,Array{Float64,2}},TrackedArray{…,Array{Float64,2}},TrackedArray{…,Array{Float64,2}}}},Tuple{Nothing,Flux.Tracker.Tracked{Array{Float64,2}},Flux.Tracker.Tracked{Array{Float64,2}},Flux.Tracker.Tracked{Array{Float64,2}}}}, ::Array{Float64,2}, ::Bool) at /home/dipak/.julia/packages/Flux/8XpDt/src/tracker/back.jl:26
[7] back(::Flux.Tracker.Tracked{Array{Float64,2}}, ::Array{Float64,2}, ::Bool) at /home/dipak/.julia/packages/Flux/8XpDt/src/tracker/back.jl:46
[8] #3 at /home/dipak/.julia/packages/Flux/8XpDt/src/tracker/back.jl:26 [inlined]
[9] foreach at ./abstractarray.jl:1867 [inlined]
[10] back_(::Flux.Tracker.Call{getfield(Flux.Tracker, Symbol("##481#482")){TrackedArray{…,Array{Float64,2}},Array{Float64,2}},Tuple{Flux.Tracker.Tracked{Array{Float64,2}},Nothing}}, ::Array{Float64,2}, ::Bool) at /home/dipak/.julia/packages/Flux/8XpDt/src/tracker/back.jl:26
... (the last 8 lines are repeated 3 more times)
[35] back(::Flux.Tracker.Tracked{Array{Float32,2}}, ::Array{Float64,2}, ::Bool) at /home/dipak/.julia/packages/Flux/8XpDt/src/tracker/back.jl:46
[36] (::getfield(Flux.Tracker, Symbol("##3#4")){Bool})(::Flux.Tracker.Tracked{Array{Float32,2}}, ::Array{Float64,2}) at /home/dipak/.julia/packages/Flux/8XpDt/src/tracker/back.jl:26
[37] foreach(::Function, ::Tuple{Flux.Tracker.Tracked{Array{Float32,2}},Flux.Tracker.Tracked{Array{Float32,1}}}, ::Tuple{Array{Float64,2},Array{Float64,1}}) at ./abstractarray.jl:1867
[38] back_(::Flux.Tracker.Call{getfield(Flux.Tracker, Symbol("#back#514")){2,getfield(Base.Broadcast, Symbol("##2#4")){getfield(Base.Broadcast, Symbol("##8#10")){getfield(Base.Broadcast, Symbol("##1#3")),getfield(Base.Broadcast, Symbol("##5#6")){getfield(Base.Broadcast, Symbol("##5#6")){getfield(Base.Broadcast, Symbol("##7#9"))}},getfield(Base.Broadcast, Symbol("##11#12")){getfield(Base.Broadcast, Symbol("##11#12")){getfield(Base.Broadcast, Symbol("##13#14"))}},getfield(Base.Broadcast, Symbol("##15#16")){getfield(Base.Broadcast, Symbol("##15#16")){getfield(Base.Broadcast, Symbol("##17#18"))}},typeof(+)},typeof(identity)},Tuple{TrackedArray{…,Array{Float32,2}},TrackedArray{…,Array{Float32,1}}}},Tuple{Flux.Tracker.Tracked{Array{Float32,2}},Flux.Tracker.Tracked{Array{Float32,1}}}}, ::Array{Float64,2}, ::Bool) at /home/dipak/.julia/packages/Flux/8XpDt/src/tracker/back.jl:26
... (the last 4 lines are repeated 1 more time)
[43] back(::Flux.Tracker.Tracked{Array{Float64,2}}, ::Array{Float64,2}, ::Bool) at /home/dipak/.julia/packages/Flux/8XpDt/src/tracker/back.jl:46
[44] foreach at /home/dipak/.julia/packages/Flux/8XpDt/src/tracker/back.jl:26 [inlined]
[45] back_(::Flux.Tracker.Call{getfield(Flux.Tracker, Symbol("##454#455")){TrackedArray{…,Array{Float64,2}}},Tuple{Flux.Tracker.Tracked{Array{Float64,2}}}}, ::Float64, ::Bool) at /home/dipak/.julia/packages/Flux/8XpDt/src/tracker/back.jl:26
[46] back(::Flux.Tracker.Tracked{Float64}, ::Float64, ::Bool) at /home/dipak/.julia/packages/Flux/8XpDt/src/tracker/back.jl:46
[47] #3 at /home/dipak/.julia/packages/Flux/8XpDt/src/tracker/back.jl:26 [inlined]
[48] foreach at ./abstractarray.jl:1867 [inlined]
[49] back_(::Flux.Tracker.Call{getfield(Flux.Tracker, Symbol("##273#276")){Rational{Int64}},Tuple{Flux.Tracker.Tracked{Float64},Nothing}}, ::Float64, ::Bool) at /home/dipak/.julia/packages/Flux/8XpDt/src/tracker/back.jl:26
[50] back(::Flux.Tracker.Tracked{Float64}, ::Int64, ::Bool) at /home/dipak/.julia/packages/Flux/8XpDt/src/tracker/back.jl:46
[51] #back!#5 at /home/dipak/.julia/packages/Flux/8XpDt/src/tracker/back.jl:65 [inlined]
[52] #back! at ./none:0 [inlined]
[53] #back!#27 at /home/dipak/.julia/packages/Flux/8XpDt/src/tracker/lib/real.jl:16 [inlined]
[54] back!(::Flux.Tracker.TrackedReal{Float64}) at /home/dipak/.julia/packages/Flux/8XpDt/src/tracker/lib/real.jl:14
[55] macro expansion at /home/dipak/.julia/packages/Flux/8XpDt/src/optimise/train.jl:25 [inlined]
[56] macro expansion at /home/dipak/.julia/packages/Juno/TfNYn/src/progress.jl:133 [inlined]
[57] #train!#12(::getfield(Flux.Optimise, Symbol("##14#18")), ::Function, ::Function, ::Flux.Tracker.Params, ::Base.Iterators.Take{Base.Iterators.Repeated{Tuple{Array{Float64,2},Array{Float64,2}}}}, ::ADAM) at /home/dipak/.julia/packages/Flux/8XpDt/src/optimise/train.jl:72
[58] train!(::Function, ::Flux.Tracker.Params, ::Base.Iterators.Take{Base.Iterators.Repeated{Tuple{Array{Float64,2},Array{Float64,2}}}}, ::ADAM) at /home/dipak/.julia/packages/Flux/8XpDt/src/optimise/train.jl:70
[59] nn_linear(::Array{Float64,3}, ::Array{Float64,3}, ::Int64, ::Int64) at ./In[13]:17