Probabilistic shaping with flux (Telecommunication)

Hi, trying to implement the following paper: https://arxiv.org/abs/1906.07748

already got a working version in TF: https://github.com/Rassibassi/claude/blob/master/examples/tf_AutoEncoderForProbabilisticShapingAndAwgn.ipynb

In Julia, somehow the gradient becomes complex :frowning:

Here my implementation and the error message I cannot decipher. I cropped the error message, as it was too long to post it here.

using Flux, Zygote
using Flux: @nograd, @adjoint, @epochs, onehotbatch, throttle
using Distributions
using Distributions: Gumbel
using Statistics: mean, var, std
using Plots

add_dim(x::Array) = reshape(x, (1,size(x)...))

function qammod(M)
    r = 1:sqrt(M)
    r = 2 * (r .- mean(r))
    r = [i for i in r, j in r]
    constellation = vcat(complex.(r, r')...)
    norm = sqrt(mean(abs2.(r)))
    constellation / norm
end

p_norm(p, x, fun) = sum(p .* fun(x))

function straight_through_estimator(x, M)
    min_idx = argmin.(eachcol(x))
    Flux.onehotbatch(min_idx,1:M)
end
# forward function, backward identity
@adjoint straight_through_estimator(x, M) = straight_through_estimator(x, M), identity

TR = Float32
TC = ComplexF32

M = 16
constellation_dim = 2
N = 4 * M
temperature = 1.
SNR = 20
SNRlin = 10^(SNR/10) |> TR

constellation = qammod(M)
encoder = Chain(Dense(1, 128, Flux.relu), Dense(128, M));
decoder = Chain(Dense(constellation_dim, 128, Flux.relu), Dense(128, 128, Flux.relu), Dense(128, M));

function gumbel_sample(M, N)
    g_dist = Gumbel()
    g = rand(g_dist, (M, N))
    return g
end
@nograd gumbel_sample

function model()
    # sample from discrete distribution
    s_logits = encoder(ones(1, N))
    g_dist = Gumbel()
    g = gumbel_sample(M, N)
    s_bar = Flux.softmax((g + s_logits) / temperature)
    s = straight_through_estimator(s_bar, M)
    p_s = Flux.softmax(s_logits)
    norm_factor = sqrt(p_norm(p_s, constellation, x -> abs2.(x)))
    norm_constellation = constellation / norm_factor
    x = add_dim(norm_constellation) * s

    # Channel
    šœŽ = sqrt(1 / SNRlin) |> TR
    r = x + šœŽ * randn(TC, 1, N)
    r = [real(r); imag(r)]

    # decoder
    Y = decoder(r)
    return p_s, s, Y
end

function loss()
    p_s, s, Y = model()
    loss = Flux.logitcrossentropy(Y, s)
    return loss + p_norm(p_s, p_s, x -> log2.(x))
end

@show loss()

ps = Flux.params(encoder, decoder);
gs = gradient(ps) do
    loss()
end

@show gs[encoder[1].W]

Output:

loss() = -61.019833f0
ERROR: LoadError: Gradient Complex{Float64}[0.17417264688378736 + 0.001395718790090191im 0.17262654417874682 + 0.004565566384023881im 0.17441725737440028 [...] ] should be a tuple
Stacktrace:
 [1] gradtuple1(::Array{Complex{Float64},2}) at /home/rasmus/.julia/packages/ZygoteRules/6nssF/src/adjoint.jl:14
 [2] (::var"#240#back#11"{typeof(identity)})(::Array{Complex{Float64},2}) at /home/rasmus/.julia/packages/ZygoteRules/6nssF/src/adjoint.jl:49
 [3] model at /home/rasmus/julia/jjaude/ProbShaping.jl:55 [inlined]
 [4] [...]
 [5] loss at /home/rasmus/julia/jjaude/ProbShaping.jl:72 [inlined]
 [6] [...]
 [7] #19 at /home/rasmus/julia/jjaude/ProbShaping.jl:81 [inlined]
 [8] [...]
 [9] gradient(::Function, ::Params) at /home/rasmus/.julia/packages/Zygote/8dVxG/src/compiler/interface.jl:47
 [10] top-level scope at /home/rasmus/julia/jjaude/ProbShaping.jl:80
in expression starting at /home/rasmus/julia/jjaude/ProbShaping.jl:80

Thanks,
Rasmus

Iā€™m a step closer, it runs now. But it seems like the encoder is not optimizing :frowning:

using Flux, Zygote
using Flux: @nograd, @adjoint, @epochs, onehotbatch, throttle
using Distributions
using Distributions: Gumbel
using Statistics: mean
using Plots

add_dim(x::Array) = reshape(x, (1,size(x)...))

function qammod(M)
    r = 1:sqrt(M)
    r = 2 * (r .- mean(r))
    r = [i for i in r, j in r]
    constellation = vcat(complex.(r, r')...)
    norm = sqrt(mean(abs2.(r)))
    constellation / norm
end

p_norm(p, x, fun) = sum(p .* fun(x))

function straight_through_estimator(x)
    M = size(x)[1]
    min_idx = argmin.(eachcol(x))
    Flux.onehotbatch(min_idx, 1:M)
end
# forward function, backward identity
@adjoint straight_through_estimator(x) = straight_through_estimator(x), x -> (x,)

TR = Float32
TC = ComplexF32

M = 64
constellation_dim = 2
N = 32 * M
temperature = 1.
SNR = 15
SNRlin = 10^(SNR/10) |> TR

nHidden = 128

constellation = qammod(M)
g_dist = Gumbel()

encoder = Chain(Dense(1, nHidden, Flux.relu),
                Dense(nHidden, M));
decoder = Chain(Dense(constellation_dim, nHidden, Flux.relu),
                Dense(nHidden, nHidden, Flux.relu),
                Dense(nHidden, M));

function gumbel_sample(M, N)
    rand(g_dist, (M, N))
end

function model(X)
    # sample from discrete distribution
    s_logits = encoder(X)
    g = gumbel_sample(M, N)
    s_bar = Flux.softmax((g .+ s_logits) / temperature)
    s = straight_through_estimator(s_bar)
    p_s = Flux.softmax(s_logits)

    # modulation
    norm_factor = sqrt(p_norm(p_s, constellation, x -> abs2.(x)))
    norm_constellation = constellation / complex.(norm_factor, 0)
    x = add_dim(norm_constellation) * complex.(s, 0)

    # Channel
    šœŽ = sqrt(1 / SNRlin) |> TR
    r = x + šœŽ * randn(TC, 1, N)
    r = [real(r); imag(r)]

    # decoder
    Y = decoder(r)
    return p_s, s, Y, x, norm_constellation
end

stop_gradient(x) = x
@nograd stop_gradient

function loss(X)
    p_s, s, Y, x, norm_constellation = model(X)
    logit_loss = Flux.logitcrossentropy(Y, stop_gradient(s))
    entropy_x = -p_norm(p_s, p_s, x -> log2.(x))
    logit_loss - entropy_x
end

X = ones(1, 1)

@show loss(X)

ps = Flux.params(encoder, decoder);

# gs = gradient(ps) do
#     loss(X)
# end
#
# @show gs[encoder[1].W]

opt = ADAM(0.001);
data = [[X]]
evalcb() = @show(loss(X));
@epochs 500 Flux.train!(loss, ps, data, opt, cb = throttle(evalcb, 5));

p_s, s, Y, x, norm_constellation = model(X)

scatter(real(norm_constellation), imag(norm_constellation), aspect_ratio = :equal,
                                                            markershape = :circle,
                                                            markersize = 500*p_s,
                                                            markerstrokealpha = 0)
lim_ = 1.5
ylims!((-lim_,lim_))
xlims!((-lim_,lim_))

@show p_s