# Probabilistic shaping with flux (Telecommunication)

Hi, trying to implement the following paper: https://arxiv.org/abs/1906.07748

already got a working version in TF: https://github.com/Rassibassi/claude/blob/master/examples/tf_AutoEncoderForProbabilisticShapingAndAwgn.ipynb

In Julia, somehow the gradient becomes complex

Here my implementation and the error message I cannot decipher. I cropped the error message, as it was too long to post it here.

``````using Flux, Zygote
using Distributions
using Distributions: Gumbel
using Statistics: mean, var, std
using Plots

function qammod(M)
r = 1:sqrt(M)
r = 2 * (r .- mean(r))
r = [i for i in r, j in r]
constellation = vcat(complex.(r, r')...)
norm = sqrt(mean(abs2.(r)))
constellation / norm
end

p_norm(p, x, fun) = sum(p .* fun(x))

function straight_through_estimator(x, M)
min_idx = argmin.(eachcol(x))
Flux.onehotbatch(min_idx,1:M)
end
# forward function, backward identity
@adjoint straight_through_estimator(x, M) = straight_through_estimator(x, M), identity

TR = Float32
TC = ComplexF32

M = 16
constellation_dim = 2
N = 4 * M
temperature = 1.
SNR = 20
SNRlin = 10^(SNR/10) |> TR

constellation = qammod(M)
encoder = Chain(Dense(1, 128, Flux.relu), Dense(128, M));
decoder = Chain(Dense(constellation_dim, 128, Flux.relu), Dense(128, 128, Flux.relu), Dense(128, M));

function gumbel_sample(M, N)
g_dist = Gumbel()
g = rand(g_dist, (M, N))
return g
end

function model()
# sample from discrete distribution
s_logits = encoder(ones(1, N))
g_dist = Gumbel()
g = gumbel_sample(M, N)
s_bar = Flux.softmax((g + s_logits) / temperature)
s = straight_through_estimator(s_bar, M)
p_s = Flux.softmax(s_logits)
norm_factor = sqrt(p_norm(p_s, constellation, x -> abs2.(x)))
norm_constellation = constellation / norm_factor

# Channel
š = sqrt(1 / SNRlin) |> TR
r = x + š * randn(TC, 1, N)
r = [real(r); imag(r)]

# decoder
Y = decoder(r)
return p_s, s, Y
end

function loss()
p_s, s, Y = model()
loss = Flux.logitcrossentropy(Y, s)
return loss + p_norm(p_s, p_s, x -> log2.(x))
end

@show loss()

ps = Flux.params(encoder, decoder);
loss()
end

@show gs[encoder[1].W]
``````

Output:

``````loss() = -61.019833f0
ERROR: LoadError: Gradient Complex{Float64}[0.17417264688378736 + 0.001395718790090191im 0.17262654417874682 + 0.004565566384023881im 0.17441725737440028 [...] ] should be a tuple
Stacktrace:
[3] model at /home/rasmus/julia/jjaude/ProbShaping.jl:55 [inlined]
[4] [...]
[5] loss at /home/rasmus/julia/jjaude/ProbShaping.jl:72 [inlined]
[6] [...]
[7] #19 at /home/rasmus/julia/jjaude/ProbShaping.jl:81 [inlined]
[8] [...]
[10] top-level scope at /home/rasmus/julia/jjaude/ProbShaping.jl:80
in expression starting at /home/rasmus/julia/jjaude/ProbShaping.jl:80
``````

Thanks,
Rasmus

Iām a step closer, it runs now. But it seems like the encoder is not optimizing

``````using Flux, Zygote
using Distributions
using Distributions: Gumbel
using Statistics: mean
using Plots

function qammod(M)
r = 1:sqrt(M)
r = 2 * (r .- mean(r))
r = [i for i in r, j in r]
constellation = vcat(complex.(r, r')...)
norm = sqrt(mean(abs2.(r)))
constellation / norm
end

p_norm(p, x, fun) = sum(p .* fun(x))

function straight_through_estimator(x)
M = size(x)[1]
min_idx = argmin.(eachcol(x))
Flux.onehotbatch(min_idx, 1:M)
end
# forward function, backward identity
@adjoint straight_through_estimator(x) = straight_through_estimator(x), x -> (x,)

TR = Float32
TC = ComplexF32

M = 64
constellation_dim = 2
N = 32 * M
temperature = 1.
SNR = 15
SNRlin = 10^(SNR/10) |> TR

nHidden = 128

constellation = qammod(M)
g_dist = Gumbel()

encoder = Chain(Dense(1, nHidden, Flux.relu),
Dense(nHidden, M));
decoder = Chain(Dense(constellation_dim, nHidden, Flux.relu),
Dense(nHidden, nHidden, Flux.relu),
Dense(nHidden, M));

function gumbel_sample(M, N)
rand(g_dist, (M, N))
end

function model(X)
# sample from discrete distribution
s_logits = encoder(X)
g = gumbel_sample(M, N)
s_bar = Flux.softmax((g .+ s_logits) / temperature)
s = straight_through_estimator(s_bar)
p_s = Flux.softmax(s_logits)

# modulation
norm_factor = sqrt(p_norm(p_s, constellation, x -> abs2.(x)))
norm_constellation = constellation / complex.(norm_factor, 0)
x = add_dim(norm_constellation) * complex.(s, 0)

# Channel
š = sqrt(1 / SNRlin) |> TR
r = x + š * randn(TC, 1, N)
r = [real(r); imag(r)]

# decoder
Y = decoder(r)
return p_s, s, Y, x, norm_constellation
end

function loss(X)
p_s, s, Y, x, norm_constellation = model(X)
entropy_x = -p_norm(p_s, p_s, x -> log2.(x))
logit_loss - entropy_x
end

X = ones(1, 1)

@show loss(X)

ps = Flux.params(encoder, decoder);

#     loss(X)
# end
#
# @show gs[encoder[1].W]

data = [[X]]
evalcb() = @show(loss(X));
@epochs 500 Flux.train!(loss, ps, data, opt, cb = throttle(evalcb, 5));

p_s, s, Y, x, norm_constellation = model(X)

scatter(real(norm_constellation), imag(norm_constellation), aspect_ratio = :equal,
markershape = :circle,
markersize = 500*p_s,
markerstrokealpha = 0)
lim_ = 1.5
ylims!((-lim_,lim_))
xlims!((-lim_,lim_))

@show p_s

``````