Adjoint for threaded map (ThreadsX.map)

There is this one, for map(f, ::Tuple), and Zygote has this. The basic idea is quite simple, but there are elaborations to deal with map(f, x, y, z), and to be more efficient in some cases.

function ChainRulesCore.rrule(config::RuleConfig{>:HasReverseMode}, ::typeof(my_map), f, X::AbstractArray)
    hobbits = my_map(X) do x  # this makes an array of tuples
        y, back = rrule_via_ad(config, f, x)
    end
    Y = map(first, hobbits)
    function map_pullback(dY_raw)
        dY = unthunk(dY_raw)
        # Should really do these in the reverse order
        backevals = my_map(hobbits, dY) do (y, back), dy
            dx, dx = back(dy)
        end
        df = ProjectTo(f)(sum(first, backevals))
        dX = map(last, backevals)
        return (NoTangent(), df, dX)
    end
    return Y, map_pullback
end

my_map(f, Xs...) = map(@show(f), Xs...)

gradient(x -> sum(map(inv, x)), [1,2,3.0])
gradient(x -> sum(my_map(inv, x)), [1,2,3.0])  # dX

gradient(x -> sum(map(z -> z/x, 1:3)), 4.0)
gradient(x -> sum(my_map(z -> z/x, 1:3)), 4.0)  # df
3 Likes