Flux VAE on the iris dataset

I’ve been stuck trying to get a variational autoencoder working using Flux. I’ve tried to adapt the mnist VAE in Flux’s model zoo (which I could not get to work either) to the iris data set.

Here’s my code so far:

using Distributions
using Flux
using Flux: params, logitbinarycrossentropy
using Flux: Data.DataLoader
using Flux: Losses
using RDatasets

X = dataset("datasets", "iris");
select!(X, Not(:Species));
X = transpose(convert(Matrix, X));
loader = DataLoader(X, batchsize=1, shuffle=true);

eH1 = Dense(16, 32)
eH2 = Dense(16, 32)
encoder = Chain(Dense(4, 16, relu), 
                x -> (eH1(x), eH2(x)))                

decoder = Chain(Dense(32, 16),
                Dense(16, 4))

function reparameterize(μ, logvar)
    std = exp.(logvar ./ 2)
    eps = rand(MvNormal(vec(fill(0., 32)), vec(std)))
    uu = μ .+ eps .* std
    uu, μ, logvar
end

model = Chain(x -> encoder(x),
              x -> reparameterize(x[1], x[2]),
              x -> (decoder(x[1]), x[2], x[3]))

function loss(x)
    x̂, μ, logvar = model(x)
    reconst_loss = sum(Losses.logitbinarycrossentropy.(x̂, x))
    kl_div = -0.5 * sum(1. .+ logvar .- μ.^2 .- exp.(logvar))
    reconst_loss + kl_div
end

ps = Flux.Params(model) 
opt = ADAM()
Flux.train!(loss, ps, loader, opt)

Here’s the error:

julia> Flux.train!(loss, ps, loader, opt)
ERROR: Mutating arrays is not supported
Stacktrace:
 [1] error(::String) at ./error.jl:33
 [2] (::Zygote.var"#368#369")(::Nothing) at /Users/rs990e/.julia/packages/Zygote/ggM8Z/src/lib/array.jl:61
 [3] (::Zygote.var"#2255#back#370"{Zygote.var"#368#369"})(::Nothing) at /Users/rs990e/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:59
 [4] materialize! at ./broadcast.jl:848 [inlined]
 [5] materialize! at ./broadcast.jl:845 [inlined]
 [6] materialize! at ./broadcast.jl:841 [inlined]
 [7] broadcast! at ./broadcast.jl:814 [inlined]

I’m not sure which array is being modified. Any suggestions?

Here’s my env info:

(@v1.5) pkg> st
Status `~/.julia/environments/v1.5/Project.toml`
  [31c24e10] Distributions v0.24.12
  [587475ba] Flux v0.11.1
  [ce6b1742] RDatasets v0.7.4
  [e88e6eb3] Zygote v0.5.17

Hipshot without having run the code:

Iirc the reparametrization trick is just to avoid taking the gradient of generating random numbers, like you do here: eps = rand(MvNormal(vec(fill(0., 32)), vec(std))). You probably need to tell Zygote to not try to differentiate that line, e.g. using @nograd.

Another issue with the model is that all those anonymous functions are not functors and therefore their parameters will not be captured by params. I’m also uncertain if Flux.Params(model) does the same thing as params(model). In either case, you need to give params all the layers with parameters you want to train. I think that `params([encoder, eH1, eH2, decoder]) should work.

I can confirm that it absolutely does not do the right thing and you should always use lowercase params with Flux. That said, it doesn’t seem like that’s why the mutation error is being thrown. Instead, I think this line is to blame:

Since Flux 0.11.something, logitbinarycrossentropy and co will auto-reduce using mean. In other words, the broadcast is no longer required and probably causing your mutation issue. If you look at the semi-recently updated model zoo example, you can see this new API in action.