I created a custom Lux layer to evaluate a polynomial. Now, I have N point entering the layer, transformed, and input into a loss function. I created a loop to evaluate the polynomial since I do not understand broadcasting enough to do it. But I then realized that my loops are mutating and Zygote will not be able to calculate the gradient of such a function. Here is the code:
function (l::Polylayer)(x::AbstractMatrix, ps, st::NamedTuple)
c = ps.coeffs
x1 = reshape(x, l.out_dims, :) # lasts dimension is the number of training samples
N = size(x1, length(size(x1)))
sum = zeros(l.out_dims, N)
for i in 1:N
sum[:, i] .= c[:, end]
end
for d in l.degree : -1 : 1
for i in 1:N
sum[:, i] .= sum[:, i] .* x1[:, i] .+ c[:, d]
end
end
return sum, st
end
Clearly, the line sum[:, i] .= sum[:, i] .* x1[:, i] .+ c[:, d] is problematic if I want to take a gradient.
So my question is, how can I transform the loop using broadcast?
function (l::Polylayer)(x::AbstractMatrix, ps, st::NamedTuple)
c = ps.coeffs
x1 = reshape(x, l.out_dims, :) # lasts dimension is the number of training samples
N = size(x1, length(size(x1)))
sum = zeros(l.out_dims, N)
sum2 = reduce(hcat,map(1:N) do i
c[:, end]
end)
for d in l.degree : -1 : 1
sum2 = reduce(hcat,map(1:N) do i
sum2[:, i] .* x1[:, i] .+ c[:, d]
end)
end
return sum2, st
end
I have here the source to evalpoly. It looks like this is a mutating implementation:
function evalpoly(x, p::Tuple)
if @generated
N = length(p.parameters::Core.SimpleVector)
ex = :(p[end])
for i in N-1:-1:1
ex = :(muladd(x, $ex, p[$i]))
end
ex
else
_evalpoly(x, p)
end
end
Isn’t ex overwritten? Why won’t Zygote have a problem with this?
In my code, I have an array and am mutating its elements, which Zygote does not like.
If I evaluate evalpoly with a vector argument and the .evalpoly notation (notice the dot), am I not mutating? For example,
y .= evalpoly(x, Ref((1,0,2))) # Mutating or non-mutating?
where y is a vector of size 128, for example. Should one replace .= by = to make it mutating, at the expense of using more memory? Thanks.
Got it. I know I can unroll the polynomials of course, but I will want to handle the multivariate case. I wonder if I can use the ModelingToolkit to help with this task. I also have to read why Zygote cannot handle mutations, or perhaps it is only the forward differentiation that must be immutable? I have not read much about this as yet.
After reading documentation and considering different solution, I have decided to look into a custom derivative rule with ChainCore.jl. We’ll see how that goes. If it doesn’t, I might have to try using Python, even though it is much slower. I have to take development time into consideration . Ughh.
Besides the fact that assignment is not the same as mutation, I should point out that this ex transformation is not runtime code — it’s not what the AD system is analyzing — it’s metaprogramming code that runs at compile time (because this is a @generated function).
Yes. In fact, a mask is probably what I need to solve my problem without mutation. But that takes me Ina direction I don’t have time for at the moment. Thanks for the observation.