Any solution for float 32 required by lux models embeded in Turing.jl

Hi I have trained Variational Autoencoder from Lux.jl embeded inside a Turing.jl model. But I am unable to convert a latent variable to float32 to pass it through Lux.jl model to get its ouput. I saw a post here (Issue with Float32 Precision in Turing Model Sampling) saying this is unfortunately not possible.

Just wondering if there are any solutions yet or a hack for my case ?

@model function aggvae_prev_betabin_v0(
    vae,
    vae_ps,
    vae_st_infer,
    x_pop_w::Vector{Float64},
    M_lo::Matrix{Int64},
    M_hi::Matrix{Int64},
    n_tested_lo::Vector{Int64},
    n_tested_hi::Vector{Int64},
    n_positive_lo::Vector{Int64};
    jitter::Float64 = 1e-4,
    z_dim::Int = 389,
)
    # GP realization : Construct through VAE 
    z ~ MvNormal(zeros(z_dim), I)

    f_approx, _ = vae.decoder(reshape(Float32.(z),:,1), vae_ps, vae_st_infer)

    # Prevalence model
    μp ~ Normal(0, 1)             # mean prevalence
    logits = @. μp + f_approx
    p = logistic.(logits)         # prevelence : (n_pts,)
    p_w = p .* x_pop_w             # Weighted prevelance 

    p_aggr_lo = M_g(M_lo, p_w)    # (9,)
    p_aggr_hi = M_g(M_hi, p_w)    # (49,)
    p_aggr = vcat(p_aggr_lo, p_aggr_hi) #(58,)

    n_tested = vcat(n_tested_lo, n_tested_hi)
    
    n_regions_lo = size(n_tested_lo,1) # 9
    n_regions_hi = size(n_tested_hi, 1) #49
    n_regions = n_regions_lo + n_regions_hi #58
    
    n_positive = vcat(n_positive_lo, fill(missing, n_regions_hi))

    # BetaBinomial Parameters 
    ϕ ~ Gamma(2,1)
    for i = 1:n_regions 
        α = p_aggr[i] * ϕ
        β = (1 - p_aggr[i]) * ϕ
        n_positive[i] ~ BetaBinomial(n_tested[i],α,β)
    end 
end

Run Model

in_dim = 3117 ; z_dim = 389
vae = VAE(Random.Xoshiro(), in_dim, z_dim)
ps = trained_ps.decoder 
st = trained_st.decoder
st_infer = Lux.testmode(st)

model = aggvae_prev_betabin_v0(
    vae, 
    ps,
    st_infer,
    x_pop_w,
    pol_pts_lo,
    pol_pts_hi, 
    n_tested_lo,
    n_tested_hi,
    n_positive_lo,
)

prior_prediction = sample(model, Prior(), 100)
# Prior runs fine
prior_df = DataFrame(prior_prediction)
# Error occurs when computing posterior
posterior_prediction = sample(model, NUTS(), 100)

Error : This is the line causing the problems
f_approx, _ = vae.decoder(reshape(Float32.(z),:,1), vae_ps, vae_st_infer)

where you arnt allowed to convert to Float32

ERROR: MethodError: no method matching Float32(::ForwardDiff.Dual{ForwardDiff.Tag{DynamicPPL.DynamicPPLTag, Float64}, Float64, 12})
The type `Float32` exists, but no method is defined for this combination of argument types when trying to construct it.

Closest candidates are:
  (::Type{T})(::Real, ::RoundingMode) where T<:AbstractFloat
   @ Base rounding.jl:265
  (::Type{T})(::T) where T<:Number
   @ Core boot.jl:965
  Float32(::Int128)
   @ Base float.jl:320
  ...

Stacktrace:
  [1] _broadcast_getindex_evalf
    @ ./broadcast.jl:699 [inlined]
  [2] _broadcast_getindex
    @ ./broadcast.jl:672 [inlined]
  [3] _getindex
    @ ./broadcast.jl:620 [inlined]
  [4] getindex
    @ ./broadcast.jl:616 [inlined]
  [5] copy
    @ ./broadcast.jl:933 [inlined]
  [6] materialize
    @ ./broadcast.jl:894 [inlined]
  [7] aggvae_prev_betabin_v0(__model__::DynamicPPL.Model{…}, __varinfo__::DynamicPPL.VarInfo{…}, vae::VAE{…}, vae_ps::@NamedTuple{…}, vae_st_infer::@NamedTuple{…}, x_pop_w::Vector{…}, M_lo::Matrix{…}, M_hi::Matrix{…}, n_tested_lo::Vector{…}, n_tested_hi::Vector{…}, n_positive_lo::Vector{…}; jitter::Float64, z_dim::Int64)
    @ Main ~/workspace/Lo2Hi-IV-Est-AggVAE/src/jl/gpVAEpreVv3.jl:315
  [8] aggvae_prev_betabin_v0
    @ ~/workspace/Lo2Hi-IV-Est-AggVAE/src/jl/gpVAEpreVv3.jl:299 [inlined]
  [9] _evaluate!!
    @ ~/.julia/packages/DynamicPPL/R1CV0/src/model.jl:921 [inlined]
 [10] evaluate_threadunsafe!!
    @ ~/.julia/packages/DynamicPPL/R1CV0/src/model.jl:887 [inlined]
 [11] evaluate!!
    @ ~/.julia/packages/DynamicPPL/R1CV0/src/model.jl:872 [inlined]
 [12] logdensity_at(x::Vector{…}, model::DynamicPPL.Model{…}, getlogdensity::typeof(DynamicPPL.getlogjoint_internal), varinfo::DynamicPPL.VarInfo{…})
    @ DynamicPPL ~/.julia/packages/DynamicPPL/R1CV0/src/logdensityfunction.jl:242
 [13] LogDensityAt
    @ ~/.julia/packages/DynamicPPL/R1CV0/src/logdensityfunction.jl:262 [inlined]
 [14] chunk_mode_gradient!(result::DiffResults.MutableDiffResult{…}, f::DynamicPPL.LogDensityAt{…}, x::Vector{…}, cfg::ForwardDiff.GradientConfig{…})
    @ ForwardDiff ~/.julia/packages/ForwardDiff/X74OO/src/gradient.jl:125
 [15] gradient!(result::DiffResults.MutableDiffResult{…}, f::DynamicPPL.LogDensityAt{…}, x::Vector{…}, cfg::ForwardDiff.GradientConfig{…}, ::Val{…})
    @ ForwardDiff ~/.julia/packages/ForwardDiff/X74OO/src/gradient.jl:41
 [16] value_and_gradient(::DynamicPPL.LogDensityAt{…}, ::DifferentiationInterfaceForwardDiffExt.ForwardDiffGradientPrep{…}, ::AutoForwardDiff{…}, ::Vector{…})
    @ DifferentiationInterfaceForwardDiffExt ~/.julia/packages/DifferentiationInterface/L0TGS/ext/DifferentiationInterfaceForwardDiffExt/onearg.jl:419
 [17] logdensity_and_gradient(f::LogDensityFunction{…}, x::Vector{…})
    @ DynamicPPL ~/.julia/packages/DynamicPPL/R1CV0/src/logdensityfunction.jl:289
 [18] Fix
    @ ./operators.jl:1193 [inlined]
 [19] ∂H∂θ(h::AdvancedHMC.Hamiltonian{…}, θ::Vector{…})
    @ AdvancedHMC ~/.julia/packages/AdvancedHMC/B1wPY/src/hamiltonian.jl:46
 [20] phasepoint(h::AdvancedHMC.Hamiltonian{…}, θ::Vector{…}, r::Vector{…})
    @ AdvancedHMC ~/.julia/packages/AdvancedHMC/B1wPY/src/hamiltonian.jl:103
 [21] phasepoint
    @ ~/.julia/packages/AdvancedHMC/B1wPY/src/hamiltonian.jl:185 [inlined]
 [22] find_initial_params(rng::Random.TaskLocalRNG, model::DynamicPPL.Model{…}, varinfo::DynamicPPL.VarInfo{…}, hamiltonian::AdvancedHMC.Hamiltonian{…}; max_attempts::Int64)
    @ Turing.Inference ~/.julia/packages/Turing/gtb9I/src/mcmc/hmc.jl:157
 [23] find_initial_params(rng::Random.TaskLocalRNG, model::DynamicPPL.Model{…}, varinfo::DynamicPPL.VarInfo{…}, hamiltonian::AdvancedHMC.Hamiltonian{…})
    @ Turing.Inference ~/.julia/packages/Turing/gtb9I/src/mcmc/hmc.jl:146
 [24] initialstep(rng::Random.TaskLocalRNG, model::DynamicPPL.Model{…}, spl::DynamicPPL.Sampler{…}, vi_original::DynamicPPL.VarInfo{…}; initial_params::Nothing, nadapts::Int64, verbose::Bool, kwargs::@Kwargs{})
    @ Turing.Inference ~/.julia/packages/Turing/gtb9I/src/mcmc/hmc.jl:207
 [25] step(rng::Random.TaskLocalRNG, model::DynamicPPL.Model{…}, spl::DynamicPPL.Sampler{…}; initial_params::Nothing, kwargs::@Kwargs{…})
    @ DynamicPPL ~/.julia/packages/DynamicPPL/R1CV0/src/sampler.jl:133
 [26] step
    @ ~/.julia/packages/DynamicPPL/R1CV0/src/sampler.jl:116 [inlined]
 [27] macro expansion
    @ ~/.julia/packages/AbstractMCMC/z4BsN/src/sample.jl:188 [inlined]
 [28] macro expansion
    @ ~/.julia/packages/AbstractMCMC/z4BsN/src/logging.jl:137 [inlined]
 [29] mcmcsample(rng::Random.TaskLocalRNG, model::DynamicPPL.Model{…}, sampler::DynamicPPL.Sampler{…}, N::Int64; progress::Bool, progressname::String, callback::Nothing, num_warmup::Int64, discard_initial::Int64, thinning::Int64, chain_type::Type, initial_state::Nothing, kwargs::@Kwargs{…})
    @ AbstractMCMC ~/.julia/packages/AbstractMCMC/z4BsN/src/sample.jl:168
 [30] sample(rng::Random.TaskLocalRNG, model::DynamicPPL.Model{…}, sampler::DynamicPPL.Sampler{…}, N::Int64; chain_type::Type, resume_from::Nothing, initial_state::Nothing, progress::Bool, nadapts::Int64, discard_adapt::Bool, discard_initial::Int64, kwargs::@Kwargs{})
    @ Turing.Inference ~/.julia/packages/Turing/gtb9I/src/mcmc/hmc.jl:117
 [31] sample
    @ ~/.julia/packages/Turing/gtb9I/src/mcmc/hmc.jl:86 [inlined]
 [32] #sample#112
    @ ~/.julia/packages/Turing/gtb9I/src/mcmc/abstractmcmc.jl:31 [inlined]
 [33] sample
    @ ~/.julia/packages/Turing/gtb9I/src/mcmc/abstractmcmc.jl:22 [inlined]
 [34] #sample#111
    @ ~/.julia/packages/Turing/gtb9I/src/mcmc/abstractmcmc.jl:19 [inlined]
 [35] sample(model::DynamicPPL.Model{…}, alg::NUTS{…}, N::Int64)
    @ Turing.Inference ~/.julia/packages/Turing/gtb9I/src/mcmc/abstractmcmc.jl:16
 [36] top-level scope
    @ ~/workspace/Lo2Hi-IV-Est-AggVAE/src/jl/gpVAEpreVv3.jl:366
Some type information was truncated. Use `show(err)` to see complete types.

Hello! Without going too much into detail about why converting to Float32 doesn’t work, have you tried using an alternative AD backend such as Mooncake or Enzyme? Although please note that both of these require Julia 1.11 for now, as they haven’t yet been fully updated to work with Julia 1.12.

Example code for Mooncake:

using Turing, Mooncake

# define model exactly the same way...

posterior_prediction = sample(model, NUTS(; adtype=AutoMooncake()), 100)
1 Like

Thanks, In order to do this should I also train the VAE using Mooncake (Since I am using a trained decoder). For training the VAE I used Zygote AD.

I don’t think it should make a difference, the AD in Turing is going to be separate from the AD in the autoencoder. (That said I don’t have experience with Lux.jl, so I could be wrong)

1 Like