Hi, trying to implement the following paper: [1906.07748] Joint Learning of Geometric and Probabilistic Constellation Shaping
already got a working version in TF: https://github.com/Rassibassi/claude/blob/master/examples/tf_AutoEncoderForProbabilisticShapingAndAwgn.ipynb
In Julia, somehow the gradient becomes complex
Here my implementation and the error message I cannot decipher. I cropped the error message, as it was too long to post it here.
using Flux, Zygote
using Flux: @nograd, @adjoint, @epochs, onehotbatch, throttle
using Distributions
using Distributions: Gumbel
using Statistics: mean, var, std
using Plots
add_dim(x::Array) = reshape(x, (1,size(x)...))
function qammod(M)
r = 1:sqrt(M)
r = 2 * (r .- mean(r))
r = [i for i in r, j in r]
constellation = vcat(complex.(r, r')...)
norm = sqrt(mean(abs2.(r)))
constellation / norm
end
p_norm(p, x, fun) = sum(p .* fun(x))
function straight_through_estimator(x, M)
min_idx = argmin.(eachcol(x))
Flux.onehotbatch(min_idx,1:M)
end
# forward function, backward identity
@adjoint straight_through_estimator(x, M) = straight_through_estimator(x, M), identity
TR = Float32
TC = ComplexF32
M = 16
constellation_dim = 2
N = 4 * M
temperature = 1.
SNR = 20
SNRlin = 10^(SNR/10) |> TR
constellation = qammod(M)
encoder = Chain(Dense(1, 128, Flux.relu), Dense(128, M));
decoder = Chain(Dense(constellation_dim, 128, Flux.relu), Dense(128, 128, Flux.relu), Dense(128, M));
function gumbel_sample(M, N)
g_dist = Gumbel()
g = rand(g_dist, (M, N))
return g
end
@nograd gumbel_sample
function model()
# sample from discrete distribution
s_logits = encoder(ones(1, N))
g_dist = Gumbel()
g = gumbel_sample(M, N)
s_bar = Flux.softmax((g + s_logits) / temperature)
s = straight_through_estimator(s_bar, M)
p_s = Flux.softmax(s_logits)
norm_factor = sqrt(p_norm(p_s, constellation, x -> abs2.(x)))
norm_constellation = constellation / norm_factor
x = add_dim(norm_constellation) * s
# Channel
š = sqrt(1 / SNRlin) |> TR
r = x + š * randn(TC, 1, N)
r = [real(r); imag(r)]
# decoder
Y = decoder(r)
return p_s, s, Y
end
function loss()
p_s, s, Y = model()
loss = Flux.logitcrossentropy(Y, s)
return loss + p_norm(p_s, p_s, x -> log2.(x))
end
@show loss()
ps = Flux.params(encoder, decoder);
gs = gradient(ps) do
loss()
end
@show gs[encoder[1].W]
Output:
loss() = -61.019833f0
ERROR: LoadError: Gradient Complex{Float64}[0.17417264688378736 + 0.001395718790090191im 0.17262654417874682 + 0.004565566384023881im 0.17441725737440028 [...] ] should be a tuple
Stacktrace:
[1] gradtuple1(::Array{Complex{Float64},2}) at /home/rasmus/.julia/packages/ZygoteRules/6nssF/src/adjoint.jl:14
[2] (::var"#240#back#11"{typeof(identity)})(::Array{Complex{Float64},2}) at /home/rasmus/.julia/packages/ZygoteRules/6nssF/src/adjoint.jl:49
[3] model at /home/rasmus/julia/jjaude/ProbShaping.jl:55 [inlined]
[4] [...]
[5] loss at /home/rasmus/julia/jjaude/ProbShaping.jl:72 [inlined]
[6] [...]
[7] #19 at /home/rasmus/julia/jjaude/ProbShaping.jl:81 [inlined]
[8] [...]
[9] gradient(::Function, ::Params) at /home/rasmus/.julia/packages/Zygote/8dVxG/src/compiler/interface.jl:47
[10] top-level scope at /home/rasmus/julia/jjaude/ProbShaping.jl:80
in expression starting at /home/rasmus/julia/jjaude/ProbShaping.jl:80
Thanks,
Rasmus