How to use Zygote.gradient where mapslices uses a function that includes a parameter?

Original post below. I found the documentation for SliceMap.mapcols which explicitly states this isn’t supported: Note that if f itself contains parameters, their gradients are also not tracked.. So… if I need to do mapslices and track gradients for parameters inside that function, how might I do that?


Test code below. This returns a gradient for m but nothing for k. Changing the value of k does change the result, so shouldn’t there be a gradient? Could it be because of the nested mapcols? data is deliberately not an argument to gradient because that doesn’t change.

If I change the func to just y = m .* k then it returns a gradient for both m and k.

Update: Even without the nested mapcols, just a single mapcols results in the same nothing gradient.

import SliceMap, Zygote
function op(dx, k)
    return dx ./ k
end
function func(m, k, data)
    @assert size(m) == (3, 10)
    @assert size(k) == (3,)
    @assert size(data) == (3, 10)
    y = SliceMap.mapcols(m) do col
        diff = data .- col
        @assert size(diff) == (3, 10)
        s1 = SliceMap.mapcols(diff) do diff_col
            op(diff_col, k)
        end
        @assert size(s1) == (3, 10)
        return sum(s1)
    end
    @assert size(y) == (1, 10)
    return sum(y)
end
function test()
    m = fill(1.0, 3, 10)
    data = fill(2.0, 3, 10)
    k = [0.7, 0.5, 0.5]
    grad = Zygote.gradient((m, k) -> func(m, k, data), m, k)
    val1 = func(m, k, data)
    k2 = k .+ 1.0
    val2 = func(m, k2, data)
    @assert val1 != val2
    return grad
end

I’ll leave this embarrassing question up just in case someone else searches for it without understanding the docstrings first.

SliceMap.slicemap does track the gradients of parameters inside the function:
Parameters within the function f (if there are any) should be correctly tracked, which is not the case for mapcols(). so that will work.