Unstack-stack-unstack-stack

any ideas on how to make this more efficient?
I don’t know how to broadcast the functions more efficiently than a for loop and matrix multiplications in 3D didn’t really work for me. Can someone help?

# input has size [Ni, B]
# solutions is a Array{Function} of length Np
b = model_one(input) #output is (Na, B)
w = model_two(input) #output is (Np+1, B)
B, Np = size(input, 2), size(w,1) - 1
    
qs = Zygote.ignore(() -> [hcat([s(input[:,i]) for s in solutions]...) for i=1:B]) #output is Bx(Na, Np)
Flux.stack(qs .* Flux.unstack(w[1:Np, :], 2), 2) .+ w[Np+1:Np+1, :] .* b

I’m not sure if I understand exactly the context.

  • From the code I guessed that s(input[:,i]) has dimension Na.
  • Do you want to get more speed or a better readable code?

My recommendation is to use a plain for-loop for your task.
This makes it easier to understand and probably even faster. In particular, it is not needed to allocate an array of size (Na, Np, B) for your calculation.

But, I’m not sure if the Zygote context allows this for-loop. If array indexing is not allowed for your context, we can try to find another solution.

Here is a minimal example:

Ni = 3
Na = 5
Np = 7
B = 11

input = rand(Ni,B)
solutions = [x -> x[1] .* rand(Na) for _ in 1:Np]

b = rand(Na,B)
w = rand(Np+1,B)

res = zeros(Na,B)
for j = 1:B
    global res  # just for the minimal example to work in global scope
    for i = 1:Np
        res[:,j] += w[i,j] .* solutions[i](input[:,j])
    end
    res[:,j] += w[end,j] .* b[:,j]   # this operation could also be done outside of the loop, like in your code
end