Hi I have trained Variational Autoencoder from Lux.jl embeded inside a Turing.jl
model. But I am unable to convert a latent variable to float32
to pass it through Lux.jl model to get its ouput. I saw a post here (Issue with Float32 Precision in Turing Model Sampling) saying this is unfortunately not possible.
Just wondering if there are any solutions yet or a hack for my case ?
@model function aggvae_prev_betabin_v0(
vae,
vae_ps,
vae_st_infer,
x_pop_w::Vector{Float64},
M_lo::Matrix{Int64},
M_hi::Matrix{Int64},
n_tested_lo::Vector{Int64},
n_tested_hi::Vector{Int64},
n_positive_lo::Vector{Int64};
jitter::Float64 = 1e-4,
z_dim::Int = 389,
)
# GP realization : Construct through VAE
z ~ MvNormal(zeros(z_dim), I)
f_approx, _ = vae.decoder(reshape(Float32.(z),:,1), vae_ps, vae_st_infer)
# Prevalence model
μp ~ Normal(0, 1) # mean prevalence
logits = @. μp + f_approx
p = logistic.(logits) # prevelence : (n_pts,)
p_w = p .* x_pop_w # Weighted prevelance
p_aggr_lo = M_g(M_lo, p_w) # (9,)
p_aggr_hi = M_g(M_hi, p_w) # (49,)
p_aggr = vcat(p_aggr_lo, p_aggr_hi) #(58,)
n_tested = vcat(n_tested_lo, n_tested_hi)
n_regions_lo = size(n_tested_lo,1) # 9
n_regions_hi = size(n_tested_hi, 1) #49
n_regions = n_regions_lo + n_regions_hi #58
n_positive = vcat(n_positive_lo, fill(missing, n_regions_hi))
# BetaBinomial Parameters
ϕ ~ Gamma(2,1)
for i = 1:n_regions
α = p_aggr[i] * ϕ
β = (1 - p_aggr[i]) * ϕ
n_positive[i] ~ BetaBinomial(n_tested[i],α,β)
end
end
Run Model
in_dim = 3117 ; z_dim = 389
vae = VAE(Random.Xoshiro(), in_dim, z_dim)
ps = trained_ps.decoder
st = trained_st.decoder
st_infer = Lux.testmode(st)
model = aggvae_prev_betabin_v0(
vae,
ps,
st_infer,
x_pop_w,
pol_pts_lo,
pol_pts_hi,
n_tested_lo,
n_tested_hi,
n_positive_lo,
)
prior_prediction = sample(model, Prior(), 100)
# Prior runs fine
prior_df = DataFrame(prior_prediction)
# Error occurs when computing posterior
posterior_prediction = sample(model, NUTS(), 100)
Error : This is the line causing the problems
f_approx, _ = vae.decoder(reshape(Float32.(z),:,1), vae_ps, vae_st_infer)
where you arnt allowed to convert to Float32
ERROR: MethodError: no method matching Float32(::ForwardDiff.Dual{ForwardDiff.Tag{DynamicPPL.DynamicPPLTag, Float64}, Float64, 12})
The type `Float32` exists, but no method is defined for this combination of argument types when trying to construct it.
Closest candidates are:
(::Type{T})(::Real, ::RoundingMode) where T<:AbstractFloat
@ Base rounding.jl:265
(::Type{T})(::T) where T<:Number
@ Core boot.jl:965
Float32(::Int128)
@ Base float.jl:320
...
Stacktrace:
[1] _broadcast_getindex_evalf
@ ./broadcast.jl:699 [inlined]
[2] _broadcast_getindex
@ ./broadcast.jl:672 [inlined]
[3] _getindex
@ ./broadcast.jl:620 [inlined]
[4] getindex
@ ./broadcast.jl:616 [inlined]
[5] copy
@ ./broadcast.jl:933 [inlined]
[6] materialize
@ ./broadcast.jl:894 [inlined]
[7] aggvae_prev_betabin_v0(__model__::DynamicPPL.Model{…}, __varinfo__::DynamicPPL.VarInfo{…}, vae::VAE{…}, vae_ps::@NamedTuple{…}, vae_st_infer::@NamedTuple{…}, x_pop_w::Vector{…}, M_lo::Matrix{…}, M_hi::Matrix{…}, n_tested_lo::Vector{…}, n_tested_hi::Vector{…}, n_positive_lo::Vector{…}; jitter::Float64, z_dim::Int64)
@ Main ~/workspace/Lo2Hi-IV-Est-AggVAE/src/jl/gpVAEpreVv3.jl:315
[8] aggvae_prev_betabin_v0
@ ~/workspace/Lo2Hi-IV-Est-AggVAE/src/jl/gpVAEpreVv3.jl:299 [inlined]
[9] _evaluate!!
@ ~/.julia/packages/DynamicPPL/R1CV0/src/model.jl:921 [inlined]
[10] evaluate_threadunsafe!!
@ ~/.julia/packages/DynamicPPL/R1CV0/src/model.jl:887 [inlined]
[11] evaluate!!
@ ~/.julia/packages/DynamicPPL/R1CV0/src/model.jl:872 [inlined]
[12] logdensity_at(x::Vector{…}, model::DynamicPPL.Model{…}, getlogdensity::typeof(DynamicPPL.getlogjoint_internal), varinfo::DynamicPPL.VarInfo{…})
@ DynamicPPL ~/.julia/packages/DynamicPPL/R1CV0/src/logdensityfunction.jl:242
[13] LogDensityAt
@ ~/.julia/packages/DynamicPPL/R1CV0/src/logdensityfunction.jl:262 [inlined]
[14] chunk_mode_gradient!(result::DiffResults.MutableDiffResult{…}, f::DynamicPPL.LogDensityAt{…}, x::Vector{…}, cfg::ForwardDiff.GradientConfig{…})
@ ForwardDiff ~/.julia/packages/ForwardDiff/X74OO/src/gradient.jl:125
[15] gradient!(result::DiffResults.MutableDiffResult{…}, f::DynamicPPL.LogDensityAt{…}, x::Vector{…}, cfg::ForwardDiff.GradientConfig{…}, ::Val{…})
@ ForwardDiff ~/.julia/packages/ForwardDiff/X74OO/src/gradient.jl:41
[16] value_and_gradient(::DynamicPPL.LogDensityAt{…}, ::DifferentiationInterfaceForwardDiffExt.ForwardDiffGradientPrep{…}, ::AutoForwardDiff{…}, ::Vector{…})
@ DifferentiationInterfaceForwardDiffExt ~/.julia/packages/DifferentiationInterface/L0TGS/ext/DifferentiationInterfaceForwardDiffExt/onearg.jl:419
[17] logdensity_and_gradient(f::LogDensityFunction{…}, x::Vector{…})
@ DynamicPPL ~/.julia/packages/DynamicPPL/R1CV0/src/logdensityfunction.jl:289
[18] Fix
@ ./operators.jl:1193 [inlined]
[19] ∂H∂θ(h::AdvancedHMC.Hamiltonian{…}, θ::Vector{…})
@ AdvancedHMC ~/.julia/packages/AdvancedHMC/B1wPY/src/hamiltonian.jl:46
[20] phasepoint(h::AdvancedHMC.Hamiltonian{…}, θ::Vector{…}, r::Vector{…})
@ AdvancedHMC ~/.julia/packages/AdvancedHMC/B1wPY/src/hamiltonian.jl:103
[21] phasepoint
@ ~/.julia/packages/AdvancedHMC/B1wPY/src/hamiltonian.jl:185 [inlined]
[22] find_initial_params(rng::Random.TaskLocalRNG, model::DynamicPPL.Model{…}, varinfo::DynamicPPL.VarInfo{…}, hamiltonian::AdvancedHMC.Hamiltonian{…}; max_attempts::Int64)
@ Turing.Inference ~/.julia/packages/Turing/gtb9I/src/mcmc/hmc.jl:157
[23] find_initial_params(rng::Random.TaskLocalRNG, model::DynamicPPL.Model{…}, varinfo::DynamicPPL.VarInfo{…}, hamiltonian::AdvancedHMC.Hamiltonian{…})
@ Turing.Inference ~/.julia/packages/Turing/gtb9I/src/mcmc/hmc.jl:146
[24] initialstep(rng::Random.TaskLocalRNG, model::DynamicPPL.Model{…}, spl::DynamicPPL.Sampler{…}, vi_original::DynamicPPL.VarInfo{…}; initial_params::Nothing, nadapts::Int64, verbose::Bool, kwargs::@Kwargs{})
@ Turing.Inference ~/.julia/packages/Turing/gtb9I/src/mcmc/hmc.jl:207
[25] step(rng::Random.TaskLocalRNG, model::DynamicPPL.Model{…}, spl::DynamicPPL.Sampler{…}; initial_params::Nothing, kwargs::@Kwargs{…})
@ DynamicPPL ~/.julia/packages/DynamicPPL/R1CV0/src/sampler.jl:133
[26] step
@ ~/.julia/packages/DynamicPPL/R1CV0/src/sampler.jl:116 [inlined]
[27] macro expansion
@ ~/.julia/packages/AbstractMCMC/z4BsN/src/sample.jl:188 [inlined]
[28] macro expansion
@ ~/.julia/packages/AbstractMCMC/z4BsN/src/logging.jl:137 [inlined]
[29] mcmcsample(rng::Random.TaskLocalRNG, model::DynamicPPL.Model{…}, sampler::DynamicPPL.Sampler{…}, N::Int64; progress::Bool, progressname::String, callback::Nothing, num_warmup::Int64, discard_initial::Int64, thinning::Int64, chain_type::Type, initial_state::Nothing, kwargs::@Kwargs{…})
@ AbstractMCMC ~/.julia/packages/AbstractMCMC/z4BsN/src/sample.jl:168
[30] sample(rng::Random.TaskLocalRNG, model::DynamicPPL.Model{…}, sampler::DynamicPPL.Sampler{…}, N::Int64; chain_type::Type, resume_from::Nothing, initial_state::Nothing, progress::Bool, nadapts::Int64, discard_adapt::Bool, discard_initial::Int64, kwargs::@Kwargs{})
@ Turing.Inference ~/.julia/packages/Turing/gtb9I/src/mcmc/hmc.jl:117
[31] sample
@ ~/.julia/packages/Turing/gtb9I/src/mcmc/hmc.jl:86 [inlined]
[32] #sample#112
@ ~/.julia/packages/Turing/gtb9I/src/mcmc/abstractmcmc.jl:31 [inlined]
[33] sample
@ ~/.julia/packages/Turing/gtb9I/src/mcmc/abstractmcmc.jl:22 [inlined]
[34] #sample#111
@ ~/.julia/packages/Turing/gtb9I/src/mcmc/abstractmcmc.jl:19 [inlined]
[35] sample(model::DynamicPPL.Model{…}, alg::NUTS{…}, N::Int64)
@ Turing.Inference ~/.julia/packages/Turing/gtb9I/src/mcmc/abstractmcmc.jl:16
[36] top-level scope
@ ~/workspace/Lo2Hi-IV-Est-AggVAE/src/jl/gpVAEpreVv3.jl:366
Some type information was truncated. Use `show(err)` to see complete types.