Flux.jl: Convolutional VAE throws error after upgrading to 1.4.1/10.4

Hello everyone,

I am fairly new to Julia and even newer to this forum. I hope, this is the right subforum for my problem. Currently, I am working on a project that involves a Convolutional VAE which I implemented using Flux.

So far, everything ran smoothly until I upgraded from Julia-1.2.0 (Flux-0.9.0) to 1.4.1 (Flux-10.4).

I tried to extract a minimal example for the error from my current model, which you can see below (its still quiet a lot of code though, so please excuse):

  using Flux
  using Flux: @epochs, binarycrossentropy
  using Distributions

  # dummy data
  d = Array{Float64}(zeros((796, 512, 1, 10))) .+ 1
  batches = [reshape(d[:,:,:, i:i+4], (796, 512, 1, 5)) for i in 1:5]

  # convolutional encoder
  conv1 = Conv((14, 10), 1 => 4, relu, stride = (10, 10), pad = 4)
  pool1 = MaxPool((8, 8), stride = (4, 4), pad = 2)
  conv2 = Conv((4, 3), 4 => 4, stride = (2, 2), pad = 1)
  enc1(X) = reshape(conv2(pool1(conv1(X))), (280, :))

  # mean and log-variance of vae1's z-variable/latent space
  μ1 = Dense(280, 4)
  logσ1 = Dense(280, 4)
  # sample from z-distribution
  z(μ, logσ) = μ + exp(logσ) * randn(Float64)
  z(μ, logσ, eps) = μ + exp(logσ) * eps

  # decoder, I am using the one with transposed convolutions
  dense_decoder = false  # change accordingly
  if dense_decoder
      dec = Dense(4, 796*512, sigmoid)
      dec1(X) = reshape(dec(X),  (796, 512, 1, :))
  else
      interaction1 = Dense(4, 280)  # specific to my setup
      int1(X) = reshape(interaction1(X), (10, 7, 4, :))
      tc1 = ConvTranspose((4, 3), 4 => 4, relu, stride = (2, 2), pad = 1)
      tc2 = ConvTranspose((8, 8), 4 => 4, relu, stride = (4, 4), pad = 2)
      tc3 = ConvTranspose((14, 10), 4 => 1, sigmoid, stride = (10, 10), pad = 4)
      dec = Chain(interaction1, tc1, tc2, tc3) # for params
      dec1 = Chain(int1, tc1, tc2, tc3)
  end

  # log(p(x|z)), log(p(z)), log(q(z|x))
  logp_x_z1(X, z) = -sum(binarycrossentropy.(dec1(z), X))
  logp_z(z) = sum(Float64.(logpdf.(Normal(0, 1), z)))
  log_q_z_x(ϵ, log_sigma) = Float64(logpdf(Normal(0, 1), ϵ) - log_sigma)

  # vae loss estimator
  function L1(X)
      output_enc = enc1(X)
      mu, l = μ1(output_enc), logσ1(output_enc)
      e = randn(Float64, size(l)) # latentdim1
      z_ = z.(mu, l, e)
      return -(logp_x_z1(X, z_) + logp_z(z_) - sum(log_q_z_x.(e, l))) * 1//5
  end

  # train vae1
  ps1 = Flux.params(enc1, μ1, logσ1, dec)
  @epochs 3 Flux.train!(L1, ps1, zip(batches), ADAM())

In Julia-1.2.0 this runs perfectly, upgrading to 1.4.1 the code yields the following error:

UndefRefError: access to undefined reference
in top-level scope at Juno/f8hj2/src/progress.jl:119
in macro expansion at Flux/Fj3bt/src/optimise/train.jl:122
in train! at Flux/Fj3bt/src/optimise/train.jl:79
in #train!#12 at Flux/Fj3bt/src/optimise/train.jl:81
in macro expansion at Juno/f8hj2/src/progress.jl:119 
in macro expansion at Flux/Fj3bt/src/optimise/train.jl:88 
in gradient at Zygote/YeCEW/src/compiler/interface.jl:55
in  at Zygote/YeCEW/src/compiler/interface.jl:179
in  at Zygote/YeCEW/src/compiler/interface2.jl
in #15 at Flux/Fj3bt/src/optimise/train.jl:89 
in #347#back at ZygoteRules/6nssF/src/adjoint.jl:49 
in #174 at Zygote/YeCEW/src/lib/lib.jl:182 
in  at Zygote/YeCEW/src/compiler/interface2.jl
in L1 at hello.jl:48 
in  at Zygote/YeCEW/src/compiler/interface2.jl
in logp_x_z1 at hello.jl:38 
in  at Zygote/YeCEW/src/compiler/interface2.jl
in Chain at Flux/Fj3bt/src/layers/basic.jl:38 
in  at Zygote/YeCEW/src/compiler/interface2.jl
in applychain at Flux/Fj3bt/src/layers/basic.jl:36 
in  at Zygote/YeCEW/src/compiler/interface2.jl
in applychain at Flux/Fj3bt/src/layers/basic.jl:36 
in  at Zygote/YeCEW/src/compiler/interface2.jl
in applychain at Flux/Fj3bt/src/layers/basic.jl:36 
in  at Zygote/YeCEW/src/compiler/interface2.jl
in ConvTranspose at Flux/Fj3bt/src/layers/conv.jl:148 
in  at ZygoteRules/6nssF/src/adjoint.jl:49
in #1837 at Zygote/YeCEW/src/lib/nnlib.jl:41 
in conv at NNlib/FAI3o/src/conv.jl:114 
in #conv#89 at NNlib/FAI3o/src/conv.jl:116
in conv! at NNlib/FAI3o/src/conv.jl:70 
in #conv!#48 at NNlib/FAI3o/src/conv.jl:70
in conv! at NNlib/FAI3o/src/conv.jl:97
in #conv!#83 at NNlib/FAI3o/src/conv.jl:99
in conv_direct! at NNlib/FAI3o/src/impl/conv_direct.jl:51 
in #conv_direct!#149 at NNlib/FAI3o/src/impl/conv_direct.jl:98
in getindex at base/array.jl:789 

So far, I’ve narrowed the problem down to the decoder-part of the network. In the code, I have included a bit that lets you switch the convolutional decoder to a simple dense layer. With that, the training works as expected. Additionally, I have tried simpler networks with only one layer of transposed convolution in the decoder and lower-dimensional data which also ran smoothly. Being a noob in Julia, I failed miserably trying to find what causes the error in the package code.

Having spent quiet a few days, I can’t figure out what the problem might possibly be. Any help or feedback would be highly appreciated! :pray:

Best,
Flo

I don’t exactly have an answer, but maybe this can help: Flux 0.10 made quite a lot of changes, it is now using Zygote for computing gradients, so you’d have to update your usage based on the documentation for Zygote.

If your code already works with Flux 0.9, it might not yet be worth upgrading Flux to 0.10. Just call pkg"add Flux@v0.9" to downgrade.

1 Like

Thank you very much for your reply!

I am aware that they switched from Tracker to Zygote in 0.10 and you’re right, that’s probably part of the problem here. Should have noted though that when I tried downgrading to Flux 0.9, other compatibility issues appeared. So that is not an option. I should also mention that I cannot stay on Julia 1.2.0 with this project because the code will be merged to another project that runs the current version. Hence the headache with finding a new workaround :sweat_smile:

I can’t reproduce the problem. When running the code you posted above I’m getting:

julia> @epochs 3 Flux.train!(L1, ps1, zip(batches), ADAM())
[ Info: Epoch 1
ERROR: MethodError: no method matching Irrational{:log2π}(::Int64)
Closest candidates are:
  Irrational{:log2π}(::T) where T<:Number at boot.jl:715
  Irrational{:log2π}() where sym at irrationals.jl:18
  Irrational{:log2π}(::Complex) where T<:Real at complex.jl:37
  ...

Are you sure the code above reproduces your problem? I just copied and pasted into a new julia session…

I’m using

  [31c24e10] Distributions v0.23.2
  [587475ba] Flux v0.10.4
  [e88e6eb3] Zygote v0.4.20
1 Like

Thank you very much for your reply. I was running the code from Atom/Juno IDE, where everything runs as described above (just checked again). From the message you’re posting I am gathering that you pasted the code sample into the standard Julia shell which I tested as well. In the shell I get the same error message. Not sure, what causes this new problem, as both the shell and Atom use the same Julia binary. Installed packages are:

[c52e3926] Atom v0.12.10
  [3895d2a7] CUDAapi v4.0.0
  [3a865a2d] CuArrays v2.2.0
  [31c24e10] Distributions v0.23.2
  [587475ba] Flux v0.10.4
  [f67ccb44] HDF5 v0.13.2
  [916415d5] Images v0.22.2
  [c8e1da08] IterTools v1.3.0
  [033835bb] JLD2 v0.1.13
  [e5e0dc1b] Juno v0.8.1
  [eb30cadb] MLDatasets v0.5.2
  [d96e819e] Parameters v0.12.1
  [e88e6eb3] Zygote v0.4.20
  [9a3f8284] Random 

Currently on a run, but I’ll try to figure out the problem by tomorrow and provide an example that also works outside of Atom. Sorry for the inconvenience, guess I didn’t expect things to work differently outside of my IDE.

To be honest, I’m a bit puzzled why there’s an Int64 appearing in the error message. The second part I’m not understanding is why the method isn’t resolved. The first candidate is

Irrational{:log2π}(::T) where T<:Number at boot.jl:715

Int64 <: Number is true, so this is weird. Sorry, I’m throwing out more questions than answers :thinking:

1 Like

Same for me actually, seems like this is another weird behaviour, that updating Julia is introducing. On 1.2.0/0.9 I don’t have problems running the code from shell. Thanks for trying to help though, guess I’ll have no choice but to stick to my old setup and keep my thumbs pressed for future updates :frowning_face:

You don’t have to give up so quickly.
I would suggest trying to get help at slack, and you can file a bug report. Since you have a way to reproduce the problem, the experts should be able to figure this out hopefully quickly.

2 Likes

OK, I think I managed to get both of these errors, and I can make the code run. I changed the code some to make it a little easier to debug, but I do not understand VAE’s well enough to know if the functions I defined make sense in the context of the problem.

Without using DistributionsAD, I get

 ERROR: MethodError: no method matching Irrational{:log2π}(::Int64)

Inspired by this, I added using DistributionsAD, and now I see

ERROR: UndefRefError: access to undefined reference

I think this might be a bug, but I don’t understand what is happening yet.

The other puzzling thing about this code is the need for all the Float64 conversions. They were due to Normal(1, 0) being defined in terms of Float64 because it uses a literal 0. After sorting out the few hard-coded Float64s, here is something that seems to work.

module FLUXVAE                                                                   
                                                                                 
using Flux                                                                       
using Flux: @epochs, binarycrossentropy                                          
using DistributionsAD                                                            
using Distributions                                                              
                                                                                 
# dummy data                                                                     
function dummy_data()                                                            
    d = Array{Float32}(zeros((796, 512, 1, 10))) .+ 1                            
    batches = [reshape(d[:,:,:, i:i+4], (796, 512, 1, 5)) for i in 1:5]          
end                                                                              
                                                                                 
struct Reshape                                                                   
    shape                                                                        
end                                                                              
                                                                                 
Reshape(args...) = Reshape(args)                                                 
                                                                                 
(r::Reshape)(x) = reshape(x, r.shape)                                            
                                                                                 
Flux.@functor Reshape ()                                                         
                                                                                 
                                                                                 
# convolutional encoder                                                          
function encoder()                                                               
    conv1 = Conv((14, 10), 1 => 4, relu, stride = (10, 10), pad = 4)             
    pool1 = MaxPool((8, 8), stride = (4, 4), pad = 2)                            
    conv2 = Conv((4, 3), 4 => 4, stride = (2, 2), pad = 1)                       
    res = Reshape(280, :)                                                        
    # enc1(X) = reshape(conv2(pool1(conv1(X))), (280, :))                        
    # Chain(res, conv2, pool1, conv1)                                            
    Chain(conv1, pool1, conv2, res)                                              
end                                                                              
                                                                                 
# decoder, I am using the one with transposed convolutions                       
function decoder(;dense_decoder = false)                                         
    if dense_decoder                                                             
        dec = Dense(4, 796*512, sigmoid)                                         
        dec1(X) = reshape(dec(X),  (796, 512, 1, :))                             
    else                                                                         
        interaction1 = Dense(4, 280)  # specific to my setup                     
        res = Reshape(10, 7, 4, :)                                               
        # int1(X) = reshape(interaction1(X), (10, 7, 4, :))                      
        tc1 = ConvTranspose((4, 3), 4 => 4, relu, stride = (2, 2), pad = 1)         
        tc2 = ConvTranspose((8, 8), 4 => 4, relu, stride = (4, 4), pad = 2)         
        tc3 = ConvTranspose((14, 10), 4 => 1, sigmoid, stride = (10, 10), pad = 4)
        dec = Chain(interaction1, tc1, tc2, tc3) # for params                    
        dec1 = Chain(interaction1, res, tc1, tc2, tc3)                           
    end                                                                          
    return (dec, dec1)                                                           
end                                                                              
                                                                                 
# sample from z-distribution                                                     
z(μ::T, logσ) where {T} = μ + exp(logσ) * randn(T)                               
z(μ, logσ, eps) = μ + exp(logσ) * eps                                            
                                                                                 
# log(p(x|z)), log(p(z)), log(q(z|x))                                            
logp_x_z1(X, z, dec1) = -sum(binarycrossentropy.(dec1(z), X))                    
logp_z(z::AbstractArray{T}) where {T} = sum((logpdf.(Normal(zero(T), one(T)), z)))
log_q_z_x(ϵ, log_sigma) = logpdf(Normal(zero(ϵ), one(ϵ)), ϵ) - log_sigma         
                                                                                 
# vae loss estimator                                                             
function vae_loss(enc1, dec1, μ1, logσ1)                                         
    mu(X) = μ1(enc1(X))                                                          
    l(X) =  logσ1(enc1(X))                                                       
    e(X) = randn(eltype(X), size(l(X))) # latentdim1                             
    z_(X) = z.(mu(X), l(X), e(X))                                                
    return X->-(logp_x_z1(X, z_(X), dec1) + logp_z(z_(X)) - sum(log_q_z_x.(e(X), l(X)))) * 1//5
end                                                                              
                                                                                 
# train vae1                                                                     
function train!()                                                                
    enc1 = encoder()                                                             
    dec, dec1 = decoder()                                                        
    # mean and log-variance of vae1's z-variable/latent space                    
    μ1 = Dense(280, 4)                                                           
    logσ1 = Dense(280, 4)                                                        
    L1 = vae_loss(enc1, dec1, μ1, logσ1)                                         
    ps1 = Flux.params(enc1, μ1, logσ1, dec1)                                     
    batches = dummy_data()                                                       
    @epochs 3 Flux.train!(L1, ps1, zip(batches), ADAM())                         
end                                                                              
                                                                                 
end
2 Likes

Thank you both again for trying to help!

@contradict guess we should have kept on googling a bit more. That should solve our second problem. Also thank you for refactoring the code (I just learned new Julia syntax :wink: ). Also, you’re right about the Float64 conversions, your version handles that much more elegantly. Actually, I added those only after upgrading to 1.4.1 because of some other type-issues.

Coming to the most important part of your reply, I can confirm that this solution also works on my machine! I think I’ve made this clear by now, but this helps A LOT :partying_face: :
I went through your version and everything seems to be implemented correctly from an algorithmic perspective as well. Still not sure, which of your alterations actually did the trick, but I am fine with that for now. Thank you again.

Do you think this might still be worth filing a report for?

Sure, why not.
If it’s invalid, it’s easy to close, and if it’s a bug, you don’t want other people to run into it.

The changes I made to the Reshape stuff were to get rid of the dec/dec1 distinction, I didn’t finish cleaning that up but it also turned out not to be a problem. If you don’t like that syntax, there is no need to keep those.

Converting everything to functions and not using global variables is what I meant by making it easier to debug, those are worth keeping as kindness to your future self.

The changes I made that actually made the program work were the ones that eliminated all the Float64 types. Specifically, making all the terms of the loss function copy their input types to their output by declaring all constants to be the same type as the inputs and then just feeding the correct input type from dummy_data.

You could submit a bug report, I was trying to reduce the size of the example some before doing that but it might already be obvious to one of the Flux developers what has gone wrong. In the new code I wrote, if you switch the data type in dummy_data to Float64, the ERROR: UndefRefError: access to undefined reference comes back. Following the stack trace, this seems to be a slow path in Flux that uses a non-optimized convolution that can operate on any array type. When you pass a Float32, there is an optimized routine for that data type that does work.

1 Like

I reported it here.

2 Likes

Thank you for the clarification! I managed to apply your solution to my project, and after another weekend of debugging and headache it finally runs on 1.4.1/10.4 as well.

In case anybody else is interested in the code above, there is a small algorithmic error in the new loss function: In the new version, e(X) is being computed twice, one time to sample from the latent variable in z_(X) and another time to compute log(q(z|x)) in log_q_z_x.(e(X)), l(X)). The correct behavior is to sample once and use the same values for both computations. Of course, this is easy to fix by nesting it in another function and thereby keeping the functional approach.

Cheers again and have a great Sunday!