OK, this is much messier, I’m not sure there is a really nice answer. It can be written as accumulate
, which is a bit faster, BUT it forgets the gradient of the init
keyword, which might be bad.
You could write an efficient gradient for this by hand, in which you just allocate the right size output once, and accumulate in the reverse pass. All these cat
s and indexing use a lot of memory.
function f2(K, xi, d::Int)
xs = accumulate(1:d-1; init=xi) do x, i
K * x
end
hcat(xi, reduce(hcat, xs))
end
function f3(K, xi, d::Int)
xs = accumulate(vcat([xi], 1:d-1)) do x, i # avoiding init, type-unstable
K * x
end
reduce(hcat, xs)
end
function f4(K, xi, d)
xs = [xi]
for i = 2:d
xs = vcat(xs, [K*xs[i-1]])
end
reduce(hcat, xs)
end
f4(K, xi, 50) ≈ f3(K, xi, 50) ≈ f2(K, xi, 50) ≈ f(K, xi, 50)
using BenchmarkTools, Zygote
@btime f($K, $xi, 50);
@btime f2($K, $xi, 50); # twice as quick
@btime f3($K, $xi, 50); # a bit slower
@btime f4($K, $xi, 50);
julia> gradient(sum∘f, K, xi, 10)
([63.45016309970954 50.40609159573776 101.36588271461751; 23.572874731387856 18.379315265377535 35.224999619160954; 31.033286457367566 24.176359057416636 46.03455941092244], [48.455853839178204; 14.765466919614408; 18.845362109436827;;], nothing)
julia> gradient(sum∘f2, K, xi, 10) # NB the gradient for init=xi is missing!
([63.45016309970953 50.40609159573775 101.3658827146175; 23.572874731387852 18.379315265377535 35.22499961916095; 31.033286457367552 24.17635905741663 46.03455941092242], Fill(1.0, 3, 1), nothing)
julia> gradient(sum∘f3, K, xi, 10)
([63.45016309970953 50.40609159573775 101.3658827146175; 23.572874731387852 18.379315265377535 35.22499961916095; 31.033286457367552 24.17635905741663 46.03455941092242], [48.455853839178204; 14.765466919614404; 18.845362109436827;;], nothing)
julia> gradient(sum∘f4, K, xi, 10)
([63.45016309970953 50.40609159573775 101.3658827146175; 23.572874731387852 18.379315265377535 35.22499961916095; 31.033286457367552 24.17635905741663 46.03455941092242], [48.455853839178204; 14.765466919614404; 18.845362109436827;;], nothing)
julia> @btime gradient(sum∘f, $K, $xi, $10);
min 30.291 μs, mean 34.062 μs (369 allocations, 24.08 KiB)
julia> @btime gradient(sum∘f2, $K, $xi, $10);
min 21.583 μs, mean 28.247 μs (275 allocations, 36.14 KiB)
julia> @btime gradient(sum∘f3, $K, $xi, $10);
min 49.917 μs, mean 58.936 μs (544 allocations, 48.08 KiB)
julia> @btime gradient(sum∘f4, $K, $xi, $10);
min 76.375 μs, mean 89.648 μs (894 allocations, 63.39 KiB)