Challenges with Zygote Differentiation in Multivariate Mixture Models as input noise of an implicit machine learning method

The following implicit machine learning method works well when the input noise is, for example, a multidimensional Gaussian

function sliced_invariant_statistical_loss_optimized_2(nn_model, loader, hparams)
    @assert loader.batchsize == hparams.samples
    @assert length(loader) == hparams.epochs
    losses = Vector{Float32}()
    optim = Flux.setup(Flux.Adam(hparams.η), nn_model)

    @showprogress for data in loader
        Ω = -> sample_random_direction(size(data)[1]), 1:(hparams.m))
        loss, grads = Flux.withgradient(nn_model) do nn
            total = 0.0f0
            for ω in Ω
                aₖ = zeros(Float32, hparams.K + 1)  # Reset aₖ for each new ω

                # Generate all random numbers in one go
                x_batch = rand(hparams.noise_model, hparams.samples * hparams.K)

                # Process batch through nn_model
                yₖ_batch = nn(Float32.(x_batch))

                s = Matrix(ω' * yₖ_batch)

                # Pre-compute column indices for slicing
                start_cols = hparams.K * (1:(hparams.samples - 1))
                end_cols = hparams.K * (2:(hparams.samples)) .- 1

                # Create slices of 's' for all 'aₖ_slice'
                aₖ_slices = [
                    s[:, start_col:(end_col - 1)] for
                    (start_col, end_col) in zip(start_cols, end_cols)

                # Compute the dot products for all iterations at once
                ω_data_dot_products = [dot(ω, data[:, i]) for i in 2:(hparams.samples)]

                # Apply 'generate_aₖ' for each pair and sum the results
                aₖ = sum([
                    generate_aₖ(aₖ_slice, ω_data_dot_product) for
                    (aₖ_slice, ω_data_dot_product) in zip(aₖ_slices, ω_data_dot_products)
                total += scalar_diff(aₖ ./ sum(aₖ))
            total / hparams.m
        Flux.update!(optim, nn_model, grads[1])
        push!(losses, loss)
    return losses

But when I introduce a Mixture Model in the following way:

# Mean vector (zero vector of length dim)
mean_vector_1 = device(zeros(z_dim))
mean_vector_2 = device(ones(z_dim))

# Covariance matrix (identity matrix of size dim x dim)
cov_matrix_1 = device(Diagonal(ones(z_dim)))
cov_matrix_2 = device(Diagonal(ones(z_dim)))

# Create the multivariate normal distribution
noise_model = device(MvNormal(mean_vector_1, cov_matrix_1))
noise_model = device(
        MvNormal(mean_vector_1, cov_matrix_1), MvNormal(mean_vector_2, cov_matrix_2)

Then I get the following error:

This error occurs when you ask Zygote to differentiate operations that change
the elements of arrays in place (e.g., setting values with x .= ...)```

What could be the reason for this different behavior in one case and the other?

Can you provide a complete working example, including imports and setup code?


The imports are the following,

using ISL
using Flux
using LinearAlgebra
using Plots
using Distributions

And the functions that the method uses are

function _sigmoid(ŷ::Matrix{T}, y::T) where {T<:AbstractFloat}
    return sigmoid_fast.((y .- ŷ) .* 10.0f0)

function ϕ(yₖ::Matrix{T}, yₙ::T) where {T<:AbstractFloat}
    #return sum(_leaky_relu(yₖ, yₙ))
    return sum(_sigmoid(yₖ, yₙ))

function ψₘ(y::T, m::Int64) where {T<:AbstractFloat}
    stddev = 0.1f0
    return exp((-0.5f0 * ((y - m) / stddev)^2))

function γ(yₖ::Matrix{T}, yₙ::T, m::Int64) where {T<:AbstractFloat}
    eₘ(m) = [j == m ? T(1.0) : T(0.0) for j in 0:length(yₖ)]
    return eₘ(m) * ψₘ(ϕ(yₖ, yₙ), m)

@inline function generate_aₖ(ŷ::Matrix{T}, y::T) where {T<:AbstractFloat}
    return sum([γ(ŷ, y, k) for k in 0:length(ŷ)])

I’m sorry for not providing more details, it’s a bit cumbersome to explain what each function does and why it’s there in just a few lines, and I hope it’s not necessary. But if you think it helps, I can comment on it. Thank you very much!

Could you also provide the beginning of whatever loop or function you display in the very first chunk of code?

Yes! There was a typo! Should be solved!

Does this mean that you have a working version with a non-mixture model?

Are you trying to differentiate through random number generation?

Yes! and yes!

Okay, so typically it’s easier to help you if you isolate a Minimum Working Example (MWE) first. Apparently this is the crux of your issue:

julia> using Zygote, Distributions

julia> f(μ) = rand(MvNormal(μ, I))
f (generic function with 1 method)

julia> g(μ) = rand(MixtureModel([MvNormal(μ, I), MvNormal(μ, I)]))
g (generic function with 1 method)

julia> f([1])
1-element Vector{Float64}:

julia> g([1])
1-element Vector{Float64}:

julia> Zygote.jacobian(f, [1])

julia> Zygote.jacobian(g, [1])
ERROR: Mutating arrays is not supported -- called setindex!(Vector{Float64}, ...)
This error occurs when you ask Zygote to differentiate operations that change
the elements of arrays in place (e.g. setting values with x .= ...)

Possible fixes:
- avoid mutating operations (preferred)
- or read the documentation and solutions for this error

First of all, this shows us that differentiating through random number generation is not well-defined behavior (see the zero jacobian), so I think it’s worth exploring why you feel the need to do that. There might be better alternatives like the reparametrization trick.

But setting that aside for a moment, it seems that sampling from the mixture model leads to mutating operations, while sampling from the distribution alone doesn’t. I’m honestly not sure why.

1 Like

Perfect, I think the last thing you mentioned is the problem! I suppose it can be bypassed by defining a distribution that is a Bernoulli of two Normal distributions.

Either way, it seems a bit strange that the behavior is different for the case without a mixture and with a mixture. I don’t know if it would be worth mentioning it in Distributions.jl.

Out of curiosity, can we come back to this part?

The method we propose is based on the generation of random noise to see how the implicit model behaves and to see if the data generated by this model, after passing it random noise, meet a certain statistic in relation to the real data.

I wouldn’t know how to do this any other way.