Zygote in Turing: Mutating arrays is not supported

I am trying to use Zygote in Turing. Here’s my MWE

@model function mwe(B, ::Type{T} = Vector{Float64}) where {T}
    μ =  mapslices(lsexp, B; dims=[1])[1,:]
end
mwe_model = mwe([1.0 2; 3 4])
mwe_ch = sample(mwe_model, NUTS(.65),10)

and the error is

Mutating arrays is not supported

Stacktrace:
 [1] error(::String) at .\error.jl:33
 [2] (::Zygote.var"#1052#1053")(::Nothing) at C:\Users\Manoj\.julia\packages\Zygote\YeCEW\src\lib\array.jl:64
 [3] (::Zygote.var"#2785#back#1054"{Zygote.var"#1052#1053"})(::Nothing) at C:\Users\Manoj\.julia\packages\ZygoteRules\6nssF\src\adjoint.jl:49
 [4] materialize! at .\broadcast.jl:823 [inlined]
 [5] concatenate_setindex! at .\abstractarray.jl:2058 [inlined]
 [6] (::typeof(∂(concatenate_setindex!)))(::Nothing) at C:\Users\Manoj\.julia\packages\Zygote\YeCEW\src\compiler\interface2.jl:0
 [7] (::Zygote.var"#174#175"{typeof(∂(concatenate_setindex!)),Tuple{Tuple{Nothing,Nothing},Int64}})(::Nothing) at C:\Users\Manoj\.julia\packages\Zygote\YeCEW\src\lib\lib.jl:182
 [8] (::Zygote.var"#347#back#176"{Zygote.var"#174#175"{typeof(∂(concatenate_setindex!)),Tuple{Tuple{Nothing,Nothing},Int64}}})(::Nothing) at C:\Users\Manoj\.julia\packages\ZygoteRules\6nssF\src\adjoint.jl:49
 [9] inner_mapslices! at .\abstractarray.jl:2039 [inlined]
 [10] (::typeof(∂(inner_mapslices!)))(::Nothing) at C:\Users\Manoj\.julia\packages\Zygote\YeCEW\src\compiler\interface2.jl:0
 [11] #mapslices#115 at .\abstractarray.jl:2029 [inlined]
 [12] (::typeof(∂(#mapslices#115)))(::Nothing) at C:\Users\Manoj\.julia\packages\Zygote\YeCEW\src\compiler\interface2.jl:0 (repeats 2 times)
 [13] macro expansion at .\In[19]:2 [inlined]

I tried to read about this error, and found a long discussion here: https://github.com/FluxML/Zygote.jl/issues/377 . I was not sure what the fix is.

I’m hoping someone can tell me how to fix this. Note: use of mapslices is fairly trivial here because this is a minimal example which does nothing much, but I do need something very much like that in my actual code. Nonmutating options would be fine if you can suggest some. I am still fairly new to Julia and not sure what mutates and what doesn’t mutate.

1 Like

I realized my example leaves out description of lsexp. This is logsumexp written as:

function lsexp(w)
    N = length(w)
    offset, maxind = findmax(w)
    w .= exp.(w .- offset)
    Σ = _sum_all_but(w, maxind)
    log1p(Σ) + offset
end

function _sum_all_but(w, i)
    w[i] -= 1
    s = sum(w)
    w[i] += 1
    s
end

You may like SliceMap.jl which has Zygote-compatible replacements for mapslices.

But these parts also mutate w, which you will need to avoid:

Another approach would be to provide a gradient for the function lsexp, so that Zygote does not need to look inside. This must have been worked out somewhere, perhaps https://github.com/FluxML/Zygote.jl/issues/2 has what you want?

In fact, you should probably write a function lsexp(B; dims) which works directly on the whole array. This is very likely to be more efficient than making slices & acting on each of them.

3 Likes