In an MWE of my use case below, I am trying to output a softmax-weighted sum of multiple layers on the same input. The output needs to be differentiable with respect to both the weights within the layers as well as the `softmax`

-ed weights `αs`

.

On GPU, the `softmax`

and `mapreduce`

seem to be differentiable when only one or the other is included in the model, but not both. I have tried replacing `mapreduce`

with other parallelizable broadcasting approaches, but everything I tried either errors out in the forward pass, errors out in the gradient evaluation, returns gradients of `Nothing`

type, or returns gradients on the CPU. The only thing that seems to work is to 1) use my own adjoint for `softmax`

and 2) use the (less than ideal) scalar indexing approach to performing the weighted sum.

Any ideas as to why this issue is occurring, and any more efficient ways of implementing this?

```
using Flux
using Zygote
function my_softmax(xs; dims = 1)
softmax(xs, dims = dims)
end
Zygote.@adjoint function my_softmax(xs; dims = 1)
softmax(xs, dims = dims), Δ -> begin
(∇softmax(Δ, xs, dims = dims),)
end
end
ReLUConv(channels_in, channels_out, kernel_size, pad) =
Chain(x -> relu.(x), Conv(kernel_size, channels_in => channels_out, pad = pad))
struct MixedOperation
operations::AbstractArray
end
MixedOperation(channels::Int64, kernel_options::AbstractArray) =
MixedOperation([ReLUConv(channels, channels, (i, i), i ÷ 2) for i in kernel_options])
function (m::MixedOperation)(x::AbstractArray, αs::AbstractArray)
αs = softmax(αs)
#αs = my_softmax(αs)
mapreduce((op, α) -> α * op(x), +, m.operations, αs) #errors out in gradient of softmax
#sum(αs[i]*m.operations[i](x) for i in 1:length(αs))
end
Flux.@functor MixedOperation
using Test
using CUDA
m = MixedOperation(3, [1, 3, 5, 7]) |> gpu
αs = rand(Float32, 4) |> gpu
test_image = rand(Float32, 16, 16, 3, 1) |> gpu
@test sum(m(test_image, αs)) != 0
grad = gradient((x,αs) -> sum(m(x,αs)), test_image, αs)
gαs = gradient(Flux.params(αs)) do
sum(m(test_image, αs))
end
for a in Flux.params(αs)
@show gαs[a]
@test isa(gαs[a], CuArray)
end
```