Flux.jl: Error thrown during gradient calculation in conv VAE

Hi,

I’m just getting into Flux (and Julia more generally), but I’m finding it impossible to get rid of a very particular error in my convolutional VAE model. The code runs fine until the back(1f0) line where it throws ERROR: LoadError: UndefRefError: access to undefined reference.

I suspect this might the same core problem with Float32 being converted to Float64 (as documented here). I just cannot seem to resolve the issue even though my arrays seem to all be Float32 when I print them out.

Here is my minimal example (Julia: Version 1.4.1 (2020-04-14), Flux: Version 0.10.4):

using Flux
using Flux: logitbinarycrossentropy
using Flux.Data: DataLoader
using ImageFiltering
using MLDatasets: FashionMNIST
using Random
using Zygote

function get_train_loader(batch_size, shuffle::Bool)
    # FashionMNIST is made up of 60k 28 by 28 greyscale images
    train_x, train_y = FashionMNIST.traindata(Float32)
    train_x = reshape(train_x, (28, 28, 1, :))
    train_x = parent(padarray(train_x, Fill(0, (2,2,0,0))))
    return DataLoader(train_x, train_y, batchsize=batch_size, shuffle=shuffle)
end

function vae_loss(encoder_μ, encoder_logσ, decoder, x)
    batch_size = size(x)[end]

    # Forward propagate through encoder
    μ = encoder_μ(x)
    logσ = encoder_logσ(x)
    # Apply reparameterisation trick to sample latent
    z = μ + randn(Float32, size(logσ)) .* exp.(logσ)
    # Reconstruct from latent sample
    x̂ = decoder(z)
    # Negative reconstruction loss Ε_q[logp_x_z]
    logp_x_z = -sum(logitbinarycrossentropy.(x̂, x)) / batch_size
    # KL(qᵩ(z|x)||p(z)) where p(z)=N(0,1) and qᵩ(z|x) models the encoder i.e. reverse KL
    # The @. macro makes sure that all operates are elementwise
    kl_q_p = 0.5f0 * sum(@. (exp(2f0*logσ) + μ^2f0 - 2f0*logσ - 1f0)) / batch_size
    # We want to maximise the evidence lower bound (ELBO)
    β = 1f0
    elbo = logp_x_z - β .* kl_q_p
    return -elbo
end

function train()
    # Define the encoder network
    encoder_features = Chain(
        Conv((4, 4), 1 => 32, relu; stride = 2, pad = 1),
        Conv((4, 4), 32 => 32, relu; stride = 2, pad = 1),
        Conv((4, 4), 32 => 32, relu; stride = 2, pad = 1),
        flatten,
        Dense(32 * 4 * 4, 256, relu),
    )
    encoder_μ = Chain(encoder_features, Dense(256, 10))
    encoder_logσ = Chain(encoder_features, Dense(256, 10))

    # Define the decoder network
    decoder = Chain(
        Dense(10, 256, relu),
        Dense(256, 32 * 4 * 4, relu),
        x -> reshape(x, (4, 4, 32, 16)),
        ConvTranspose((4, 4), 32 => 32, relu; stride = 2, pad = 1),
        ConvTranspose((4, 4), 32 => 32, relu; stride = 2, pad = 1),
        ConvTranspose((4, 4), 32 => 1; stride = 2, pad = 1)
    )

    trainable_params = Flux.params(encoder_μ, encoder_logσ, decoder)

    optimiser = ADAM(0.0001, (0.9, 0.999))
    batch_size = 16
    shuffle_data = true
    dataloader = get_train_loader(batch_size, shuffle_data)

    for epoch_num = 1:3
        for (x_batch, y_batch) in dataloader
            # pullback function returns the result (loss) and a pullback operator (back)
            loss, back = Zygote.pullback(trainable_params) do
                vae_loss(encoder_μ, encoder_logσ, decoder, x_batch)
            end
            # Feed the pullback 1 to obtain the gradients and update the model parameters
            # NOTE: This is where it seems to fail
            gradients = back(1f0)
            Flux.Optimise.update!(optimiser, trainable_params, gradients)
            println(loss)
        end

    end
    println("Training complete!")
end

if abspath(PROGRAM_FILE) == @__FILE__
    train()
end

The stack trace (and preceding warning):

┌ Warning: `haskey(::TargetIterator, name::String)` is deprecated, use `Target(; name = name) !== nothing` instead.
│   caller = llvm_compat(::VersionNumber) at compatibility.jl:176
└ @ CUDAnative ~/.julia/packages/CUDAnative/ierw8/src/compatibility.jl:176
ERROR: LoadError: UndefRefError: access to undefined reference
Stacktrace:
 [1] getindex at ./array.jl:789 [inlined]
 [2] conv_direct!(::Array{AbstractFloat,5}, ::Array{AbstractFloat,5}, ::Array{Float32,5}, ::DenseConvDims{3,(4, 4, 1),32,32,(2, 2, 1),(1, 1, 1, 1, 0, 0),(1, 1, 1),false}; alpha::Float64, beta::Bool) at /home/aleco/.julia/packages/NNlib/FAI3o/src/impl/conv_direct.jl:98
 [3] conv_direct! at /home/aleco/.julia/packages/NNlib/FAI3o/src/impl/conv_direct.jl:51 [inlined]
 [4] conv!(::Array{AbstractFloat,5}, ::Array{AbstractFloat,5}, ::Array{Float32,5}, ::DenseConvDims{3,(4, 4, 1),32,32,(2, 2, 1),(1, 1, 1, 1, 0, 0),(1, 1, 1),false}; kwargs::Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}}) at /home/aleco/.julia/packages/NNlib/FAI3o/src/conv.jl:99
 [5] conv!(::Array{AbstractFloat,5}, ::Array{AbstractFloat,5}, ::Array{Float32,5}, ::DenseConvDims{3,(4, 4, 1),32,32,(2, 2, 1),(1, 1, 1, 1, 0, 0),(1, 1, 1),false}) at /home/aleco/.julia/packages/NNlib/FAI3o/src/conv.jl:97
 [6] conv!(::Array{AbstractFloat,4}, ::Array{AbstractFloat,4}, ::Array{Float32,4}, ::DenseConvDims{2,(4, 4),32,32,(2, 2),(1, 1, 1, 1),(1, 1),false}; kwargs::Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}}) at /home/aleco/.julia/packages/NNlib/FAI3o/src/conv.jl:70
 [7] conv! at /home/aleco/.julia/packages/NNlib/FAI3o/src/conv.jl:70 [inlined]
 [8] conv(::Array{AbstractFloat,4}, ::Array{Float32,4}, ::DenseConvDims{2,(4, 4),32,32,(2, 2),(1, 1, 1, 1),(1, 1),false}; kwargs::Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}}) at /home/aleco/.julia/packages/NNlib/FAI3o/src/conv.jl:116
 [9] conv at /home/aleco/.julia/packages/NNlib/FAI3o/src/conv.jl:114 [inlined]
 [10] #1224 at /home/aleco/.julia/packages/Zygote/1GXzF/src/lib/nnlib.jl:43 [inlined]
 [11] (::Zygote.var"#2720#back#1226"{Zygote.var"#1224#1225"{Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},Array{Float32,4},Array{Float32,4},DenseConvDims{2,(4, 4),32,32,(2, 2),(1, 1, 1, 1),(1, 1),false}}})(::Array{AbstractFloat,4}) at /home/aleco/.julia/packages/ZygoteRules/6nssF/src/adjoint.jl:49
 [12] ConvTranspose at /home/aleco/.julia/packages/Flux/Fj3bt/src/layers/conv.jl:148 [inlined]
 [13] (::typeof(∂(λ)))(::Array{Float64,4}) at /home/aleco/.julia/packages/Zygote/1GXzF/src/compiler/interface2.jl:0
 [14] applychain at /home/aleco/.julia/packages/Flux/Fj3bt/src/layers/basic.jl:36 [inlined]
 [15] (::typeof(∂(applychain)))(::Array{Float64,4}) at /home/aleco/.julia/packages/Zygote/1GXzF/src/compiler/interface2.jl:0
 [16] applychain at /home/aleco/.julia/packages/Flux/Fj3bt/src/layers/basic.jl:36 [inlined]
 [17] (::typeof(∂(applychain)))(::Array{Float64,4}) at /home/aleco/.julia/packages/Zygote/1GXzF/src/compiler/interface2.jl:0
 [18] applychain at /home/aleco/.julia/packages/Flux/Fj3bt/src/layers/basic.jl:36 [inlined]
 [19] (::typeof(∂(applychain)))(::Array{Float64,4}) at /home/aleco/.julia/packages/Zygote/1GXzF/src/compiler/interface2.jl:0
 [20] applychain at /home/aleco/.julia/packages/Flux/Fj3bt/src/layers/basic.jl:36 [inlined]
 [21] (::typeof(∂(applychain)))(::Array{Float64,4}) at /home/aleco/.julia/packages/Zygote/1GXzF/src/compiler/interface2.jl:0
 [22] applychain at /home/aleco/.julia/packages/Flux/Fj3bt/src/layers/basic.jl:36 [inlined]
 [23] (::typeof(∂(applychain)))(::Array{Float64,4}) at /home/aleco/.julia/packages/Zygote/1GXzF/src/compiler/interface2.jl:0
 [24] Chain at /home/aleco/.julia/packages/Flux/Fj3bt/src/layers/basic.jl:38 [inlined]
 [25] (::typeof(∂(λ)))(::Array{Float64,4}) at /home/aleco/.julia/packages/Zygote/1GXzF/src/compiler/interface2.jl:0
 [26] vae_loss at /home/aleco/Documents/flux-vae/minimal_example/online.jl:26 [inlined]
 [27] (::typeof(∂(vae_loss)))(::Float32) at /home/aleco/.julia/packages/Zygote/1GXzF/src/compiler/interface2.jl:0
 [28] #4 at /home/aleco/Documents/flux-vae/minimal_example/online.jl:71 [inlined]
 [29] (::typeof(∂(λ)))(::Float32) at /home/aleco/.julia/packages/Zygote/1GXzF/src/compiler/interface2.jl:0
 [30] (::Zygote.var"#50#51"{Params,Zygote.Context,typeof(∂(λ))})(::Float32) at /home/aleco/.julia/packages/Zygote/1GXzF/src/compiler/interface.jl:177
 [31] train() at /home/aleco/Documents/flux-vae/minimal_example/online.jl:75
 [32] top-level scope at /home/aleco/Documents/flux-vae/minimal_example/online.jl:85
 [33] include(::Module, ::String) at ./Base.jl:377
 [34] exec_options(::Base.JLOptions) at ./client.jl:288
 [35] _start() at ./client.jl:484
in expression starting at /home/aleco/Documents/flux-vae/minimal_example/online.jl:84

Any help would be greatly appreciated as I am really having a hard time trying to debug this.

Kind regards,
Aleco

Hello Aleco,

I entered the above code into my julia (1.4.1) shell with my package status looking like this

  [c52e3926] Atom v0.12.11
  [fbb218c0] BSON v0.2.6
  [3895d2a7] CUDAapi v4.0.0
  [35d6a980] ColorSchemes v3.9.0
  [3a865a2d] CuArrays v2.2.0
  [31c24e10] Distributions v0.23.2
  [ced4e74d] DistributionsAD v0.5.2
  [587475ba] Flux v0.10.4
  [f67ccb44] HDF5 v0.13.2
  [6a3955dd] ImageFiltering v0.6.13
  [916415d5] Images v0.22.2
  [c8e1da08] IterTools v1.3.0
  [033835bb] JLD2 v0.1.13
  [e5e0dc1b] Juno v0.8.2
  [eb30cadb] MLDatasets v0.5.2
  [442fdcdd] Measures v0.3.1
  [a3a9e032] NIfTI v0.4.1
  [d96e819e] Parameters v0.12.1
  [91a5bcdd] Plots v1.3.5
  [e88e6eb3] Zygote v0.4.20
  [9a3f8284] Random 

And the algorithm started printing loss values after calling train() and terminated after three epochs without problems. After what you have written here, I suspect that something is wrong with your julia version or the packages. Can you confirm that you have the same package versions as above? Especially Flux and Zygote should be. You can check by typing ] status into the julia console. Also, when you call julia from your command line does julia 1.4.1 load?

1 Like

Hi @flo,

Thank you for your reply. You are correct, it does seem to have been a package versioning problem. After I created a new environment and replicated your package versions, I saw the same result as you. It worked :slight_smile:

My previous environment looked like this:

  [fbb218c0] BSON v0.2.6
  [052768ef] CUDA v0.1.0
  [634d3b9d] DrWatson v1.13.1
  [587475ba] Flux v0.10.4
  [6a3955dd] ImageFiltering v0.6.13
  [916415d5] Images v0.22.3
  [4138dd39] JLD v0.10.0
  [929cbde3] LLVM v1.7.0
  [eb30cadb] MLDatasets v0.5.2
  [872c559c] NNlib v0.6.6
  [d96e819e] Parameters v0.12.1
  [91a5bcdd] Plots v1.4.4
  [92933f4c] ProgressMeter v1.3.1
  [e88e6eb3] Zygote v0.4.22
  [56ddb016] Logging 

I suspect that downgrading Zygote to v0.4.20 is the key change. Thanks so much for your help!

Glad I could help! A solution tick would be highly appreciated :slight_smile:

1 Like

Already done :slight_smile:

1 Like