I’m trying to get this autoencoder working. I started with the implementation in the model-zoo and have this:
using Flux
include("../src/modules.jl")
using .Tools, Distributions
import Distributions: logpdf
####################### Experiment setup #################################
s2n_ratio = 0.01f0
# Generate data
cd = pwd()
Xs, _ = t_dataset(
"$cd/../data/exp_raw/t-scaled.jpeg",
100, s2n_ratio, 10
)
############################# Model #################################
struct VAE{I, G, Z}
A::I
μ::Z
logσ::Z
f::G
end
# Attributes
latdim(m::VAE) = typeof(m.f) <: Dense ? size(m.f.W[1]) : size(m.f[1].W[1])
hiddim(m::VAE) = typeof(m.f) <: Dense ? size(m.f.W[1]) : size(m.f[1].W[2])
nhid(m::VAE) = typeof(m.A) <: Dense ? 1 : length(m.A)
Flux.params(m::VAE) = Flux.params(m.A, m.μ, m.logσ, m.f)
# Constructor
function VAE(Di::Integer, Dh::Integer, Dz::Integer,
nhid::Integer=1, act=tanh)
# inference network
if nhid == 1
A = Dense(Di, Dh, act)
else
A = Chain(
Dense(Di, Dh, act),
(Dense(Dh, Dh, act) for _ in 2:nhid)...
)
end
# variational posterior
μ = Dense(Dh, Dz)
logσ = Dense(Dh, Dz)
# generative network
f = Chain(
Dense(Dz, Dh, act),
(Dense(Dh, Dh, act) for _ in 2:nhid)...,
Dense(Dh, Di, σ)
)
return VAE(A, μ, logσ, f)
end
# Inference functions
g(m::VAE, X) = (h = m.A(X); (m.μ(h), m.logσ(h)))
z(μ, logσ) = μ + exp(logσ) * randn(Float32)
function (m::VAE)(X)
μ, logσ = g(m, X)
z0 = z.(μ, logσ)
return z0
end
# Sample from the model
function sample(m::VAE)
Dz = latdim(m)
return sample(m, zeros(Dz), zeros(Dz))
end
sample(m::VAE, μ, logσ) = rand.(Bernoulli.(m.f(z.(μ, logσ))))
# KL-divergence between approximation posterior and N(0, 1) prior.
kl_q_p(μ, logσ) = 0.5f0 * sum(exp.(2f0 .* logσ) + μ.^2 .- 1f0 .+ logσ.^2)
# logp(x|z) - conditional probability of data given latents, usuming discrete pixel values.
logp_x_z(m::VAE, x, z) = sum(logpdf.(Bernoulli.(m.f(z)), x))
logpdf(b::Bernoulli, y::Float32) = y * log(b.p + eps(Float32)) + (1f0 - y) * log(1f0 - b.p + eps(Float32))
# Helper functions useful for defining new losses w.r.t. that use the VAE's latents upstream
elbo(m::VAE, μ̂ , logσ̂ , z, X) = -(logp_x_z(m, X, z) - kl_q_p(μ̂ , logσ̂ )) * 1 / size(X, 2)
elbo(m::VAE, μ̂ , logσ̂ , X) = (z0 = z.(μ̂ , logσ̂ ); elbo(m, μ̂ , logσ̂ , z0, X))
# Monte Carlo estimator of mean ELBO using batch samples.
elbo(m::VAE, X) = (h = g(m, X); z0 = z.(h...); elbo(m, h..., z0, X))
# ######################## Training ###################################
Di, Dh, Dz, C = 32^2, 256, 64, 1
vae = VAE(Di, Dh, Dz, 1, tanh)
loss(X) = elbo(vae, X)
ps = Flux.params(vae)
opt = ADAM()
Flux.train!(loss, ps, zip(Xs), opt)
However training gives me this error in the backwards path. which I can’t understand and i don’t know where the difference with the original is.
LoadError: type Float64 has no field partials
Stacktrace:
[1] getproperty(::Any, ::Symbol) at ./sysimg.jl:18
[2] partial(::getfield(Base.Broadcast, Symbol("##2#4")){getfield(Base.Broadcast, Symbol("##8#10")){getfield(Base.Broadcast, Symbol("##5#6")){getfield(Base.Broadcast, Symbol("##1#3"))},getfield(Base.Broadcast, Symbol("##5#6")){getfield(Base.Broadcast, Symbol("##7#9"))},getfield(Base.Broadcast, Symbol("##11#12")){getfield(Base.Broadcast, Symbol("##13#14"))},getfield(Base.Broadcast, Symbol("##15#16")){getfield(Base.Broadcast, Symbol("##17#18"))},UnionAll},typeof(logpdf)}, ::Float64, ::Int64, ::Float32, ::Float32) at /home/milton/.julia/packages/Tracker/RRYy6/src/lib/array.jl:480
[3] _broadcast_getindex at ./broadcast.jl:578 [inlined]
[4] getindex at ./broadcast.jl:511 [inlined]
[5] copy at ./broadcast.jl:787 [inlined]
[6] materialize(::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{2},Nothing,typeof(Tracker.partial),Tuple{Base.RefValue{getfield(Base.Broadcast, Symbol("##2#4")){getfield(Base.Broadcast, Symbol("##8#10")){getfield(Base.Broadcast, Symbol("##5#6")){getfield(Base.Broadcast, Symbol("##1#3"))},getfield(Base.Broadcast, Symbol("##5#6")){getfield(Base.Broadcast, Symbol("##7#9"))},getfield(Base.Broadcast, Symbol("##11#12")){getfield(Base.Broadcast, Symbol("##13#14"))},getfield(Base.Broadcast, Symbol("##15#16")){getfield(Base.Broadcast, Symbol("##17#18"))},UnionAll},typeof(logpdf)}},Array{Float64,2},Int64,Array{Float32,2},Array{Float32,2}}}) at ./broadcast.jl:753
[7] broadcast(::typeof(Tracker.partial), ::Base.RefValue{getfield(Base.Broadcast, Symbol("##2#4")){getfield(Base.Broadcast, Symbol("##8#10")){getfield(Base.Broadcast, Symbol("##5#6")){getfield(Base.Broadcast, Symbol("##1#3"))},getfield(Base.Broadcast, Symbol("##5#6")){getfield(Base.Broadcast, Symbol("##7#9"))},getfield(Base.Broadcast, Symbol("##11#12")){getfield(Base.Broadcast, Symbol("##13#14"))},getfield(Base.Broadcast, Symbol("##15#16")){getfield(Base.Broadcast, Symbol("##17#18"))},UnionAll},typeof(logpdf)}}, ::Array{Float64,2}, ::Int64, ::Vararg{Any,N} where N) at ./broadcast.jl:707
[8] ∇broadcast(::typeof(Tracker.partial), ::Base.RefValue{getfield(Base.Broadcast, Symbol("##2#4")){getfield(Base.Broadcast, Symbol("##8#10")){getfield(Base.Broadcast, Symbol("##5#6")){getfield(Base.Broadcast, Symbol("##1#3"))},getfield(Base.Broadcast, Symbol("##5#6")){getfield(Base.Broadcast, Symbol("##7#9"))},getfield(Base.Broadcast, Symbol("##11#12")){getfield(Base.Broadcast, Symbol("##13#14"))},getfield(Base.Broadcast, Symbol("##15#16")){getfield(Base.Broadcast, Symbol("##17#18"))},UnionAll},typeof(logpdf)}}, ::Array{Float64,2}, ::Int64, ::TrackedArray{…,Array{Float32,2}}, ::Array{Float32,2}) at /home/milton/.julia/packages/Tracker/RRYy6/src/lib/array.jl:484
....
thanks in advance