How to efficiently build AD-compatible matrices line by line

I find myself frequently needing a construct like the following:

a = Array{Float32}(undef, 100, 5)
for i=1:100
    a[i, :] = <a vector>
end

which works fine in plain Julia, and is highly efficient, but gives Flux fits when used inside a layer or a loss function because it screws up the AD.

What I am currently doing instead is the following:

a = <first vector>
for i=1:100
    a = hcat(a, <next vector>)
end

which is terribly inefficient when writing vanilla Julia because of all the allocations that have to be performed.

Is there any way to get the efficiency of memory preallocation using Flux?

Use a = Array{T}(undef, 100, 5) where T is a type parameter based on the input type of the function. This way, when ad runs on the code, T will be a Dual{Float32} and everything will work.

For reverse mode, I think you will want to use reduce(hcat, all_the_vectors), which should be efficient.

4 Likes

In trying this, I get an error I’m not able to diagnose:

Cannot `convert` an object of type Float32 to an object of type CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}

when I try to write to a column of a.

Here’s my code:

function (m::MyLayer)(x::T) where T
    xi = Array{T}(undef, <sizex>, <sizey>)
    for i=1:size(x,1)
        xi[i, :] = m.W * x[i, :] .+ m.b
    end
    return xi
end

Ah, got it. eltype(T) gives the element type of the array.

Keep in mind that Julia uses column major memory layout, so it might be more efficient to store your data so you can slice it like
x[:, i] instead of x[i, :].
Also have a look at @views

Also, note that this loop is ξ = x * m.W' .+ b', which will be quicker. But perhaps it’s a toy example.

1 Like

OK, so… here’s the code as it currently stands. It doesn’t work.

function (m::MyLayer)(x::T) where T
    xi = Array{eltype{T}}(undef, <sizex>, <sizey>)
    for i=1:size(x,1)
        xi[i, :] = m.W * x[i, :] .+ m.b
    end
    return xi
end

The reason it doesn’t work is (I think) that x is a CuArray{Float32}, but xi is just a Matrix{Float32}. So the layer gets executed, but everything gets pulled back onto the CPU, which screws up the nest layer down which is still on the GPU.

Sigh. OK, I got the routine to run… and I’m back to the same AD error. Here’s the code:

function (m::MyLayer)(x::T) where T
    xi = similar(T, (<sizex>, <sizey>))
    for i=1:size(x,1)
        xi[i, :] = m.W * x[i, :] .+ m.b
    end
    return xi
end

and the error I get is the same one I was getting originally:

ERROR: LoadError: Mutating arrays is not supported -- called set_index!(::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, _...)

Note that typeof{T} returns CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}. There is not a Dual{Float32} in sight.

For future reference, the correct answer appears to be using the Zygote.Buffer data structure, which lets you create something that acts like an array but is mutable:

using Zygote
function (m::MyLayer)(x::AbstractArray)
    xi = Zygote.Buffer(x, (<sizex>, <sizey>))
    for i=1:size(x,1)
        xi[i, :] = m.W * x[i, :] .+ m.b
    end
    return copy(xi)
end

It is not yet clear to me that this is actually more efficient than the vcat solution I was originally using, but it does in fact run. It’s kind of hackish that it requires explicit coding for AD…

The correct solution is *, as above.

Unless this is a warm-up problem for something harder. In which case the correct solution is probably something like reduce(hcat, map(f, xs)).

I’m not sure what you mean by this?

Sorry, perhaps too compressed. But this function does matrix multiplication, by hand. This will be slow, 20x on my computer at the size below, and even worse under AD.

function loop(W, x::T, b) where T
    xi = Array{eltype(T)}(undef, size(x,1), size(W,1))
    for i=1:size(x,1)
        xi[i, :] = W * x[i, :] .+ b
    end
    return xi
end
blas(W, x, b) = x * W' .+ b'
blas2(W, x, b) = muladd(x, W', b')

x = rand(50,40); W = rand(30,40); b = rand(30);
loop(W,x,b) ≈ blas(W, x, b)  ≈ blas2(W, x, b)  # true

I actually have two different use cases. One does indeed fit into this framework, and I have adopted it.

The other I’m not so sure. What I need to do is the following:

x = xi # an nx1 vector
for i=1:l
    x = vcat(x, K*x[end-n:end, :])
end

e.g. I’m propagating a linear difference equation for l steps and recording the trajectory in x. I can’t figure out how to fit a recursive relationship like this into a reduce() framework.

Well maybe this is messier. Can you make a function that runs, and sample data? end-n:end is out of bounds if x is an n-vector, what’s the intended behaviour?

Yeah, sorry, I should have been more precise. Here’s a running example:

function f(K, xi, d)
    x = xi
    for i = 2:d
        x = hcat(x, K*x[:, i-1])
    end
    return x
end

K = rand(3,3)
xi = rand(3,1)
f(K, xi, 50)

OK, this is much messier, I’m not sure there is a really nice answer. It can be written as accumulate, which is a bit faster, BUT it forgets the gradient of the init keyword, which might be bad.

You could write an efficient gradient for this by hand, in which you just allocate the right size output once, and accumulate in the reverse pass. All these cats and indexing use a lot of memory.

function f2(K, xi, d::Int)
    xs = accumulate(1:d-1; init=xi) do x, i
        K * x
    end
    hcat(xi, reduce(hcat, xs))
end

function f3(K, xi, d::Int)
    xs = accumulate(vcat([xi], 1:d-1)) do x, i  # avoiding init, type-unstable
        K * x
    end
    reduce(hcat, xs)
end

function f4(K, xi, d)
    xs = [xi]
    for i = 2:d
        xs = vcat(xs, [K*xs[i-1]])
    end
    reduce(hcat, xs)
end

f4(K, xi, 50) ≈ f3(K, xi, 50) ≈ f2(K, xi, 50) ≈ f(K, xi, 50)

using BenchmarkTools, Zygote
@btime f($K, $xi, 50);
@btime f2($K, $xi, 50);  # twice as quick
@btime f3($K, $xi, 50);  # a bit slower
@btime f4($K, $xi, 50);
julia> gradient(sum∘f, K, xi, 10)
([63.45016309970954 50.40609159573776 101.36588271461751; 23.572874731387856 18.379315265377535 35.224999619160954; 31.033286457367566 24.176359057416636 46.03455941092244], [48.455853839178204; 14.765466919614408; 18.845362109436827;;], nothing)

julia> gradient(sum∘f2, K, xi, 10) # NB the gradient for init=xi is missing!
([63.45016309970953 50.40609159573775 101.3658827146175; 23.572874731387852 18.379315265377535 35.22499961916095; 31.033286457367552 24.17635905741663 46.03455941092242], Fill(1.0, 3, 1), nothing)

julia> gradient(sum∘f3, K, xi, 10)
([63.45016309970953 50.40609159573775 101.3658827146175; 23.572874731387852 18.379315265377535 35.22499961916095; 31.033286457367552 24.17635905741663 46.03455941092242], [48.455853839178204; 14.765466919614404; 18.845362109436827;;], nothing)

julia> gradient(sum∘f4, K, xi, 10)
([63.45016309970953 50.40609159573775 101.3658827146175; 23.572874731387852 18.379315265377535 35.22499961916095; 31.033286457367552 24.17635905741663 46.03455941092242], [48.455853839178204; 14.765466919614404; 18.845362109436827;;], nothing)

julia> @btime gradient(sum∘f, $K, $xi, $10);
  min 30.291 μs, mean 34.062 μs (369 allocations, 24.08 KiB)

julia> @btime gradient(sum∘f2, $K, $xi, $10);
  min 21.583 μs, mean 28.247 μs (275 allocations, 36.14 KiB)

julia> @btime gradient(sum∘f3, $K, $xi, $10);
  min 49.917 μs, mean 58.936 μs (544 allocations, 48.08 KiB)

julia> @btime gradient(sum∘f4, $K, $xi, $10);
  min 76.375 μs, mean 89.648 μs (894 allocations, 63.39 KiB)
3 Likes

This is amazing, thanks very much for the help!