I’m faced with the Mutating arrays is not supported
error message and looking for ideas/help. In particular, I am trying to compute gradients of the following metric that measures diversity across a multi-dimensional array \mathbf{X}_{(M \times N \times D)} via deterministic point processes:
where K_{i,j}=(1+dist(x_i,x_j))^{-1} and K is {(D\times D)}.
The code below implements the metric:
using LinearAlgebra
function ddp_diversity(X::AbstractArray)
D = ndims(X)
@assert D == 3
n_samples = size(X,D)
K = zeros(n_samples, n_samples)
for i ∈ 1:n_samples
for j ∈ 1:n_samples
sᵢ = selectdim(X,D,i)
sⱼ = selectdim(X,D,j)
K[i,j] += 1/(1 + norm(sᵢ .- sⱼ))
end
end
return det(K)
end
Trying to compute the gradient with respect to \mathbf{X} causes the error:
julia> using Zygote: gradient
julia> function f_errors(X)
@assert ndims(X)==3
term1 = norm(X)
term2 = ddp_diversity(X)
return term1 + term2
end
f_errors (generic function with 1 method)
julia> gradient(() -> f_errors(X), Flux.params(X))[X]
ERROR: Mutating arrays is not supported -- called setindex!(::Matrix{Float64}, _...)
I’ve seen here that the @ignore
macro can be used to ignore parts that involve array mutation. But I’m still struggling to get that to work. I can do the following:
using Zygote: gradient, @ignore
function f_ignore_term2(X)
@assert ndims(X)==3
term1 = norm(X)
term2 = @ignore ddp_diversity(X)
return term1 + term2
end
gs_ignored = gradient(() -> f_ignore_term2(X), Flux.params(X))[X]
But that corresponds to just not including ddp_diversity
at all:
julia> function f_without_term2(X)
@assert ndims(X)==3
term1 = norm(X)
return term1
end
f_without_term2 (generic function with 1 method)
julia> gs_without = gradient(() -> f_without_term2(X), Flux.params(X))[X];
julia> all(gs_ignored .== gs_without)
true
Any ideas what part exactly I can ignore to still actually compute a gradient with respect to the diversity metric? Or any ideas how to construct the measure without array mutation?
Thanks!