Hello everyone, I am trying to parallelize a cost function. I am using Distributed and Zygote as an autodifferentiator for this. The example is as follows
function sliced_invariant_statistical_loss_distributed(nn_model, loader, hparams::HyperParamsSlicedISL)
@assert loader.batchsize == hparams.samples
@assert length(loader) == hparams.epochs
losses = []
optim = Flux.setup(Flux.Adam(hparams.η), nn_model)
@showprogress for data in loader
loss, grads = Flux.withgradient(nn_model) do nn
Ω = [sample_random_direction(size(data)[1]) for _ in 1:(hparams.m)]
# Distributed computation for the outer loop
total = @distributed (+) for ω in Ω
compute_loss_for_ω(nn, ω, data, hparams)
end
total / hparams.m
end
Flux.update!(optim, nn_model, grads[1])
push!(losses, loss)
end
return losses
end;
# Helper function to compute loss for a single ω
function compute_loss_for_ω(nn, ω, data, hparams)
aₖ = zeros(hparams.K + 1)
for i in 1:hparams.samples
x = rand(hparams.noise_model, hparams.K)
yₖ = nn(x)
s = collect(reshape(ω' * yₖ, 1, hparams.K))
aₖ += generate_aₖ(s, ω ⋅ data[:, i])
end
scalar_diff(aₖ ./ sum(aₖ))
end
The error that I am encountering is the following:
nested task error: Compiling Tuple{typeof(Distributed.run_work_thunk),
Distributed.var"#153#154"{ISL.var"#48#51"{Chain{Tuple{Dense{typeof(identity),
Matrix{Float32}, Vector{Float32}}, typeof(elu), Dense{typeof(identity), Matrix{Float32},
Vector{Float32}}, typeof(elu), Dense{typeof(identity), Matrix{Float32},
Vector{Float32}}, typeof(elu), Dense{typeof(identity), Matrix{Float32},
Vector{Float32}}}}, HyperParamsSlicedISL, Matrix{Float32}}, Tuple{typeof(+),
Vector{Vector{Float32}}, Int64, Int64}, Base.Pairs{Symbol, Union{}, Tuple{},
NamedTuple{(), Tuple{}}}}, Bool}: try/catch is not supported. Refer to the Zygote
documentation for fixes. https://fluxml.ai/Zygote.jl/latest/limitations
It seems to be a limitation of Zygote. Can someone explain to me what the actual limitation is in this case and if there is any way to circumvent it?