In an MWE of my use case below, I am trying to output a softmax-weighted sum of multiple layers on the same input. The output needs to be differentiable with respect to both the weights within the layers as well as the
On GPU, the
mapreduce seem to be differentiable when only one or the other is included in the model, but not both. I have tried replacing
mapreduce with other parallelizable broadcasting approaches, but everything I tried either errors out in the forward pass, errors out in the gradient evaluation, returns gradients of
Nothing type, or returns gradients on the CPU. The only thing that seems to work is to 1) use my own adjoint for
softmax and 2) use the (less than ideal) scalar indexing approach to performing the weighted sum.
Any ideas as to why this issue is occurring, and any more efficient ways of implementing this?
using Flux using Zygote function my_softmax(xs; dims = 1) softmax(xs, dims = dims) end Zygote.@adjoint function my_softmax(xs; dims = 1) softmax(xs, dims = dims), Δ -> begin (∇softmax(Δ, xs, dims = dims),) end end ReLUConv(channels_in, channels_out, kernel_size, pad) = Chain(x -> relu.(x), Conv(kernel_size, channels_in => channels_out, pad = pad)) struct MixedOperation operations::AbstractArray end MixedOperation(channels::Int64, kernel_options::AbstractArray) = MixedOperation([ReLUConv(channels, channels, (i, i), i ÷ 2) for i in kernel_options]) function (m::MixedOperation)(x::AbstractArray, αs::AbstractArray) αs = softmax(αs) #αs = my_softmax(αs) mapreduce((op, α) -> α * op(x), +, m.operations, αs) #errors out in gradient of softmax #sum(αs[i]*m.operations[i](x) for i in 1:length(αs)) end Flux.@functor MixedOperation using Test using CUDA m = MixedOperation(3, [1, 3, 5, 7]) |> gpu αs = rand(Float32, 4) |> gpu test_image = rand(Float32, 16, 16, 3, 1) |> gpu @test sum(m(test_image, αs)) != 0 grad = gradient((x,αs) -> sum(m(x,αs)), test_image, αs) gαs = gradient(Flux.params(αs)) do sum(m(test_image, αs)) end for a in Flux.params(αs) @show gαs[a] @test isa(gαs[a], CuArray) end