The following implicit machine learning method works well when the input noise is, for example, a multidimensional Gaussian
function sliced_invariant_statistical_loss_optimized_2(nn_model, loader, hparams)
@assert loader.batchsize == hparams.samples
@assert length(loader) == hparams.epochs
losses = Vector{Float32}()
optim = Flux.setup(Flux.Adam(hparams.η), nn_model)
@showprogress for data in loader
Ω = ThreadsX.map(_ -> sample_random_direction(size(data)[1]), 1:(hparams.m))
loss, grads = Flux.withgradient(nn_model) do nn
total = 0.0f0
for ω in Ω
aₖ = zeros(Float32, hparams.K + 1) # Reset aₖ for each new ω
# Generate all random numbers in one go
x_batch = rand(hparams.noise_model, hparams.samples * hparams.K)
# Process batch through nn_model
yₖ_batch = nn(Float32.(x_batch))
s = Matrix(ω' * yₖ_batch)
# Pre-compute column indices for slicing
start_cols = hparams.K * (1:(hparams.samples - 1))
end_cols = hparams.K * (2:(hparams.samples)) .- 1
# Create slices of 's' for all 'aₖ_slice'
aₖ_slices = [
s[:, start_col:(end_col - 1)] for
(start_col, end_col) in zip(start_cols, end_cols)
]
# Compute the dot products for all iterations at once
ω_data_dot_products = [dot(ω, data[:, i]) for i in 2:(hparams.samples)]
# Apply 'generate_aₖ' for each pair and sum the results
aₖ = sum([
generate_aₖ(aₖ_slice, ω_data_dot_product) for
(aₖ_slice, ω_data_dot_product) in zip(aₖ_slices, ω_data_dot_products)
])
total += scalar_diff(aₖ ./ sum(aₖ))
end
total / hparams.m
end
Flux.update!(optim, nn_model, grads[1])
push!(losses, loss)
end
return losses
end
But when I introduce a Mixture Model in the following way:
# Mean vector (zero vector of length dim)
mean_vector_1 = device(zeros(z_dim))
mean_vector_2 = device(ones(z_dim))
# Covariance matrix (identity matrix of size dim x dim)
cov_matrix_1 = device(Diagonal(ones(z_dim)))
cov_matrix_2 = device(Diagonal(ones(z_dim)))
# Create the multivariate normal distribution
noise_model = device(MvNormal(mean_vector_1, cov_matrix_1))
noise_model = device(
MixtureModel([
MvNormal(mean_vector_1, cov_matrix_1), MvNormal(mean_vector_2, cov_matrix_2)
]),
)
Then I get the following error:
This error occurs when you ask Zygote to differentiate operations that change
the elements of arrays in place (e.g., setting values with x .= ...)```
What could be the reason for this different behavior in one case and the other?