# Understanding and Overcoming Zygote's Functional Limitations for distributed

Hello everyone, I am trying to parallelize a cost function. I am using Distributed and Zygote as an autodifferentiator for this. The example is as follows

``````function sliced_invariant_statistical_loss_distributed(nn_model, loader, hparams::HyperParamsSlicedISL)
losses = []

Ω = [sample_random_direction(size(data)[1]) for _ in 1:(hparams.m)]

# Distributed computation for the outer loop
total = @distributed (+) for ω in Ω
compute_loss_for_ω(nn, ω, data, hparams)
end

total / hparams.m
end

push!(losses, loss)
end
return losses
end;

# Helper function to compute loss for a single ω
function compute_loss_for_ω(nn, ω, data, hparams)
aₖ = zeros(hparams.K + 1)

for i in 1:hparams.samples
x = rand(hparams.noise_model, hparams.K)
yₖ = nn(x)
s = collect(reshape(ω' * yₖ, 1, hparams.K))
aₖ += generate_aₖ(s, ω ⋅ data[:, i])
end

scalar_diff(aₖ ./ sum(aₖ))
end
``````

The error that I am encountering is the following:

``````nested task error: Compiling Tuple{typeof(Distributed.run_work_thunk),
Distributed.var"#153#154"{ISL.var"#48#51"{Chain{Tuple{Dense{typeof(identity),
Matrix{Float32}, Vector{Float32}}, typeof(elu), Dense{typeof(identity), Matrix{Float32},
Vector{Float32}}, typeof(elu), Dense{typeof(identity), Matrix{Float32},
Vector{Float32}}, typeof(elu), Dense{typeof(identity), Matrix{Float32},
Vector{Float32}}}}, HyperParamsSlicedISL, Matrix{Float32}}, Tuple{typeof(+),
Vector{Vector{Float32}}, Int64, Int64}, Base.Pairs{Symbol, Union{}, Tuple{},
NamedTuple{(), Tuple{}}}}, Bool}: try/catch is not supported. Refer to the Zygote
documentation for fixes. https://fluxml.ai/Zygote.jl/latest/limitations
``````

It seems to be a limitation of Zygote. Can someone explain to me what the actual limitation is in this case and if there is any way to circumvent it?

I believe the only part of the Distributed API Zygote supports is `pmap`, and even then only in a limited capacity. If you can’t make do with that, then consider writing custom rules for this or figuring out a way to move the distributed part outside of the (with)gradient call.

I’ve tried what you mentioned, and it doesn’t seem to really work. It doesn’t give errors, but the autodifferentiator returns as gradient ‘nothing’. I have also tried with `Threads.@spawn`, as I have seen that it is what they use in tests in the Zygote project, but with the same result as before. Do you know if there is another automatic differentiator, even if it’s experimental, that is compatible with multithreading or multiprocessing parallelization?

I would need to know what exactly you tried to help here, because I mentioned more than one thing!

I have tried pmap

``````function sliced_invariant_statistical_loss_distributed_pmap(nn_model, loader, hparams::HyperParamsSlicedISL)
losses = []

Ω = [sample_random_direction(size(data)[1]) for _ in 1:(hparams.m)]

# Using pmap for parallel computation
total = sum(pmap(ω -> compute_loss_for_ω(nn, ω, data, hparams), Ω))

total / hparams.m
end

push!(losses, loss)
end
return losses
end

# Helper function to compute loss for a single ω, remains unchanged
function compute_loss_for_ω(nn, ω, data, hparams)
aₖ = zeros(hparams.K + 1)

# Threaded computation for the inner loop
for i in 1:hparams.samples
x = rand(hparams.noise_model, hparams.K)
yₖ = nn(x)
s = collect(reshape(ω' * yₖ, 1, hparams.K))
aₖ += generate_aₖ(s, ω ⋅ data[:, i])
end

scalar_diff(aₖ ./ sum(aₖ))
end
``````

without success, grads[1] is equal to `nothing`. I also tries,

``````function compute_forward_pass(nn, ω, data, hparams)
aₖ = zeros(hparams.K + 1)
for i in 1:hparams.samples
x = Float32.(rand(hparams.noise_model, hparams.K))
yₖ = nn(x)
s = Matrix(reshape(ω' * yₖ, 1, hparams.K))  # Convert to Matrix
aₖ += generate_aₖ(s, ω ⋅ data[:, i])
end
return aₖ
end
losses = Vector{Float32}()

Ω = [sample_random_direction(size(data)[1]) for _ in 1:hparams.m]

# Perform the forward pass in parallel
forward_pass_results = [Threads.@spawn compute_forward_pass(nn_model, ω, data, hparams) for ω in Ω]
aₖ_results = fetch.(forward_pass_results)

total_loss = sum([scalar_diff(aₖ_result ./ sum(aₖ_result)) for aₖ_result in aₖ_results]) / hparams.m
total_loss
end

For the first example, have you checked that non-nothing gradients are returned when using normal `map` instead of `pmap`? Unfortunately in the absence of a MWE I can’t check anything myself.
For the second example, you have to run the forward pass inside of the (with)gradient callback for AD to see it. When I said " figuring out a way to move the distributed part outside of the (with)gradient call", I was thinking of something more along the lines of whether you could rewrite your algorithm to do `pmap(... -> gradient(...))` (or `@distributed + gradient(...)`, this is all pseudo-logic) instead of `gradient(... -> pmap(...))` as you have it now.