Understanding and Overcoming Zygote's Functional Limitations for distributed

Hello everyone, I am trying to parallelize a cost function. I am using Distributed and Zygote as an autodifferentiator for this. The example is as follows

function sliced_invariant_statistical_loss_distributed(nn_model, loader, hparams::HyperParamsSlicedISL)
    @assert loader.batchsize == hparams.samples
    @assert length(loader) == hparams.epochs
    losses = []
    optim = Flux.setup(Flux.Adam(hparams.η), nn_model)

    @showprogress for data in loader
        loss, grads = Flux.withgradient(nn_model) do nn
            Ω = [sample_random_direction(size(data)[1]) for _ in 1:(hparams.m)]

            # Distributed computation for the outer loop
            total = @distributed (+) for ω in Ω
                compute_loss_for_ω(nn, ω, data, hparams)
            end

            total / hparams.m
        end

        Flux.update!(optim, nn_model, grads[1])
        push!(losses, loss)
    end
    return losses
end;

# Helper function to compute loss for a single ω
function compute_loss_for_ω(nn, ω, data, hparams)
    aₖ = zeros(hparams.K + 1)

    for i in 1:hparams.samples
        x = rand(hparams.noise_model, hparams.K)
        yₖ = nn(x)
        s = collect(reshape(ω' * yₖ, 1, hparams.K))
        aₖ += generate_aₖ(s, ω ⋅ data[:, i])
    end

    scalar_diff(aₖ ./ sum(aₖ))
end

The error that I am encountering is the following:

nested task error: Compiling Tuple{typeof(Distributed.run_work_thunk), 
Distributed.var"#153#154"{ISL.var"#48#51"{Chain{Tuple{Dense{typeof(identity), 
Matrix{Float32}, Vector{Float32}}, typeof(elu), Dense{typeof(identity), Matrix{Float32},
 Vector{Float32}}, typeof(elu), Dense{typeof(identity), Matrix{Float32}, 
Vector{Float32}}, typeof(elu), Dense{typeof(identity), Matrix{Float32}, 
Vector{Float32}}}}, HyperParamsSlicedISL, Matrix{Float32}}, Tuple{typeof(+), 
Vector{Vector{Float32}}, Int64, Int64}, Base.Pairs{Symbol, Union{}, Tuple{}, 
NamedTuple{(), Tuple{}}}}, Bool}: try/catch is not supported. Refer to the Zygote 
documentation for fixes. https://fluxml.ai/Zygote.jl/latest/limitations

It seems to be a limitation of Zygote. Can someone explain to me what the actual limitation is in this case and if there is any way to circumvent it?

I believe the only part of the Distributed API Zygote supports is pmap, and even then only in a limited capacity. If you can’t make do with that, then consider writing custom rules for this or figuring out a way to move the distributed part outside of the (with)gradient call.

I’ve tried what you mentioned, and it doesn’t seem to really work. It doesn’t give errors, but the autodifferentiator returns as gradient ‘nothing’. I have also tried with Threads.@spawn, as I have seen that it is what they use in tests in the Zygote project, but with the same result as before. Do you know if there is another automatic differentiator, even if it’s experimental, that is compatible with multithreading or multiprocessing parallelization?

I would need to know what exactly you tried to help here, because I mentioned more than one thing!

I have tried pmap

function sliced_invariant_statistical_loss_distributed_pmap(nn_model, loader, hparams::HyperParamsSlicedISL)
    @assert loader.batchsize == hparams.samples
    @assert length(loader) == hparams.epochs
    losses = []
    optim = Flux.setup(Flux.Adam(hparams.η), nn_model)

    @showprogress for data in loader
        loss, grads = Flux.withgradient(nn_model) do nn
            Ω = [sample_random_direction(size(data)[1]) for _ in 1:(hparams.m)]

            # Using pmap for parallel computation
            total = sum(pmap(ω -> compute_loss_for_ω(nn, ω, data, hparams), Ω))

            total / hparams.m
        end

        Flux.update!(optim, nn_model, grads[1])
        push!(losses, loss)
    end
    return losses
end

# Helper function to compute loss for a single ω, remains unchanged
function compute_loss_for_ω(nn, ω, data, hparams)
    aₖ = zeros(hparams.K + 1)

    # Threaded computation for the inner loop
    for i in 1:hparams.samples
        x = rand(hparams.noise_model, hparams.K)
        yₖ = nn(x)
        s = collect(reshape(ω' * yₖ, 1, hparams.K))
        aₖ += generate_aₖ(s, ω ⋅ data[:, i])
    end

    scalar_diff(aₖ ./ sum(aₖ))
end

without success, grads[1] is equal to nothing. I also tries,

function compute_forward_pass(nn, ω, data, hparams)
    aₖ = zeros(hparams.K + 1)
    for i in 1:hparams.samples
        x = Float32.(rand(hparams.noise_model, hparams.K))
        yₖ = nn(x)
        s = Matrix(reshape(ω' * yₖ, 1, hparams.K))  # Convert to Matrix
        aₖ += generate_aₖ(s, ω ⋅ data[:, i])
    end
    return aₖ
end
function sliced_invariant_statistical_loss_multithreaded_2(nn_model, loader, hparams::HyperParamsSlicedISL)
    @assert loader.batchsize == hparams.samples
    @assert length(loader) == hparams.epochs
    losses = Vector{Float32}()
    optim = Flux.setup(Flux.Adam(hparams.η), nn_model)

    @showprogress for data in loader
        Ω = [sample_random_direction(size(data)[1]) for _ in 1:hparams.m]

        # Perform the forward pass in parallel
        forward_pass_results = [Threads.@spawn compute_forward_pass(nn_model, ω, data, hparams) for ω in Ω]
        aₖ_results = fetch.(forward_pass_results)

        # Compute gradients sequentially
        loss, grads = Flux.withgradient(nn_model) do nn
            total_loss = sum([scalar_diff(aₖ_result ./ sum(aₖ_result)) for aₖ_result in aₖ_results]) / hparams.m
            total_loss
        end

        Flux.update!(optim, nn_model, grads[1])
        push!(losses, loss)
    end
    return losses
end

with same result.

For the first example, have you checked that non-nothing gradients are returned when using normal map instead of pmap? Unfortunately in the absence of a MWE I can’t check anything myself.

For the second example, you have to run the forward pass inside of the (with)gradient callback for AD to see it. When I said " figuring out a way to move the distributed part outside of the (with)gradient call", I was thinking of something more along the lines of whether you could rewrite your algorithm to do pmap(... -> gradient(...)) (or @distributed + gradient(...), this is all pseudo-logic) instead of gradient(... -> pmap(...)) as you have it now.