Mutating versus non-mutating arrays for Zygote Gradient

I created a custom Lux layer to evaluate a polynomial. Now, I have N point entering the layer, transformed, and input into a loss function. I created a loop to evaluate the polynomial since I do not understand broadcasting enough to do it. But I then realized that my loops are mutating and Zygote will not be able to calculate the gradient of such a function. Here is the code:

function (l::Polylayer)(x::AbstractMatrix, ps, st::NamedTuple)
    c = ps.coeffs
    x1 = reshape(x, l.out_dims, :) # lasts dimension is the number of training samples

    N = size(x1, length(size(x1)))
    sum = zeros(l.out_dims, N)

    for i in 1:N
        sum[:, i] .= c[:, end]
    end
    for d in l.degree : -1 : 1
        for i in 1:N
            sum[:, i] .= sum[:, i] .* x1[:, i]  .+ c[:, d]
        end
    end
    return sum, st
end

Clearly, the line sum[:, i] .= sum[:, i] .* x1[:, i] .+ c[:, d] is problematic if I want to take a gradient.
So my question is, how can I transform the loop using broadcast?

Thanks,

1 Like

Just use mapping constructs instead.

function (l::Polylayer)(x::AbstractMatrix, ps, st::NamedTuple)
    c = ps.coeffs
    x1 = reshape(x, l.out_dims, :) # lasts dimension is the number of training samples

    N = size(x1, length(size(x1)))
    sum = zeros(l.out_dims, N)

    sum2 = reduce(hcat,map(1:N) do i
         c[:, end]
    end)
    for d in l.degree : -1 : 1
        sum2 = reduce(hcat,map(1:N) do i
            sum2[:, i] .* x1[:, i]  .+ c[:, d]
        end)
    end
    return sum2, st
end
1 Like

I have here the source to evalpoly. It looks like this is a mutating implementation:

function evalpoly(x, p::Tuple)
    if @generated
        N = length(p.parameters::Core.SimpleVector)
        ex = :(p[end])
        for i in N-1:-1:1
            ex = :(muladd(x, $ex, p[$i]))
        end
        ex
    else
        _evalpoly(x, p)
    end
end

Isn’t ex overwritten? Why won’t Zygote have a problem with this?
In my code, I have an array and am mutating its elements, which Zygote does not like.

If I evaluate evalpoly with a vector argument and the .evalpoly notation (notice the dot), am I not mutating? For example,

y .= evalpoly(x, Ref((1,0,2)))    # Mutating or non-mutating?

where y is a vector of size 128, for example. Should one replace .= by = to make it mutating, at the expense of using more memory? Thanks.

It’s overwriting, not mutating.

Got it. I know I can unroll the polynomials of course, but I will want to handle the multivariate case. I wonder if I can use the ModelingToolkit to help with this task. I also have to read why Zygote cannot handle mutations, or perhaps it is only the forward differentiation that must be immutable? I have not read much about this as yet.

After reading documentation and considering different solution, I have decided to look into a custom derivative rule with ChainCore.jl. We’ll see how that goes. If it doesn’t, I might have to try using Python, even though it is much slower. I have to take development time into consideration :slight_smile: . Ughh.

Merry Christmas!

Besides the fact that assignment is not the same as mutation, I should point out that this ex transformation is not runtime code — it’s not what the AD system is analyzing — it’s metaprogramming code that runs at compile time (because this is a @generated function).

Yes. In fact, a mask is probably what I need to solve my problem without mutation. But that takes me Ina direction I don’t have time for at the moment. Thanks for the observation.

Cheers,