Getting around `Zygote` mutating array issue

I’m faced with the Mutating arrays is not supported error message and looking for ideas/help. In particular, I am trying to compute gradients of the following metric that measures diversity across a multi-dimensional array \mathbf{X}_{(M \times N \times D)} via deterministic point processes:

\text{ddp_diversity} = \det(K)

where K_{i,j}=(1+dist(x_i,x_j))^{-1} and K is {(D\times D)}.

The code below implements the metric:

using LinearAlgebra
function ddp_diversity(X::AbstractArray)
    D = ndims(X)
    @assert D == 3
    n_samples = size(X,D)
    K = zeros(n_samples, n_samples)
    for i ∈ 1:n_samples
        for j ∈ 1:n_samples
            sᵢ = selectdim(X,D,i)
            sⱼ = selectdim(X,D,j)
            K[i,j] += 1/(1 + norm(sᵢ .- sⱼ))
        end
    end
    return det(K)
end

Trying to compute the gradient with respect to \mathbf{X} causes the error:

julia> using Zygote: gradient

julia> function f_errors(X)
           @assert ndims(X)==3
           term1 = norm(X)
           term2 = ddp_diversity(X)
           return term1 + term2
       end
f_errors (generic function with 1 method)

julia> gradient(() -> f_errors(X), Flux.params(X))[X]
ERROR: Mutating arrays is not supported -- called setindex!(::Matrix{Float64}, _...)

I’ve seen here that the @ignore macro can be used to ignore parts that involve array mutation. But I’m still struggling to get that to work. I can do the following:

using Zygote: gradient, @ignore
function f_ignore_term2(X)
    @assert ndims(X)==3
    term1 = norm(X)
    term2 = @ignore ddp_diversity(X)
    return term1 + term2
end

gs_ignored = gradient(() -> f_ignore_term2(X), Flux.params(X))[X]

But that corresponds to just not including ddp_diversity at all:

julia> function f_without_term2(X)
           @assert ndims(X)==3
           term1 = norm(X)
           return term1
       end
f_without_term2 (generic function with 1 method)

julia> gs_without = gradient(() -> f_without_term2(X), Flux.params(X))[X];

julia> all(gs_ignored .== gs_without)
true

Any ideas what part exactly I can ignore to still actually compute a gradient with respect to the diversity metric? Or any ideas how to construct the measure without array mutation?

Thanks!

It looks like a relatively simple function in the grand scheme of things. Any hope to derive the custom adjoint rule by hand? The backprop of each element looks like it may be relatively separate which often makes calculating the adjoints easy for the interior parts to calculate K if you either call out to or copy/paste the existing adjoint from the det(K) via https://github.com/JuliaDiff/ChainRules.jl/blob/f13e0a45d10bb13f48d6208e9c9d5b4a52b96732/src/rulesets/LinearAlgebra/dense.jl#L129-L136

If performance is important, you may want to rewrite the det(K) anyways so that you can store a matrix factorization to reuse in the backprop rule.

Basically anytime I have had mutattion with zygote, it has made sense to create a custom rule. But when speed is required, this isn’t just a zygote thing, and you probably would want to do it with other AD systems.

1 Like

Using an array comprehension works for me, i.e.,

function ddp_diversity(X::AbstractArray)
    D = ndims(X)
    @assert D == 3
    n_samples = size(X,D)
    K = [1/(1 + norm(selectdim(X,D,i) .- selectdim(X,D,j)))
         for i ∈ 1:n_samples, j ∈ 1:n_samples]
    return det(K)
end

Find that more readable anyways …

1 Like

Using eachslice will often be more efficient than explicitly indexing:

julia> function ddp_diversity_3(X::AbstractArray{<:Real, 3})
           xs = eachslice(X, dims = ndims(X))
           K = [1/(1 + norm(x .- y)) for x in xs, y in xs]
           return det(K)
       end
ddp_diversity_3 (generic function with 1 method)

julia> x10 = rand(10,10,10);

julia> ddp_diversity(x10)  # from @bertschi above
0.3929850820872172

julia> ddp_diversity_3(x10)
0.3929850820872172

julia> @btime gradient(ddp_diversity, $x10);
  27.176 ms (61786 allocations: 6.18 MiB)

julia> @btime gradient(ddp_diversity_3, $x10);
  171.708 μs (1415 allocations: 542.86 KiB)
3 Likes

Wow! It’s been less than a day and not only did you provide me with a working solution, you also optimised performance – love this community. Thanks all :hugs:

In case of interest, this will actually help me to add another methodology to a package I’ll present at #juliacon this month: CounterfactualExplanations.jl. I’ll link this thread in comments and post the final implementation here, once it’s done.

1 Like

Woohoo :tada: it does what it should

dice_intro

2 Likes