I am trying to make a gan in which the generator is generated in a fully parametric way
using Flux, StatsBase
the discriminator is a very simple neural network for simplicity
function Discriminator()
return Chain(
Dense(10 , 1, sigmoid)
)
end
The parameters will me a mean μ (10 dimensional vector) and a standard deviation is σ. They are in a dictionary
p = Dict(:μ => ones(Float32,10), σ=>1);
I make a struct which contains the parameters
struct B
μ
σ
end
and define a model for an input z-> (z+μ)σ
z=randn(Float32,10)
m(μ,σ,z)= (z+μ)σ
(b::B)(z)=m(b.μ,b.σ,z)
G=B(ones(Float32,10),1)
I can now generated 10 dimensional vectors (z+μ)σ
G(randn(Float32,10))
If I define a function such as for example sum(G(z).^2), then the gradient can be calculated in Flux.
dG = gradient(G → sum(G(z).^2), G)[1]
or equivalently I could do the following:
g=G → sum(G(z).^2)
gradient(g,G)
I generate a random matrix of dimension 10x100 which I want to use to generate fake sample of 100 vectors of dimension 10 arranged in a 10x100
Z=randn(Float32,10,100);
something like this:
fake_data=hcat(G.(eachcol(Z))…)
If I take a function of like the mean of this generated matrix the gradient is well defined:
mean(hcat(G.(eachcol(Z))…))
dG = gradient(G → mean(hcat(G.(eachcol(Z))…)), G)[1]
or alternatively
g=G → mean(hcat(G.(eachcol(Z))…))
gradient(g,G)
All of the above works well. However, suppose that I feed my 100 observations to the discriminator and then I take the mean. The function is well defined.
mean(Discriminator()(hcat(G.(eachcol(Z))…)))
The gradient of such function should exist (this next train_steps is done regularly in the estimation of GANs with G replaced by a neural network).
However there is an error in the calculation of the gradient
dG = gradient(G → mean(Discriminator()(hcat(G.(eachcol(Z))…))), G)[1]
The same error occurs if I use the equivalent formulation
g=G → mean(Discriminator()(hcat(G.(eachcol(Z))…)))
gradient(g,G)
The error I get is:
julia> dG = gradient(G → mean(Discriminator()(hcat(G.(eachcol(Z))…))), G)[1]
ERROR: LoadError: Mutating arrays is not supported – called setindex!(::Vector{Float32}, _…)
Stacktrace:
[1] error(s::String)
@ Base ./error.jl:33
[2] (::Zygote.var"#437#438"{Vector{Float32}})(#unused#::Nothing)
@ Zygote ~/.julia/packages/Zygote/kUefI/src/lib/array.jl:71
[3] (::Zygote.var"#2320#back#439"{Zygote.var"#437#438"{Vector{Float32}}})(Δ::Nothing)
@ Zygote ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:67
[4] Pullback
@ ./array.jl:335 [inlined]
[5] (::typeof(∂(fill!)))(Δ::Vector{Float32})
@ Zygote ~/.julia/packages/Zygote/kUefI/src/compiler/interface2.jl:0
[6] Pullback
@ ~/.julia/packages/Flux/BPPNj/src/utils.jl:385 [inlined]
[7] (::typeof(∂(create_bias)))(Δ::Vector{Float32})
@ Zygote ~/.julia/packages/Zygote/kUefI/src/compiler/interface2.jl:0
[8] Pullback
@ ~/.julia/packages/Flux/BPPNj/src/layers/basic.jl:128 [inlined]
[9] (::typeof(∂(Dense)))(Δ::NamedTuple{(:weight, :bias, :σ), Tuple{Matrix{Float32}, Vector{Float32}, Nothing}})
@ Zygote ~/.julia/packages/Zygote/kUefI/src/compiler/interface2.jl:0
[10] Pullback
@ ~/.julia/packages/Flux/BPPNj/src/layers/basic.jl:151 [inlined]
[11] (::typeof(∂(#Dense#154)))(Δ::NamedTuple{(:weight, :bias, :σ), Tuple{Matrix{Float32}, Vector{Float32}, Nothing}})
@ Zygote ~/.julia/packages/Zygote/kUefI/src/compiler/interface2.jl:0
[12] Pullback
@ ~/.julia/packages/Flux/BPPNj/src/layers/basic.jl:137 [inlined]
[13] Pullback
@ ./Untitled-1:5 [inlined]
[14] (::typeof(∂(Discriminator)))(Δ::NamedTuple{(:layers,), Tuple{Tuple{NamedTuple{(:weight, :bias, :σ), Tuple{Matrix{Float32}, Vector{Float32}, Nothing}}}}})
@ Zygote ~/.julia/packages/Zygote/kUefI/src/compiler/interface2.jl:0
[15] Pullback
@ ./Untitled-1:71 [inlined]
[16] (::typeof(∂(#13)))(Δ::Float32)
@ Zygote ~/.julia/packages/Zygote/kUefI/src/compiler/interface2.jl:0
[17] (::Zygote.var"#55#56"{typeof(∂(#13))})(Δ::Float32)
@ Zygote ~/.julia/packages/Zygote/kUefI/src/compiler/interface.jl:41
[18] gradient(f::Function, args::B)
@ Zygote ~/.julia/packages/Zygote/kUefI/src/compiler/interface.jl:76
[19] top-level scope
@ Untitled-1:71
in expression starting at Untitled-1:71
I suspect the problem is due to the discriminator. I just wonder whether there is a way of getting around the mutating array in this situation. I have read about similar problems with mutating arrays and I understand what causes the problem if I define the function. But in this case the the definition of the discriminator is done though flux for which I do not have control. Any help/suggestions would be greatly appreciated.