Is there a way to parallelize portions of code in Flux?

Hello, I have a code where I want to execute a parallel loop within the code portion responsible for automatic differentiation. This is because the cost function I want to compute relies on certain properties of the estimated and real function that are independent of each other, so I could potentially execute them all in parallel. Is there a way to do this? Right now, when I use @distributed, threads, etc., I get a Zigote error.

What kind of Zygote error? Can you post a reproducible example?

Hello, I’m actually starting to feel embarrassed to ask you, I think you’re going to do my homework all by yourself :wink:

I’m not sure if you remember the example from last time, but in the end, I did manage to make it differentiate. The code would be something like this.

model = Chain(
    Dense(1 => 10, tanh),
    Dense(10 => 1),
)|> gpu

μ₁=-1; σ₁ = 0.5; 
μ₂=-2; σ₂ = 0.4;
realModel(ϵ) =  ϵ < 0.5 ? rand(Normal(μ₁, σ₁)) : rand(Normal(μ₂, σ₂))

#Learning with custom loss
μ = 0; stddev = 1
η = 0.001; num_epochs = 200; n_samples = 10000; K = 10
optim = Flux.setup(Flux.Adam(η), model)
losses = []
@showprogress for epoch in 1:num_epochs
    loss, grads = Flux.withgradient(model) do m
        aₖ = zeros(K+1)
        for _ in 1:n_samples
            x = rand(Normal(μ, stddev), K)
            yₖ = m(x')
            y = realModel(rand(Float64))
            aₖ += generate_aₖ(yₖ, y)
        end
        scalar_diff(aₖ ./ sum(aₖ))
    end
    Flux.update!(optim, model, grads[1])
    push!(losses, loss)
end;

In this case, I’m trying to learn a bimodal distribution in one dimension (I only have results showing that it works in one dimension, but it should be true in multiple dimensions as well, although I don’t have the proof yet). We are trying to transform a standard normal distribution (0, 1) into a bimodal distribution (toy example). The inner loop is ‘morally parallelizable’ since the order of simulations and their summations doesn’t matter. Therefore, I would like to execute this in parallel, but when I use @distributed, threads, etc., I get the following Zigote error:

Can't differentiate foreigncall expression $(Expr(:foreigncall, :(:jl_cpu_wake), Nothing, svec(), 0, :(:ccall))). You might want to check the Zygote limitations documentation. 

Let me briefly remind you where each thing comes from.

I have defined a loss function that evaluate the convergence of random generated samples from the model towards a uniform data distribution with respect the real observation.

I generate K random hypothetical observations and compare how many of them are smaller than the real data in the training set. In other words, the model generates K simulations, and we calculate the number of simulations where the generated data is smaller than the actual data. If the model is well trained, this distribution should converge to a uniform distribution.

generated_a_k is construted the following way. However, to simplify, I will just say that executions of generated_a_k are independent.

scalar_diff(aₖ) = sum((aₖ .- (1 ./ length(aₖ))) .^2)

"""
    sigmoid(ŷ, y)

    Sigmoid function centered at y.
"""
function sigmoid(ŷ, y)
    return sigmoid_fast.((ŷ-y)*10)
end;

"""
    ψₘ(y, m)

    Bump function centered at m. Implemented as a gaussian function.
"""
function ψₘ(y, m)
    stddev = 0.1
    return exp.((-0.5 .* ((y .- m) ./ stddev) .^ 2))
end

"""
    ϕ(yₖ, yₙ)

    Sum of the sigmoid function centered at yₙ applied to the vector yₖ.
"""
function ϕ(yₖ, yₙ)
    return sum(sigmoid.(yₙ, yₖ))
end;

"""
    γ(yₖ, yₙ, m)
    
    Calculate the contribution of ψₘ ∘ ϕ(yₖ, yₙ) to the m bin of the histogram (Vector{Float}).
"""
function γ(yₖ, yₙ::Float64, m::Int64)
    eₘ(m) = [j == m ? 1.0 : 0.0 for j in 0:length(yₖ)]
    return eₘ(m) * ψₘ(ϕ(yₖ, yₙ), m)
end;

"""
    γ_fast(yₖ, yₙ, m)

Apply the γ function to the given parameters. 
This function is faster than the original γ function because it uses StaticArrays.
However because Zygote does not support StaticArrays, this function can not be used in the training process.
"""
function γ_fast(yₖ, yₙ::Float64, m::Int64)
    eₘ(m) = SVector{length(yₖ)+1, Float64}(j == m ? 1.0 : 0.0 for j in 0:length(yₖ))
    return eₘ(m) * ψₘ(ϕ(yₖ, yₙ), m)
end;

"""
    generate_aₖ(ŷ, y)

    Generate a one step histogram (Vector{Float}) of the given vector ŷ of K simulted observations and the real data y.
    generate_aₖ(ŷ, y) = ∑ₖ γ(ŷ, y, k)
"""
generate_aₖ(ŷ, y::Float64) = sum([γ(ŷ, y, k) for k in 0:length(ŷ)])

We need to transform the concept of counting (histogram) into a differentiable operation. To do this, I have done the following. First, I have used a sigmoid operation to check if a fictitious observation (out of the K generated) is smaller than the real observation. Obviously, this process can be repeated for the K observations by simply summing them (using the ϕ function). After this, I want to generate differentiable histogram bins. To achieve this, it is sufficient to use a bump function that sums to nearly 1 when the real observation is greater than the K fictitious observations, and it is zero otherwise. We will have K+1 bump functions, each centered at 1, 2, …. With the gamma γ function, I simply try to associate the previous result with the vector component i , which will represent my histogram bins. Finally, generated_ak is nothing more than an interation over all the vector components to generate each of the bins.

1 Like

That’s what I get for answering many AD-related posts I guess :rofl:

Those two are actually very different, not sure if you’re clear on the specific behavior of each one. I think in your case multithreading (with shared memory) is more appropriate, unless the inner iterations are really long and you’re working on a cluster of several machines.

True, but if you naively add a @threads to your code as written you’re gonna make it incorrect due to race conditions in aₖ += ....
In any case, the most straightforward fix is probably using ThreadsX.jl to parallelize the sum:

aₖ = ThreadsX.sum(1:n_samples) do _
    x = rand(Normal(μ, stddev), K)
    yₖ = m(x')
    y = realModel(rand(Float64))
    generate_aₖ(yₖ, y)
end

Not tested but that’s the gist

I would like to see the actual code you ran in parallel cause I’m a bit surprised by the “foreign call” error.

1 Like

I understand that the difference is the same as between processes and threads in Linux. You’re right that launching threads should be lighter and have the advantage of shared resources.

You’re right, I was thinking that since I didn’t care about the order of the addition, there shouldn’t be any race conditions. But you’re right that the addition should be an atomic operation to avoid any problems.

I’ve tried what you mentioned. I assume you meant to say that:

aₖ = ThreadsX.sum(n_samples) do _
    x = rand(Normal(μ, stddev), K)
    yₖ = m(x')
    y = realModel(rand(Float64))
    generate_aₖ(yₖ, y)
end

Interestingly, I’m getting the following error.

nested task error: `llvmcall` must be compiled to be called
Stacktrace:

You copy-pasted the same code, but yeah inside the sum I meant 1:n_samples

Sorry, you are right

In theory, Zygote (the AD backend Flux uses by default) supports the lower level built-in threading constructs. See Zygote.jl/test/threads.jl at 2f4937096ee1db4b5a67c1c31fe3ebeab1c96c8c · FluxML/Zygote.jl · GitHub for a couple of examples. I do not believe it supports @threads because that adds a bunch of extra function calls, but it looks like you can get away with not using that for your model.

1 Like