Mapreduce of Softmax-ed array can't be differentiated on GPU

In an MWE of my use case below, I am trying to output a softmax-weighted sum of multiple layers on the same input. The output needs to be differentiable with respect to both the weights within the layers as well as the softmax-ed weights αs.

On GPU, the softmax and mapreduce seem to be differentiable when only one or the other is included in the model, but not both. I have tried replacing mapreduce with other parallelizable broadcasting approaches, but everything I tried either errors out in the forward pass, errors out in the gradient evaluation, returns gradients of Nothing type, or returns gradients on the CPU. The only thing that seems to work is to 1) use my own adjoint for softmax and 2) use the (less than ideal) scalar indexing approach to performing the weighted sum.

Any ideas as to why this issue is occurring, and any more efficient ways of implementing this?

using Flux
using Zygote

function my_softmax(xs; dims = 1)
    softmax(xs, dims = dims)

Zygote.@adjoint function my_softmax(xs; dims = 1)
    softmax(xs, dims = dims), Δ -> begin
        (∇softmax(Δ, xs, dims = dims),)

ReLUConv(channels_in, channels_out, kernel_size, pad) =
    Chain(x -> relu.(x), Conv(kernel_size, channels_in => channels_out, pad = pad))

struct MixedOperation

MixedOperation(channels::Int64, kernel_options::AbstractArray) =
    MixedOperation([ReLUConv(channels, channels, (i, i), i ÷ 2) for i in kernel_options])

function (m::MixedOperation)(x::AbstractArray, αs::AbstractArray)
    αs = softmax(αs)
    #αs = my_softmax(αs)
    mapreduce((op, α) -> α * op(x), +, m.operations, αs) #errors out in gradient of softmax
    #sum(αs[i]*m.operations[i](x) for i in 1:length(αs))

Flux.@functor MixedOperation

using Test
using CUDA

m = MixedOperation(3, [1, 3, 5, 7]) |> gpu
αs = rand(Float32, 4) |> gpu
test_image = rand(Float32, 16, 16, 3, 1) |> gpu
@test sum(m(test_image, αs)) != 0
grad = gradient((x,αs) -> sum(m(x,αs)), test_image, αs)

gαs = gradient(Flux.params(αs)) do
    sum(m(test_image, αs))
for a in Flux.params(αs)
    @show gαs[a]
    @test isa(gαs[a], CuArray)
1 Like