Zygote error - Mutating arrays is not supported

Hi all, I am converting a function originally witten in python to Julia. but I have some problems working with the autodiff function provided by Zygote when I try to take the gradient of the function ( (I am able to do this seamlessly with Jax and pytorch)
I understand that the lines where I fill in arrays P_log and P_norm are the problem but I am at a loss at what to do because P[kvals[i]:kvals[i + 1]-1] has a dynamic shape and sometimes produces a jagged array (i.e for each value of i, a different sized array might be produced - hence vcat is not an option).
I will appreciate some some pointers on how to tackle this problem (Still trying to understand the inner workings of Julia)

function compute_total_obj(P::AbstractVector, 
    F::AbstractArray, 
    Z::AbstractArray, 
    Zerr_Re::AbstractArray,
    Zerr_Im::AbstractArray, 
    LB::AbstractVector, 
    UB::AbstractVector, 
    smf::AbstractVector, 
    func::Function, 
    num_params::Integer, 
    num_eis::Integer, 
    kvals::AbstractVector, 
    d2m::AbstractMatrix)

    # The lines causing the error
    P_log = zeros(num_params, num_eis)
    P_norm = zeros(num_params, num_eis)
    for i = 1:num_params
        P_log[i, :] .= (P[kvals[i]:kvals[i + 1]-1])

        P_norm[i, :] .= ((LB[kvals[i]:kvals[i + 1]-1] .+ (10 .^ P[kvals[i]:kvals[i + 1]-1])) ./ (1 .+ (10 .^ P[kvals[i]:kvals[i + 1]-1]) ./ UB[kvals[i]:kvals[i + 1]-1]))
    end
    smf_1 = ifelse.(isinf.(smf), 0.0, smf)
    chi_smf = sum(sum(((d2m * transpose(P_log)) .* (d2m * transpose(P_log))), dims = 1) .* smf_1)
    wrss_tot = compute_wrss.(eachcol(P_norm), eachcol(F), eachcol(Z), eachcol(Zerr_Re), eachcol(Zerr_Im), func)
    return (sum(wrss_tot) + chi_smf)
end

# Taking the gradient
    function g!(G, p)
        G .= Zygote.gradient(p -> compute_total_obj(p, F, Z, Zerr_Re, Zerr_Im, lb_vec, ub_vec, smf, func, num_params, num_eis, kvals, d2m), p)[1]
    end

# How kvals is formed (if any element of smf is Inf, a jagged array is produced
function get_kvals(smf::AbstractVector, num_eis::Integer)
    kvals = cumsum(insert!(ifelse.(isinf.(smf), 1, num_eis), 1, 1),)
    return kvals
end

# Error
# Output exceeds the size limit. Open the full output data in a text editor
# Mutating arrays is not supported -- called copyto!(SubArray{Float64, 1, Matrix{Float64}, Tuple{Int64, Base.Slice{Base.OneTo{Int64}}}, true}, ...)
# This error occurs when you ask Zygote to differentiate operations that change
# the elements of arrays in place (e.g. setting values with x .= ...)

# Possible fixes:
# - avoid mutating operations (preferred)
# - or read the documentation and solutions for this error
#   https://fluxml.ai/Zygote.jl/latest/limitations


# Stacktrace:
#   [1] error(s::String)
#     @ Base .\error.jl:35
#   [2] _throw_mutation_error(f::Function, args::SubArray{Float64, 1, Matrix{Float64}, Tuple{Int64, Base.Slice{Base.OneTo{Int64}}}, true})
#     @ Zygote C:\Users\richinex\.julia\packages\Zygote\dABKa\src\lib\array.jl:68
#   [3] (::Zygote.var"#389#390"{SubArray{Float64, 1, Matrix{Float64}, Tuple{Int64, Base.Slice{Base.OneTo{Int64}}}, true}})(#unused#::Nothing)
#     @ Zygote C:\Users\richinex\.julia\packages\Zygote\dABKa\src\lib\array.jl:83
#   [4] (::Zygote.var"#2474#back#391"{Zygote.var"#389#390"{SubArray{Float64, 1, Matrix{Float64}, Tuple{Int64, Base.Slice{Base.OneTo{Int64}}}, true}}})(Δ::Nothing)
#     @ Zygote C:\Users\richinex\.julia\packages\ZygoteRules\AIbCs\src\adjoint.jl:67
#   [5] Pullback
#     @ .\broadcast.jl:871 [inlined]
#   [6] Pullback
#     @ .\broadcast.jl:868 [inlined]
#   [7] Pullback
#     @ .\broadcast.jl:864 [inlined]
# ...
#     @ Optim C:\Users\richinex\.julia\packages\Optim\Zq1jM\src\multivariate\optimize\interface.jl:150
#  [22] fit_deterministic(P::Vector{Float64}, freq::Vector{Float64}, Z::Matrix{ComplexF64}, bounds::Vector{Vector{Float64}}, smf::Vector{Float64}, func::Function; weight::Type)
#     @ Main d:\julia_projects\error_calc_comparison\multi_2.ipynb:43
#  [23] top-level scope
#     @ d:\julia_projects\error_calc_comparison\multi_2.ipynb:1

Given this wouldn’t work in JAX either (it might not error, but I’m pretty sure it’s at best undefined behaviour), this Julia code is presumably not a direct translation? If you can post a working Python example, we may be able to give some advice on a better one.

Edit: does the forward pass even work in Julia? This statement:

Is also true for the loop, because you can’t broadcast a length M array into a length N > M slice. So please do share a working Python example :slight_smile:

Hi ToucheSir, The code is pretty long and that is why I did not post it earlier. Here is the link to the complete code written in JAX which also contains the test data. I have just run it in google colab so I guess its working. Kindly go through and let me know how I can reproduce the compute_total_obj best in julia. Thanks.

Doesn’t look like the notebook is publicly accessible, can you make it so, toss it in a GitHub gist or attach it here?

I’ve made it public.

Gist: jax_optim.ipynb · GitHub

To make comparison easier I have also attached my julia example notebook as a gist: julia_optim.ipynb · GitHub

Thanks! Looking at a line like

P_log = P_log.at[i, :].set(P[self.kvals[i]:self.kvals[i + 1]])

Is it correct to say that len(P[self.kvals[i]:self.kvals[i + 1]]) == self.kvals[i + 1] - self.kvals[i] == self.num_eis? If so, that should make the Julia code far easier to write. If not, you’ll have to explain to me how this code interacts with 🔪 JAX - The Sharp Bits 🔪 — JAX documentation :smile:

Is it correct to say that len(P[self.kvals[i]:self.kvals[i + 1]]) == self.kvals[i + 1] - self.kvals[i] == self.num_eis ?
This is only true when no value in smf is Inf. when a value in smf is set to Inf, the intervals produced by the values in kval are not equal so I assign a scalar to P_log.at[i,:] for instance something like P_log.at[i,:].set(0.999). I found a workarround using julia’s ternary but it might be suboptimal. Maybe you get more hints on what I am trying to achieve and help to make it better. What I do is this: when the interval is just 1 value, I repeat that value to the size of num_eis.

function compute_total_obj(P::AbstractVector, 
    F::AbstractArray, 
    Z::AbstractArray, 
    Zerr_Re::AbstractArray,
    Zerr_Im::AbstractArray, 
    LB::AbstractVector, 
    UB::AbstractVector, 
    smf::AbstractVector, 
    func::Function, 
    num_params::Integer, 
    num_eis::Integer, 
    kvals::AbstractVector, 
    d2m::AbstractMatrix)

    P_log = vcat(transpose.([length(P[kvals[i]:kvals[i + 1]-1]) == 1 ? repeat(P[kvals[i]:kvals[i + 1]-1], num_eis) : P[kvals[i]:kvals[i + 1]-1]  for i = 1:num_params])...)
    P_norm = vcat(transpose.([length(P[kvals[i]:kvals[i + 1]-1]) == 1 ? repeat(((LB[kvals[i]:kvals[i + 1]-1] .+ (10 .^ P[kvals[i]:kvals[i + 1]-1])) ./ (1 .+ (10 .^ P[kvals[i]:kvals[i + 1]-1]) ./ UB[kvals[i]:kvals[i + 1]-1])), num_eis) : ((LB[kvals[i]:kvals[i + 1]-1] .+ (10 .^ P[kvals[i]:kvals[i + 1]-1])) ./ (1 .+ (10 .^ P[kvals[i]:kvals[i + 1]-1]) ./ UB[kvals[i]:kvals[i + 1]-1])) for i = 1:num_params])...)
    smf_1 = ifelse.(isinf.(smf), 0.0, smf)
    chi_smf = sum(sum(((d2m * transpose(P_log)) .* (d2m * transpose(P_log))), dims = 1) .* smf_1)
    wrss_tot = compute_wrss.(eachcol(P_norm), eachcol(F), eachcol(Z), eachcol(Zerr_Re), eachcol(Zerr_Im), func)
    return (sum(wrss_tot) + chi_smf)

The big idea is that I wanna be able to keep a parameter constant during optimization by setting it’s value in smf to Inf

I had to test this locally, but I think I understand now. JAX lets you do the equivalent of array[:] = [x], so len(P[self.kvals[i]:self.kvals[i + 1]]) may also be 1.

Your Julia implementation looks pretty reasonable. Here is my attempt at cleaning it up a bit:

P_log = reduce(hcat, [
    let istart=kvals[i], istop=kvals[i + 1] - 1, ps=@view P[istart:istop]
        istop - istart == 1 ? repeat(ps, num_eis) : ps
    end
    for i = 1:num_params
])
# Alternative way of doing this using map
# Not sure which is faster, feel free to try both
P_norm = reduce(hcat, map(1:num_params) do i
    istart, istop = kvals[i], kvals[i + 1] - 1
    # @view and @views reduces memory allocations when you're not differentiating this function with Zygote
    ps = @view P[istart:istop]
    ps_norm = @views (LB[istart:istop] .+ (10 .^ ps)) ./ (1 .+ (10 .^ ps) ./ UB[istart:istop])
    ub - lb == 1 ? repeat(ps, num_eis) : ps
end)

smf_1 = ifelse.(isinf.(smf), 0.0, smf)
chi_smf = sum(sum((d2m * transpose(P_log)).^2, dims = 1) .* smf_1)
wrss_tot = compute_wrss.(eachcol(P_norm), eachcol(F), eachcol(Z), eachcol(Zerr_Re), eachcol(Zerr_Im), func)
return sum(wrss_tot) + chi_smf

Thanks. Your clean up works but takes so much time (about 12 mins) to compile compared to using the one I posted earlier (using vcat) which takes less than 2 mins. I wonder why.

It works now., I spotted the error. ub - lb == 1 ? repeat(ps, num_eis) : ps should be istop - istart == 1 ? repeat(ps_norm, num_eis) : ps_norm

1 Like