Hi,
I am using in some JAX array code jax.lax.scan
which implements something like this in my example:
function jax_lax_scan(f, x; accumulator_init)
acc = accumulator_init
for i in axes(x, 1)
acc = f(acc, selectdim(x, 1, i))
end
return acc
end
julia> A = rand(2,2,3,1)
2×2×3×1 Array{Float64, 4}:
[:, :, 1, 1] =
0.983886 0.301163
0.142536 0.898853
[:, :, 2, 1] =
0.502673 0.460876
0.118309 0.675352
[:, :, 3, 1] =
0.142753 0.386542
0.858919 0.898126
julia> jax_lax_scan((acc, x) -> acc .+ x.^2, A; accumulator_init=zeros((2,3,1)))
2×3×1 Array{Float64, 3}:
[:, :, 1] =
0.466399 0.598419 0.670064
0.638399 0.307632 1.4969
What’s the Julia equivalent here? And is this efficiently supported in any automatic differentation package (such as in JAX)?
Yes, I could differentiate this with Zygote but afaik it would keep copies of each iteration in memory. My examples are arrays with (100, 512, 512, 100)
so impossible to store.
I guess there is some Enzyme.jl + Reactant.jl way of doing that?
Thanks,
Felix