How to use Lux.apply for a composite model

I am following the CVAE tutorial in Lux (Convolutional VAE for MNIST using Reactant | Lux.jl Docs) to construct a VAE .

Everything was all good till I came to the loss function which looks like this

function loss_function(model, ps, st, X)
    (y, μ, logσ²), st = model(X, ps, st)
    reconstruction_loss = MSELoss(; agg=sum)(y, X)
    kldiv_loss = -sum(1 .+ logσ² .- μ .^ 2 .- exp.(logσ²)) / 2
    loss = reconstruction_loss + kldiv_loss
    return loss, st, (; y, μ, logσ², reconstruction_loss, kldiv_loss)
end

When I attempt to test the function like this

X = gp_aggr #(58,*)
vae = VAE(Random.default_rng(), 50, 40, 58, 58)
ps, st = Lux.setup(Random.default_rng(), vae)
loss_function(vae, ps, st, X)

I get the following error

{
	"name": "MethodError",
	"message": "MethodError: objects of type VAE{CompactLuxLayer{:₋₋₋no_special_dispatch₋₋₋, var\"#15#17\", Nothing, @NamedTuple{fc1::Dense{typeof(elu), Int64, Int64, Nothing, Nothing, Static.True}, fc_mu::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}, fc_logvar::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Lux.CompactMacroImpl.ValueStorage{@NamedTuple{}, @NamedTuple{rng::Returns{TaskLocalRNG}}}, Tuple{Tuple{}, Tuple{}}}, CompactLuxLayer{:₋₋₋no_special_dispatch₋₋₋, var\"#19#21\", Nothing, @NamedTuple{fc1::Dense{typeof(elu), Int64, Int64, Nothing, Nothing, Static.True}, fc2::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Lux.CompactMacroImpl.ValueStorage{@NamedTuple{}, @NamedTuple{}}, Tuple{Tuple{}, Tuple{}}}} are not callable
The object of type `VAE{CompactLuxLayer{:₋₋₋no_special_dispatch₋₋₋, var\"#15#17\", Nothing, @NamedTuple{fc1::Dense{typeof(elu), Int64, Int64, Nothing, Nothing, Static.True}, fc_mu::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}, fc_logvar::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Lux.CompactMacroImpl.ValueStorage{@NamedTuple{}, @NamedTuple{rng::Returns{TaskLocalRNG}}}, Tuple{Tuple{}, Tuple{}}}, CompactLuxLayer{:₋₋₋no_special_dispatch₋₋₋, var\"#19#21\", Nothing, @NamedTuple{fc1::Dense{typeof(elu), Int64, Int64, Nothing, Nothing, Static.True}, fc2::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Lux.CompactMacroImpl.ValueStorage{@NamedTuple{}, @NamedTuple{}}, Tuple{Tuple{}, Tuple{}}}}` exists, but no method is defined for this combination of argument types when trying to treat it as a callable object.",
	"stack": "MethodError: objects of type VAE{CompactLuxLayer{:₋₋₋no_special_dispatch₋₋₋, var\"#15#17\", Nothing, @NamedTuple{fc1::Dense{typeof(elu), Int64, Int64, Nothing, Nothing, Static.True}, fc_mu::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}, fc_logvar::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Lux.CompactMacroImpl.ValueStorage{@NamedTuple{}, @NamedTuple{rng::Returns{TaskLocalRNG}}}, Tuple{Tuple{}, Tuple{}}}, CompactLuxLayer{:₋₋₋no_special_dispatch₋₋₋, var\"#19#21\", Nothing, @NamedTuple{fc1::Dense{typeof(elu), Int64, Int64, Nothing, Nothing, Static.True}, fc2::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Lux.CompactMacroImpl.ValueStorage{@NamedTuple{}, @NamedTuple{}}, Tuple{Tuple{}, Tuple{}}}} are not callable
The object of type `VAE{CompactLuxLayer{:₋₋₋no_special_dispatch₋₋₋, var\"#15#17\", Nothing, @NamedTuple{fc1::Dense{typeof(elu), Int64, Int64, Nothing, Nothing, Static.True}, fc_mu::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}, fc_logvar::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Lux.CompactMacroImpl.ValueStorage{@NamedTuple{}, @NamedTuple{rng::Returns{TaskLocalRNG}}}, Tuple{Tuple{}, Tuple{}}}, CompactLuxLayer{:₋₋₋no_special_dispatch₋₋₋, var\"#19#21\", Nothing, @NamedTuple{fc1::Dense{typeof(elu), Int64, Int64, Nothing, Nothing, Static.True}, fc2::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Lux.CompactMacroImpl.ValueStorage{@NamedTuple{}, @NamedTuple{}}, Tuple{Tuple{}, Tuple{}}}}` exists, but no method is defined for this combination of argument types when trying to treat it as a callable object.

Stacktrace:
 [1] loss_function(model::VAE{CompactLuxLayer{:₋₋₋no_special_dispatch₋₋₋, var\"#15#17\", Nothing, @NamedTuple{fc1::Dense{typeof(elu), Int64, Int64, Nothing, Nothing, Static.True}, fc_mu::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}, fc_logvar::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Lux.CompactMacroImpl.ValueStorage{@NamedTuple{}, @NamedTuple{rng::Returns{TaskLocalRNG}}}, Tuple{Tuple{}, Tuple{}}}, CompactLuxLayer{:₋₋₋no_special_dispatch₋₋₋, var\"#19#21\", Nothing, @NamedTuple{fc1::Dense{typeof(elu), Int64, Int64, Nothing, Nothing, Static.True}, fc2::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Lux.CompactMacroImpl.ValueStorage{@NamedTuple{}, @NamedTuple{}}, Tuple{Tuple{}, Tuple{}}}}, ps::@NamedTuple{encoder::@NamedTuple{fc1::@NamedTuple{weight::Matrix{Float32}, bias::Vector{Float32}}, fc_mu::@NamedTuple{weight::Matrix{Float32}, bias::Vector{Float32}}, fc_logvar::@NamedTuple{weight::Matrix{Float32}, bias::Vector{Float32}}}, decoder::@NamedTuple{fc1::@NamedTuple{weight::Matrix{Float32}, bias::Vector{Float32}}, fc2::@NamedTuple{weight::Matrix{Float32}, bias::Vector{Float32}}}}, st::@NamedTuple{encoder::@NamedTuple{fc1::@NamedTuple{}, fc_mu::@NamedTuple{}, fc_logvar::@NamedTuple{}, rng::TaskLocalRNG}, decoder::@NamedTuple{fc1::@NamedTuple{}, fc2::@NamedTuple{}}}, X::Matrix{Float64})
   @ Main ~/workspace/Lo2Hi-IV-Est-AggVAE/julia/jl_notebook_cell_df34fa98e69747e1a8f8a730347b8e2f_X20sZmlsZQ==.jl:2
 [2] top-level scope
   @ ~/workspace/Lo2Hi-IV-Est-AggVAE/julia/jl_notebook_cell_df34fa98e69747e1a8f8a730347b8e2f_X20sZmlsZQ==.jl:12"
}

I am gonna guess thats this is because you cant do model(X, ps, st) - Not so sure but might be the old syntax. This needs to be likely replaced by Lux.apply(vae, ps,st,X), I also tried Lux.apply(vae,X, ps,st). But either way you get the following error

{
	"name": "MethodError",
	"message": "MethodError: objects of type VAE{CompactLuxLayer{:₋₋₋no_special_dispatch₋₋₋, var\"#15#17\", Nothing, @NamedTuple{fc1::Dense{typeof(elu), Int64, Int64, Nothing, Nothing, Static.True}, fc_mu::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}, fc_logvar::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Lux.CompactMacroImpl.ValueStorage{@NamedTuple{}, @NamedTuple{rng::Returns{TaskLocalRNG}}}, Tuple{Tuple{}, Tuple{}}}, CompactLuxLayer{:₋₋₋no_special_dispatch₋₋₋, var\"#19#21\", Nothing, @NamedTuple{fc1::Dense{typeof(elu), Int64, Int64, Nothing, Nothing, Static.True}, fc2::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Lux.CompactMacroImpl.ValueStorage{@NamedTuple{}, @NamedTuple{}}, Tuple{Tuple{}, Tuple{}}}} are not callable
The object of type `VAE{CompactLuxLayer{:₋₋₋no_special_dispatch₋₋₋, var\"#15#17\", Nothing, @NamedTuple{fc1::Dense{typeof(elu), Int64, Int64, Nothing, Nothing, Static.True}, fc_mu::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}, fc_logvar::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Lux.CompactMacroImpl.ValueStorage{@NamedTuple{}, @NamedTuple{rng::Returns{TaskLocalRNG}}}, Tuple{Tuple{}, Tuple{}}}, CompactLuxLayer{:₋₋₋no_special_dispatch₋₋₋, var\"#19#21\", Nothing, @NamedTuple{fc1::Dense{typeof(elu), Int64, Int64, Nothing, Nothing, Static.True}, fc2::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Lux.CompactMacroImpl.ValueStorage{@NamedTuple{}, @NamedTuple{}}, Tuple{Tuple{}, Tuple{}}}}` exists, but no method is defined for this combination of argument types when trying to treat it as a callable object.",
	"stack": "MethodError: objects of type VAE{CompactLuxLayer{:₋₋₋no_special_dispatch₋₋₋, var\"#15#17\", Nothing, @NamedTuple{fc1::Dense{typeof(elu), Int64, Int64, Nothing, Nothing, Static.True}, fc_mu::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}, fc_logvar::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Lux.CompactMacroImpl.ValueStorage{@NamedTuple{}, @NamedTuple{rng::Returns{TaskLocalRNG}}}, Tuple{Tuple{}, Tuple{}}}, CompactLuxLayer{:₋₋₋no_special_dispatch₋₋₋, var\"#19#21\", Nothing, @NamedTuple{fc1::Dense{typeof(elu), Int64, Int64, Nothing, Nothing, Static.True}, fc2::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Lux.CompactMacroImpl.ValueStorage{@NamedTuple{}, @NamedTuple{}}, Tuple{Tuple{}, Tuple{}}}} are not callable
The object of type `VAE{CompactLuxLayer{:₋₋₋no_special_dispatch₋₋₋, var\"#15#17\", Nothing, @NamedTuple{fc1::Dense{typeof(elu), Int64, Int64, Nothing, Nothing, Static.True}, fc_mu::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}, fc_logvar::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Lux.CompactMacroImpl.ValueStorage{@NamedTuple{}, @NamedTuple{rng::Returns{TaskLocalRNG}}}, Tuple{Tuple{}, Tuple{}}}, CompactLuxLayer{:₋₋₋no_special_dispatch₋₋₋, var\"#19#21\", Nothing, @NamedTuple{fc1::Dense{typeof(elu), Int64, Int64, Nothing, Nothing, Static.True}, fc2::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Lux.CompactMacroImpl.ValueStorage{@NamedTuple{}, @NamedTuple{}}, Tuple{Tuple{}, Tuple{}}}}` exists, but no method is defined for this combination of argument types when trying to treat it as a callable object.

Stacktrace:
 [1] apply(model::VAE{CompactLuxLayer{:₋₋₋no_special_dispatch₋₋₋, var\"#15#17\", Nothing, @NamedTuple{fc1::Dense{typeof(elu), Int64, Int64, Nothing, Nothing, Static.True}, fc_mu::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}, fc_logvar::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Lux.CompactMacroImpl.ValueStorage{@NamedTuple{}, @NamedTuple{rng::Returns{TaskLocalRNG}}}, Tuple{Tuple{}, Tuple{}}}, CompactLuxLayer{:₋₋₋no_special_dispatch₋₋₋, var\"#19#21\", Nothing, @NamedTuple{fc1::Dense{typeof(elu), Int64, Int64, Nothing, Nothing, Static.True}, fc2::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Lux.CompactMacroImpl.ValueStorage{@NamedTuple{}, @NamedTuple{}}, Tuple{Tuple{}, Tuple{}}}}, x::Matrix{Float64}, ps::@NamedTuple{encoder::@NamedTuple{fc1::@NamedTuple{weight::Matrix{Float32}, bias::Vector{Float32}}, fc_mu::@NamedTuple{weight::Matrix{Float32}, bias::Vector{Float32}}, fc_logvar::@NamedTuple{weight::Matrix{Float32}, bias::Vector{Float32}}}, decoder::@NamedTuple{fc1::@NamedTuple{weight::Matrix{Float32}, bias::Vector{Float32}}, fc2::@NamedTuple{weight::Matrix{Float32}, bias::Vector{Float32}}}}, st::@NamedTuple{encoder::@NamedTuple{fc1::@NamedTuple{}, fc_mu::@NamedTuple{}, fc_logvar::@NamedTuple{}, rng::TaskLocalRNG}, decoder::@NamedTuple{fc1::@NamedTuple{}, fc2::@NamedTuple{}}})
   @ LuxCore ~/.julia/packages/LuxCore/XUV80/src/LuxCore.jl:155
 [2] loss_function(model::VAE{CompactLuxLayer{:₋₋₋no_special_dispatch₋₋₋, var\"#15#17\", Nothing, @NamedTuple{fc1::Dense{typeof(elu), Int64, Int64, Nothing, Nothing, Static.True}, fc_mu::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}, fc_logvar::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Lux.CompactMacroImpl.ValueStorage{@NamedTuple{}, @NamedTuple{rng::Returns{TaskLocalRNG}}}, Tuple{Tuple{}, Tuple{}}}, CompactLuxLayer{:₋₋₋no_special_dispatch₋₋₋, var\"#19#21\", Nothing, @NamedTuple{fc1::Dense{typeof(elu), Int64, Int64, Nothing, Nothing, Static.True}, fc2::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Lux.CompactMacroImpl.ValueStorage{@NamedTuple{}, @NamedTuple{}}, Tuple{Tuple{}, Tuple{}}}}, ps::@NamedTuple{encoder::@NamedTuple{fc1::@NamedTuple{weight::Matrix{Float32}, bias::Vector{Float32}}, fc_mu::@NamedTuple{weight::Matrix{Float32}, bias::Vector{Float32}}, fc_logvar::@NamedTuple{weight::Matrix{Float32}, bias::Vector{Float32}}}, decoder::@NamedTuple{fc1::@NamedTuple{weight::Matrix{Float32}, bias::Vector{Float32}}, fc2::@NamedTuple{weight::Matrix{Float32}, bias::Vector{Float32}}}}, st::@NamedTuple{encoder::@NamedTuple{fc1::@NamedTuple{}, fc_mu::@NamedTuple{}, fc_logvar::@NamedTuple{}, rng::TaskLocalRNG}, decoder::@NamedTuple{fc1::@NamedTuple{}, fc2::@NamedTuple{}}}, X::Matrix{Float64})
   @ Main ~/workspace/Lo2Hi-IV-Est-AggVAE/julia/jl_notebook_cell_df34fa98e69747e1a8f8a730347b8e2f_X20sZmlsZQ==.jl:2
 [3] top-level scope
   @ ~/workspace/Lo2Hi-IV-Est-AggVAE/julia/jl_notebook_cell_df34fa98e69747e1a8f8a730347b8e2f_X20sZmlsZQ==.jl:12"
}

I think this error is cause the VAE is composite of the encoder and the decoder and not just one model - I could be wrong here. Anyways any idea what I can do to fix this?

Here are the rest of the functions just for completeness

function vae_encoder(rng, inp_dims::Int,h_dims::Int, z_dims::Int)
    return @compact(;
        fc1 = Dense(inp_dims,h_dims, elu), #(inp_dims, *) -> (h_dims, *)
        fc_mu = Dense(h_dims, z_dims), #(h_dims, *) -> (z_dims, *)
        fc_logvar = Dense(h_dims, z_dims), #(h_dims, *) -> (z_dims, *)
        rng
    ) do x 
        h = fc1(x) #(h_dims, *)
        μ = fc_mu(h) #(z_dims, *)
        logσ² = fc_logvar(h) #(z_dims, *) 
        # Clamp log variance for numerical stability 
        T = eltype(logσ²)
        logσ² = clamp.(logσ², -T(20.0f0), T(10.0f0)) #(z_dims, *)
        σ = exp.(logσ² .* T(0.5)) #(z_dims, *)
        # Generate a tensor of random values from a normal distribution 
        ϵ = randn_like(Lux.replicate(rng), σ)
        # Reparameterization trick
        z = μ .+ σ .* ϵ
        return z, μ, logσ² 
    end
end
function vae_decoder(z_dim, h_dims, out_dims)
    return @compact(;
        fc1 = Dense(z_dim, h_dims, elu), #(z_dim, *) -> (h_dims, *)
        fc2 = Dense(h_dims, out_dims), #(h_dims, *) -> (out_dims, *)
    ) do Z 
        h = fc1(Z) #(h_dims, *)
        gp_recon = fc2(h)
    end
end
@concrete struct VAE <: AbstractLuxContainerLayer{(:encoder, :decoder)}
    encoder <: AbstractLuxLayer
    decoder <: AbstractLuxLayer 
end 

function VAE(rng, h_dims, z_dims, inp_dims, out_dims)
    decoder = vae_decoder(z_dims, h_dims, out_dims) 
    encoder = vae_encoder(rng, inp_dims, h_dims, z_dims)
    return VAE(encoder, decoder) 
end 

function encode(vae::VAE, x, ps, st)
    (z, _,_), st_enc = vae.encoder(x, ps.encoder, st.encoder)
    return z, (;encoder = st_enc, st.decoder) #you need to return decoder to update sttes
end

function decode(vae::VAE, z, ps, st)
    gp_rec, st_dec = vae.decoder(z, ps.decoder, st.decoder)
    return gp_rec, (;decoder = st_dec, st.encoder) #you need to return encoder to update its state
end

You are missing the (::VAE)(x, ps, st) dispatch. See the Lux Interface | Lux.jl Docs docs

1 Like