I am trying to construct a Wasserstein GAN as in Gulrajani, et al, Improved Training of Wasserstein GANs. Part of the code is below. The discriminator loss function contains a function of the derivatives of the discriminator with respect to the inputs. This creates problems in differentiating the discriminator loss which I cannot solve.
using CUDA, Flux, Distributions, CSV, DataFrames, ForwardDiff, StatsBase, GLM, Random, Statistics, Base
using Flux.Losses: logitbinarycrossentropy
using Parameters: @with_kw
using Flux.Optimise: update!
using Base.Iterators: partition
using Flux: params
using ReverseDiff
using Distances
global const ε=Float32(1e-6)
function Discriminator()
return Chain(
Dense(14, 21, elu),
Dense(21, 1,sigmoid))
end
function Grad_discriminator_x(dscr,Z)
z=zeros(14)
f = z → dscr(z)[1]
g = z → ReverseDiff.gradient(f,z)
x = mapslices(g, Z; dims=1)
return x
end
function discriminator_loss(dscr, real_input, fake_input)
λ=10.
real_loss = mean(dscr(real_input))
fake_loss = mean(dscr(fake_input))
mix=ε * real_input +(1-ε)fake_input
x=Grad_discriminator_x(dscr,mix)
x = mapslices(Z->Grad_discriminator_x(Discriminator(),Z), mix; dims=1)
norms = colwise(Euclidean(), x, zeros(14))
penalty = λ mean(norms.-1)^2
return fake_loss-real_loss+penalty
end
function Grad_discriminator_loss(dscr, real_input, fake_input)
gradient(()->discriminator_loss(dscr, real_input, fake_input),Flux.params(dscr))
end
real_input=randn(Float64,14, 5);
fake_input=randn(Float64,14, 5);
Grad_discriminator_x(Discriminator(),randn(Float64,14, 5))
discriminator_loss(Discriminator(), real_input, fake_input)
Grad_discriminator_loss(Discriminator(), real_input, fake_input)
Calculating the gradient of the discriminator with respect to the inputs, or the discriminator loss is not problem. However, when I try to evaluate Grad_discriminator_loss I get the following error. Any help would be greatly appreciated.
Grad_discriminator_loss(Discriminator(), real_input, fake_input)
ERROR: Mutating arrays is not supported
Stacktrace:
[1] error(s::String)
@ Base ./error.jl:33
[2] (::Zygote.var"#403#404")(#unused#::Nothing)
@ Zygote ~/.julia/packages/Zygote/i1R8y/src/lib/array.jl:58
[3] (::Zygote.var"#2259#back#405"{Zygote.var"#403#404"})(Δ::Nothing)
@ Zygote ~/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:59
[4] Pullback
@ ~/.julia/packages/Distances/gnt89/src/generic.jl:83 [inlined]
[5] (::typeof(∂(colwise!)))(Δ::Vector{Float64})
@ Zygote ~/.julia/packages/Zygote/i1R8y/src/compiler/interface2.jl:0
[6] Pullback
@ ~/.julia/packages/Distances/gnt89/src/generic.jl:163 [inlined]
[7] (::typeof(∂(colwise)))(Δ::Vector{Float64})
@ Zygote ~/.julia/packages/Zygote/i1R8y/src/compiler/interface2.jl:0
[8] Pullback
@ ./REPL[11]:8 [inlined]
[9] (::typeof(∂(discriminator_loss)))(Δ::Float64)
@ Zygote ~/.julia/packages/Zygote/i1R8y/src/compiler/interface2.jl:0
[10] Pullback
@ ./REPL[12]:2 [inlined]
[11] (::typeof(∂(λ)))(Δ::Float64)
@ Zygote ~/.julia/packages/Zygote/i1R8y/src/compiler/interface2.jl:0
[12] (::Zygote.var"#69#70"{Zygote.Params, typeof(∂(λ)), Zygote.Context})(Δ::Float64)
@ Zygote ~/.julia/packages/Zygote/i1R8y/src/compiler/interface.jl:255
[13] gradient(f::Function, args::Zygote.Params)
@ Zygote ~/.julia/packages/Zygote/i1R8y/src/compiler/interface.jl:59
[14] Grad_discriminator_loss(dscr::Chain{Tuple{Dense{typeof(elu), Matrix{Float32}, Vector{Float32}}, Dense{typeof(σ), Matrix{Float32}, Vector{Float32}}}}, real_input::Matrix{Float64}, fake_input::Matrix{Float64})
@ Main ./REPL[12]:2
[15] top-level scope
@ REPL[18]:1