I am trying to add a gradient norm penalty to a loss function similar to WGAN-GP ([1704.00028] Improved Training of Wasserstein GANs, Alg 1, Lines 7-9). However, I am running into problems with mutation. Is there another way to do this, or does something have to change in Zygote to accomodate this use-case?
The following code snippet,
using Flux, LinearAlgebra
m = Chain(Dense(100, 50, relu), Dense(50, 2), softmax);
opt = Descent(0.01);
data, labels = rand(Float32, 100, 100), rand(0:1, 100);
labels = reshape(hcat(labels, 1 .- labels), (2,100))
loss(m, x, y) = sum(Flux.crossentropy(m(x), y));
function get_grad(m, data, labels, ps)
gs = gradient(ps) do
l = sum(LinearAlgebra.norm, get_grad_inner(m,data,labels,Flux.params(data)))
end
end
function get_grad_inner(m, data, labels, ps)
gs = gradient(ps) do
l = loss(m, data, labels)
end
end
g1 = get_grad(m, data, labels, Flux.params(m))
results in an error
ERROR: Mutating arrays is not supported -- called copyto!(::Matrix{Float32}, _...)
Stacktrace:
...