There is this one, for map(f, ::Tuple)
, and Zygote has this. The basic idea is quite simple, but there are elaborations to deal with map(f, x, y, z)
, and to be more efficient in some cases.
function ChainRulesCore.rrule(config::RuleConfig{>:HasReverseMode}, ::typeof(my_map), f, X::AbstractArray)
hobbits = my_map(X) do x # this makes an array of tuples
y, back = rrule_via_ad(config, f, x)
end
Y = map(first, hobbits)
function map_pullback(dY_raw)
dY = unthunk(dY_raw)
# Should really do these in the reverse order
backevals = my_map(hobbits, dY) do (y, back), dy
dx, dx = back(dy)
end
df = ProjectTo(f)(sum(first, backevals))
dX = map(last, backevals)
return (NoTangent(), df, dX)
end
return Y, map_pullback
end
my_map(f, Xs...) = map(@show(f), Xs...)
gradient(x -> sum(map(inv, x)), [1,2,3.0])
gradient(x -> sum(my_map(inv, x)), [1,2,3.0]) # dX
gradient(x -> sum(map(z -> z/x, 1:3)), 4.0)
gradient(x -> sum(my_map(z -> z/x, 1:3)), 4.0) # df