Original post below. I found the documentation for SliceMap.mapcols which explicitly states this isn’t supported: Note that if
f itself contains parameters, their gradients are also not tracked.
. So… if I need to do mapslices and track gradients for parameters inside that function, how might I do that?
Test code below. This returns a gradient for m but nothing
for k. Changing the value of k does change the result, so shouldn’t there be a gradient? Could it be because of the nested mapcols? data
is deliberately not an argument to gradient because that doesn’t change.
If I change the func to just y = m .* k
then it returns a gradient for both m and k.
Update: Even without the nested mapcols, just a single mapcols results in the same nothing
gradient.
import SliceMap, Zygote
function op(dx, k)
return dx ./ k
end
function func(m, k, data)
@assert size(m) == (3, 10)
@assert size(k) == (3,)
@assert size(data) == (3, 10)
y = SliceMap.mapcols(m) do col
diff = data .- col
@assert size(diff) == (3, 10)
s1 = SliceMap.mapcols(diff) do diff_col
op(diff_col, k)
end
@assert size(s1) == (3, 10)
return sum(s1)
end
@assert size(y) == (1, 10)
return sum(y)
end
function test()
m = fill(1.0, 3, 10)
data = fill(2.0, 3, 10)
k = [0.7, 0.5, 0.5]
grad = Zygote.gradient((m, k) -> func(m, k, data), m, k)
val1 = func(m, k, data)
k2 = k .+ 1.0
val2 = func(m, k2, data)
@assert val1 != val2
return grad
end