I am implementing neural architecture search methods in Julia. One class of them uses trainable architecture weights to learn the best option for each layer of the network. Each repeated cell within a network has N layers that each use a M-length array of continuous weights to computed a weighted sum of the operation choices at that layer, and these weights are optimized by gradient descent.

However, I have not been able to find a data structure that allows me to store and train these architectures weights on GPU. On CPU, just using an NxM 2D matrix works (as in the MWE below). I have tried other encodings for GPU, including a 1D array of CuArrays (leading to a similar situation as this issue), some approaches from SliceMap.jl, and other ideas discussed here on the Slack (which includes a snippet of the full network code). Most approaches lead to an KernelError from the gradient running into non-GPU arrays.

I understand that scalar indexing is not optimal for GPU computations, but the `all_αs`

matrices are relatively small, and I cannot think of a better way to structure them than as a slicable matrix.

TLDR: how can I change the two lines with comments below to make the script work on a GPU?

```
using Flux
using Zygote
ReLUConv(channels_in, channels_out, kernel_size, pad) =
Chain(x -> relu.(x), Conv(kernel_size, channels_in => channels_out, pad = pad))
struct MixedOperation
operations::AbstractArray
end
MixedOperation(channels::Int64, kernel_options::AbstractArray) =
MixedOperation([ReLUConv(channels, channels, (i, i), i ÷ 2) for i in kernel_options])
function (m::MixedOperation)(x::AbstractArray, αs::AbstractArray)
mapreduce((op, α) -> α * op(x), +, m.operations, αs)
end
Flux.@functor MixedOperation
struct MWECell
steps::Int64
outstates::Int64
mixedops::Array{MixedOperation,1}
end
function MWECell(
steps::Int64,
outstates::Int64,
channels::Int64,
kernel_options::AbstractArray,
)
mixedops = [MixedOperation(channels, kernel_options) for _ = 1:steps]
MWECell(steps, outstates, mixedops)
end
function (m::MWECell)(x::AbstractArray, all_αs::AbstractArray)
state1 = m.mixedops[1](x, all_αs[1, :])
states = Zygote.Buffer([state1], m.steps)
states[1] = state1
for step = 2:m.steps
states[step] = m.mixedops[step](states[step-1], all_αs[step,:]) #GPU version errors out here
end
states_ = copy(states)
out = cat(states_[m.steps-m.outstates+1:m.steps]..., dims = 3)
out
end
Flux.@functor MWECell
struct MWEModel
all_αs::AbstractArray
cells::Array{MWECell,1}
end
function MWEModel(
kernel_options::AbstractArray;
num_cells = 3,
channels = 3,
steps = 4,
outstates = 2,
)
cells = [
MWECell(steps, outstates, channels * outstates^index, kernel_options)
for index = 0:num_cells-1
]
all_αs = rand(Float32, steps, length(kernel_options)) #offending data structure
MWEModel(all_αs, cells)
end
function (m::MWEModel)(x::AbstractArray)
state = x
αs = softmax(m.all_αs, dims = 2)
for cell in m.cells
state = cell(state, αs)
end
state
end
Flux.@functor MWEModel
using Test
using CUDA
m = MWEModel([1, 3, 5]) |> gpu
test_image = rand(Float32, 32, 32, 3, 1) |> gpu
@test sum(m(test_image)) != 0
grad = gradient(x -> sum(m(x)), test_image)
loss(m, x) = sum(m(x))
gαs = gradient(params(m.all_αs)) do
sum(m(test_image))
end
for αs in params(m.all_αs)
@test gαs[αs] != Nothing
end
gws = gradient(params(m.cells)) do
sum(m(test_image))
end
for ws in params(m.cells)
@test gws[ws] != Nothing
end
```