Flux.jl: Convolutional VAE throws error after upgrading to 1.4.1/10.4

OK, I think I managed to get both of these errors, and I can make the code run. I changed the code some to make it a little easier to debug, but I do not understand VAE’s well enough to know if the functions I defined make sense in the context of the problem.

Without using DistributionsAD, I get

 ERROR: MethodError: no method matching Irrational{:log2π}(::Int64)

Inspired by this, I added using DistributionsAD, and now I see

ERROR: UndefRefError: access to undefined reference

I think this might be a bug, but I don’t understand what is happening yet.

The other puzzling thing about this code is the need for all the Float64 conversions. They were due to Normal(1, 0) being defined in terms of Float64 because it uses a literal 0. After sorting out the few hard-coded Float64s, here is something that seems to work.

module FLUXVAE                                                                   
                                                                                 
using Flux                                                                       
using Flux: @epochs, binarycrossentropy                                          
using DistributionsAD                                                            
using Distributions                                                              
                                                                                 
# dummy data                                                                     
function dummy_data()                                                            
    d = Array{Float32}(zeros((796, 512, 1, 10))) .+ 1                            
    batches = [reshape(d[:,:,:, i:i+4], (796, 512, 1, 5)) for i in 1:5]          
end                                                                              
                                                                                 
struct Reshape                                                                   
    shape                                                                        
end                                                                              
                                                                                 
Reshape(args...) = Reshape(args)                                                 
                                                                                 
(r::Reshape)(x) = reshape(x, r.shape)                                            
                                                                                 
Flux.@functor Reshape ()                                                         
                                                                                 
                                                                                 
# convolutional encoder                                                          
function encoder()                                                               
    conv1 = Conv((14, 10), 1 => 4, relu, stride = (10, 10), pad = 4)             
    pool1 = MaxPool((8, 8), stride = (4, 4), pad = 2)                            
    conv2 = Conv((4, 3), 4 => 4, stride = (2, 2), pad = 1)                       
    res = Reshape(280, :)                                                        
    # enc1(X) = reshape(conv2(pool1(conv1(X))), (280, :))                        
    # Chain(res, conv2, pool1, conv1)                                            
    Chain(conv1, pool1, conv2, res)                                              
end                                                                              
                                                                                 
# decoder, I am using the one with transposed convolutions                       
function decoder(;dense_decoder = false)                                         
    if dense_decoder                                                             
        dec = Dense(4, 796*512, sigmoid)                                         
        dec1(X) = reshape(dec(X),  (796, 512, 1, :))                             
    else                                                                         
        interaction1 = Dense(4, 280)  # specific to my setup                     
        res = Reshape(10, 7, 4, :)                                               
        # int1(X) = reshape(interaction1(X), (10, 7, 4, :))                      
        tc1 = ConvTranspose((4, 3), 4 => 4, relu, stride = (2, 2), pad = 1)         
        tc2 = ConvTranspose((8, 8), 4 => 4, relu, stride = (4, 4), pad = 2)         
        tc3 = ConvTranspose((14, 10), 4 => 1, sigmoid, stride = (10, 10), pad = 4)
        dec = Chain(interaction1, tc1, tc2, tc3) # for params                    
        dec1 = Chain(interaction1, res, tc1, tc2, tc3)                           
    end                                                                          
    return (dec, dec1)                                                           
end                                                                              
                                                                                 
# sample from z-distribution                                                     
z(μ::T, logσ) where {T} = μ + exp(logσ) * randn(T)                               
z(μ, logσ, eps) = μ + exp(logσ) * eps                                            
                                                                                 
# log(p(x|z)), log(p(z)), log(q(z|x))                                            
logp_x_z1(X, z, dec1) = -sum(binarycrossentropy.(dec1(z), X))                    
logp_z(z::AbstractArray{T}) where {T} = sum((logpdf.(Normal(zero(T), one(T)), z)))
log_q_z_x(ϵ, log_sigma) = logpdf(Normal(zero(ϵ), one(ϵ)), ϵ) - log_sigma         
                                                                                 
# vae loss estimator                                                             
function vae_loss(enc1, dec1, μ1, logσ1)                                         
    mu(X) = μ1(enc1(X))                                                          
    l(X) =  logσ1(enc1(X))                                                       
    e(X) = randn(eltype(X), size(l(X))) # latentdim1                             
    z_(X) = z.(mu(X), l(X), e(X))                                                
    return X->-(logp_x_z1(X, z_(X), dec1) + logp_z(z_(X)) - sum(log_q_z_x.(e(X), l(X)))) * 1//5
end                                                                              
                                                                                 
# train vae1                                                                     
function train!()                                                                
    enc1 = encoder()                                                             
    dec, dec1 = decoder()                                                        
    # mean and log-variance of vae1's z-variable/latent space                    
    μ1 = Dense(280, 4)                                                           
    logσ1 = Dense(280, 4)                                                        
    L1 = vae_loss(enc1, dec1, μ1, logσ1)                                         
    ps1 = Flux.params(enc1, μ1, logσ1, dec1)                                     
    batches = dummy_data()                                                       
    @epochs 3 Flux.train!(L1, ps1, zip(batches), ADAM())                         
end                                                                              
                                                                                 
end
2 Likes