I am writing a training program for the VAEAC model, but when I use Optimisers.Adam(lr), I get an error. When I use, for example, Optimisers.Descent(lr), this problem does not occur.
Error message:
ERROR: MethodError: no method matching Float32(::Reactant.TracedRNumber{Float32})
The type `Float32` exists, but no method is defined for this combination of argument types when trying to construct it.
Closest candidates are:
(::Type{T})(::T) where T<:Number
@ Core boot.jl:900
Float32(::IrrationalConstants.Logπ)
@ IrrationalConstants ~/.julia/packages/IrrationalConstants/lWTip/src/macro.jl:132
Float32(::UInt8)
@ Base float.jl:245
My training function:
function train_vaeac(; epochs=20, lr=0.001f0, batch_size=100)
model = VAEAC(input_dim, latent_dim, hidden_dim)
ps, st = Lux.setup(Random.default_rng(), model)
Reactant.set_default_backend("gpu")
dev = reactant_device()
ps = ps |> dev; st = st |> dev
data = load_binary_mnist_matrix() #|> dev
loader = make_loader(data; batchsize=batch_size, shuffle=true)
loader_dev = DeviceIterator(dev, loader)
ts = Lux.Training.TrainState(model, ps, st, Optimisers.Adam(lr))
for epoch in 1:epochs
tot = 0f0
nb = 0
for xb in loader_dev
mask = Float32.(generate_mask(size(xb))) |> dev
ε = randn(Float32, latent_dim, size(xb, 2)) |> dev
debug_data = deepcopy((model,ps, st, mask, ε))
_, loss, _, ts = Lux.Training.single_train_step!(Lux.AutoEnzyme(), loss_fn, (xb, mask, ε), ts)
tot += loss
nb += 1
end
@info "epoch=$epoch avg_loss=$(tot/nb)"
end
return ts
end
Loss function that I use:
function loss_fn(model, ps, st, (x, mask, ε))
(logits, μq, logσq, μp, logσp), st2 = Lux.apply(model, (x, mask, ε), ps, st)
recon = bce_with_logits_masked(logits, x, mask)
kl = kl_diag_gaussians(μq, logσq, μp, logσp)
(recon + kl), st2, (; recon, kl)
end
function bce_with_logits_masked(logits, x, mask)
w = 1f0 .- mask
per_elem = softplus.(logits) .- x .* logits
s = sum(w .* per_elem)
return s / size(x, 2)
end
function kl_diag_gaussians(μq, logσq, μp, logσp)
σq2 = exp.(2f0 .* logσq)
σp2 = exp.(2f0 .* logσp)
t = (σq2 .+ (μq .- μp).^2) ./ σp2 .- 1f0 .+ 2f0 .* (logσp .- logσq)
0.5f0 * sum(t) / size(μq, 2)
end