OK, I think I managed to get both of these errors, and I can make the code run. I changed the code some to make it a little easier to debug, but I do not understand VAE’s well enough to know if the functions I defined make sense in the context of the problem.
Without using DistributionsAD
, I get
ERROR: MethodError: no method matching Irrational{:log2π}(::Int64)
Inspired by this, I added using DistributionsAD
, and now I see
ERROR: UndefRefError: access to undefined reference
I think this might be a bug, but I don’t understand what is happening yet.
The other puzzling thing about this code is the need for all the Float64
conversions. They were due to Normal(1, 0)
being defined in terms of Float64
because it uses a literal 0
. After sorting out the few hard-coded Float64
s, here is something that seems to work.
module FLUXVAE
using Flux
using Flux: @epochs, binarycrossentropy
using DistributionsAD
using Distributions
# dummy data
function dummy_data()
d = Array{Float32}(zeros((796, 512, 1, 10))) .+ 1
batches = [reshape(d[:,:,:, i:i+4], (796, 512, 1, 5)) for i in 1:5]
end
struct Reshape
shape
end
Reshape(args...) = Reshape(args)
(r::Reshape)(x) = reshape(x, r.shape)
Flux.@functor Reshape ()
# convolutional encoder
function encoder()
conv1 = Conv((14, 10), 1 => 4, relu, stride = (10, 10), pad = 4)
pool1 = MaxPool((8, 8), stride = (4, 4), pad = 2)
conv2 = Conv((4, 3), 4 => 4, stride = (2, 2), pad = 1)
res = Reshape(280, :)
# enc1(X) = reshape(conv2(pool1(conv1(X))), (280, :))
# Chain(res, conv2, pool1, conv1)
Chain(conv1, pool1, conv2, res)
end
# decoder, I am using the one with transposed convolutions
function decoder(;dense_decoder = false)
if dense_decoder
dec = Dense(4, 796*512, sigmoid)
dec1(X) = reshape(dec(X), (796, 512, 1, :))
else
interaction1 = Dense(4, 280) # specific to my setup
res = Reshape(10, 7, 4, :)
# int1(X) = reshape(interaction1(X), (10, 7, 4, :))
tc1 = ConvTranspose((4, 3), 4 => 4, relu, stride = (2, 2), pad = 1)
tc2 = ConvTranspose((8, 8), 4 => 4, relu, stride = (4, 4), pad = 2)
tc3 = ConvTranspose((14, 10), 4 => 1, sigmoid, stride = (10, 10), pad = 4)
dec = Chain(interaction1, tc1, tc2, tc3) # for params
dec1 = Chain(interaction1, res, tc1, tc2, tc3)
end
return (dec, dec1)
end
# sample from z-distribution
z(μ::T, logσ) where {T} = μ + exp(logσ) * randn(T)
z(μ, logσ, eps) = μ + exp(logσ) * eps
# log(p(x|z)), log(p(z)), log(q(z|x))
logp_x_z1(X, z, dec1) = -sum(binarycrossentropy.(dec1(z), X))
logp_z(z::AbstractArray{T}) where {T} = sum((logpdf.(Normal(zero(T), one(T)), z)))
log_q_z_x(ϵ, log_sigma) = logpdf(Normal(zero(ϵ), one(ϵ)), ϵ) - log_sigma
# vae loss estimator
function vae_loss(enc1, dec1, μ1, logσ1)
mu(X) = μ1(enc1(X))
l(X) = logσ1(enc1(X))
e(X) = randn(eltype(X), size(l(X))) # latentdim1
z_(X) = z.(mu(X), l(X), e(X))
return X->-(logp_x_z1(X, z_(X), dec1) + logp_z(z_(X)) - sum(log_q_z_x.(e(X), l(X)))) * 1//5
end
# train vae1
function train!()
enc1 = encoder()
dec, dec1 = decoder()
# mean and log-variance of vae1's z-variable/latent space
μ1 = Dense(280, 4)
logσ1 = Dense(280, 4)
L1 = vae_loss(enc1, dec1, μ1, logσ1)
ps1 = Flux.params(enc1, μ1, logσ1, dec1)
batches = dummy_data()
@epochs 3 Flux.train!(L1, ps1, zip(batches), ADAM())
end
end