How to efficiently build AD-compatible matrices line by line

OK, this is much messier, I’m not sure there is a really nice answer. It can be written as accumulate, which is a bit faster, BUT it forgets the gradient of the init keyword, which might be bad.

You could write an efficient gradient for this by hand, in which you just allocate the right size output once, and accumulate in the reverse pass. All these cats and indexing use a lot of memory.

function f2(K, xi, d::Int)
    xs = accumulate(1:d-1; init=xi) do x, i
        K * x
    end
    hcat(xi, reduce(hcat, xs))
end

function f3(K, xi, d::Int)
    xs = accumulate(vcat([xi], 1:d-1)) do x, i  # avoiding init, type-unstable
        K * x
    end
    reduce(hcat, xs)
end

function f4(K, xi, d)
    xs = [xi]
    for i = 2:d
        xs = vcat(xs, [K*xs[i-1]])
    end
    reduce(hcat, xs)
end

f4(K, xi, 50) ≈ f3(K, xi, 50) ≈ f2(K, xi, 50) ≈ f(K, xi, 50)

using BenchmarkTools, Zygote
@btime f($K, $xi, 50);
@btime f2($K, $xi, 50);  # twice as quick
@btime f3($K, $xi, 50);  # a bit slower
@btime f4($K, $xi, 50);
julia> gradient(sum∘f, K, xi, 10)
([63.45016309970954 50.40609159573776 101.36588271461751; 23.572874731387856 18.379315265377535 35.224999619160954; 31.033286457367566 24.176359057416636 46.03455941092244], [48.455853839178204; 14.765466919614408; 18.845362109436827;;], nothing)

julia> gradient(sum∘f2, K, xi, 10) # NB the gradient for init=xi is missing!
([63.45016309970953 50.40609159573775 101.3658827146175; 23.572874731387852 18.379315265377535 35.22499961916095; 31.033286457367552 24.17635905741663 46.03455941092242], Fill(1.0, 3, 1), nothing)

julia> gradient(sum∘f3, K, xi, 10)
([63.45016309970953 50.40609159573775 101.3658827146175; 23.572874731387852 18.379315265377535 35.22499961916095; 31.033286457367552 24.17635905741663 46.03455941092242], [48.455853839178204; 14.765466919614404; 18.845362109436827;;], nothing)

julia> gradient(sum∘f4, K, xi, 10)
([63.45016309970953 50.40609159573775 101.3658827146175; 23.572874731387852 18.379315265377535 35.22499961916095; 31.033286457367552 24.17635905741663 46.03455941092242], [48.455853839178204; 14.765466919614404; 18.845362109436827;;], nothing)

julia> @btime gradient(sum∘f, $K, $xi, $10);
  min 30.291 μs, mean 34.062 μs (369 allocations, 24.08 KiB)

julia> @btime gradient(sum∘f2, $K, $xi, $10);
  min 21.583 μs, mean 28.247 μs (275 allocations, 36.14 KiB)

julia> @btime gradient(sum∘f3, $K, $xi, $10);
  min 49.917 μs, mean 58.936 μs (544 allocations, 48.08 KiB)

julia> @btime gradient(sum∘f4, $K, $xi, $10);
  min 76.375 μs, mean 89.648 μs (894 allocations, 63.39 KiB)
3 Likes