Mixed input layer using Flux gives "ERROR: Mutating arrays is not supported"

I would like to combine input types in the input layer of my Flux model.

For example, using the Embedded example layer (found on this forum), I would like to embed the first column of my input data and use a Dense layer for the next 2 columns.

Something like this:

model = Chain(Mixed(Embedding(3,2), Dense(2, 4)),
              Dense(6,3,relu),
              Dense(3,1))

The input layer embeds the first column (which has 3 values) into 2 columns and a Dense layer maps 2 columns to 4 output columns. So combining these 2 layer types, the output is 6 columns.

The implementation for Mixed is here:

struct Mixed{T<:Tuple}
    partitions::T
    Mixed(xs...) = new{typeof(xs)}(xs)
end

@Flux.treelike Mixed

function Base.show(io::IO, l::Mixed)
  print(io, "Mixed(", length(l.partitions))
  print(io, ")")
end


function applypartition!(partition, x::AbstractArray, datacol::Int64)
  if typeof(partition) <: Embedding
      partin = 1 #Embeddings work on single columns
      datadims = datacol
      datacol += partin  #update the data column index for the next partition
      partition(convert(Array{Int64}, x[datadims,:])), datacol
  else
      partin = size(partition.W)[2] #Get number of inputs for this partitioned layer
      datadims = datacol:(datacol + partin - 1)
      datacol += partin  #update the data column index for the next partition
      partition(x[datadims,:]), datacol
  end
end

function (a::Mixed)(x::AbstractArray)
    datacol = 1 #column index of input data to apply partition to
    results = [applypartition!(a.partitions[p], x, datacol) for p in 1:length(a.partitions)]
    reduce(vcat, [result[1] for result in results])
end

My problem seems to be with the ‘reduce’ function call. I get a “ERROR: Mutating arrays is not supported” error. I thought the Mixed type would behave like the Chain type in Flux. So, I’m not understanding where the “Mutating arrays” is coming from.

The full stacktrace is here:

ERROR: Mutating arrays is not supported
Stacktrace:
 [1] error(::String) at ./error.jl:33
 [2] (::Zygote.var"#1048#1049")(::Nothing) at /Users/rs990e/.julia/packages/Zygote/YeCEW/src/lib/array.jl:61
 [3] (::Zygote.var"#2775#back#1050"{Zygote.var"#1048#1049"})(::Nothing) at /Users/rs990e/.julia/packages/ZygoteRules/6nssF/src/adjoint.jl:49
 [4] _typed_vcat at ./abstractarray.jl:1366 [inlined]
 [5] (::typeof(∂(_typed_vcat)))(::Array{Float64,2}) at /Users/rs990e/.julia/packages/Zygote/YeCEW/src/compiler/interface2.jl:0
 [6] reduce at ./abstractarray.jl:1374 [inlined]
 [7] (::typeof(∂(reduce)))(::Array{Float64,2}) at /Users/rs990e/.julia/packages/Zygote/YeCEW/src/compiler/interface2.jl:0
 [8] Mixed at /Users/rs990e/Workspace/PVC/src/layers.jl:61 [inlined]
 [9] (::typeof(∂(λ)))(::Array{Float64,2}) at /Users/rs990e/.julia/packages/Zygote/YeCEW/src/compiler/interface2.jl:0
 [10] applychain at /Users/rs990e/.julia/packages/Flux/Fj3bt/src/layers/basic.jl:36 [inlined]
 [11] (::typeof(∂(applychain)))(::Array{Float64,2}) at /Users/rs990e/.julia/packages/Zygote/YeCEW/src/compiler/interface2.jl:0
 [12] Chain at /Users/rs990e/.julia/packages/Flux/Fj3bt/src/layers/basic.jl:38 [inlined]
 [13] (::typeof(∂(λ)))(::Array{Float64,2}) at /Users/rs990e/.julia/packages/Zygote/YeCEW/src/compiler/interface2.jl:0
 [14] loss at /Users/rs990e/Workspace/Julia_EDA/mixflux.jl:25 [inlined]
 [15] (::typeof(∂(loss)))(::Float64) at /Users/rs990e/.julia/packages/Zygote/YeCEW/src/compiler/interface2.jl:0
 [16] #174 at /Users/rs990e/.julia/packages/Zygote/YeCEW/src/lib/lib.jl:182 [inlined]
 [17] #347#back at /Users/rs990e/.julia/packages/ZygoteRules/6nssF/src/adjoint.jl:49 [inlined]
 [18] #15 at /Users/rs990e/.julia/packages/Flux/Fj3bt/src/optimise/train.jl:89 [inlined]
 [19] (::typeof(∂(λ)))(::Float64) at /Users/rs990e/.julia/packages/Zygote/YeCEW/src/compiler/interface2.jl:0
 [20] (::Zygote.var"#49#50"{Zygote.Params,Zygote.Context,typeof(∂(λ))})(::Float64) at /Users/rs990e/.julia/packages/Zygote/YeCEW/src/compiler/interface.jl:179
 [21] gradient(::Function, ::Zygote.Params) at /Users/rs990e/.julia/packages/Zygote/YeCEW/src/compiler/interface.jl:55
 [22] macro expansion at /Users/rs990e/.julia/packages/Flux/Fj3bt/src/optimise/train.jl:88 [inlined]
 [23] macro expansion at /Users/rs990e/.julia/packages/Juno/tLMZd/src/progress.jl:119 [inlined]
 [24] train!(::typeof(loss), ::Zygote.Params, ::Array{Tuple{Array{Float64,2},Array{Float64,1}},1}, ::ADAM; cb::Flux.Optimise.var"#18#26") at /Users/rs990e/.julia/packages/Flux/Fj3bt/src/optimise/train.jl:81
 [25] train!(::Function, ::Zygote.Params, ::Array{Tuple{Array{Float64,2},Array{Float64,1}},1}, ::ADAM) at /Users/rs990e/.julia/packages/Flux/Fj3bt/src/optimise/train.jl:79
 [26] top-level scope at /Users/rs990e/Workspace/Julia_EDA/mixflux.jl:29

Any help is appreciated.

Did you find any solution for this problem? Looks similar to this issue: Gradient of a loss function : struggling to avoid arrays mutation - #3 by ChrisRackauckas