Efficient automatic differentation for Julia version `jax.scan`?

,

Hi,

I am using in some JAX array code jax.lax.scan which implements something like this in my example:

function jax_lax_scan(f, x; accumulator_init)
    acc = accumulator_init 
    for i in axes(x, 1)
           acc = f(acc, selectdim(x, 1, i))
     end 
     return acc
end

julia> A = rand(2,2,3,1)
2×2×3×1 Array{Float64, 4}:
[:, :, 1, 1] =
 0.983886  0.301163
 0.142536  0.898853

[:, :, 2, 1] =
 0.502673  0.460876
 0.118309  0.675352

[:, :, 3, 1] =
 0.142753  0.386542
 0.858919  0.898126


julia> jax_lax_scan((acc, x) -> acc .+ x.^2, A; accumulator_init=zeros((2,3,1)))
2×3×1 Array{Float64, 3}:
[:, :, 1] =
 0.466399  0.598419  0.670064
 0.638399  0.307632  1.4969

What’s the Julia equivalent here? And is this efficiently supported in any automatic differentation package (such as in JAX)?

Yes, I could differentiate this with Zygote but afaik it would keep copies of each iteration in memory. My examples are arrays with (100, 512, 512, 100) so impossible to store.

I guess there is some Enzyme.jl + Reactant.jl way of doing that?

Thanks,

Felix

Ok apparently:

julia> f = (acc, x) -> acc .+ x.^2
 
julia> result = foldl(f, eachslice(A, dims=1); init=zeros((2,3,1)))
2×3×1 Array{Float64, 3}:
[:, :, 1] =
 0.988349  0.266677  0.758121
 0.898635  0.668507  0.956045

Does e.g. Enzyme efficiently with foldl, reduce or mapreduce?

I’m interested in this as well. Do Julia’s autodiff packages have specific performant rules for scan-like operations? Or perhaps such a rule wouldn’t increase performance?

If you add a @trace before the for loop and use Reactant, it will use some really nice AD tricks for differentiating the loop (and handles more broad set of possibilities compared to lax.scan).

quick answer is yes

julia> using Enzyme

julia> function jax_lax_scan(f, x, accumulator_init)
           acc = accumulator_init
               for i in axes(x, 1)
                  acc = f(acc, selectdim(x, 1, i))
            end
            return acc
       end
jax_lax_scan (generic function with 1 method)

julia> A = rand(2,2,3,1);

julia> accumulator_init=zeros((2,3,1));

julia> f(acc, x)  = acc .+ x.^2;

julia> Enzyme.jacobian(Enzyme.set_runtime_activity(Reverse),jax_lax_scan,Const(f),A,Const(accumulator_init))
(nothing, [1.6257113308613 0.0 0.0; 0.0 0.0 0.0;;;; 1.080987298281143 0.0 0.0; 0.0 0.0 0.0;;;;; 0.0 0.0 0.0; 0.9727596195934325 0.0 0.0;;;; 0.0 0.0 0.0; 0.2399608038078349 0.0 0.0;;;;;; 0.0 0.41852879546292643 0.0; 0.0 0.0 0.0;;;; 0.0 0.8543467156284372 0.0; 0.0 0.0 0.0;;;;; 0.0 0.0 0.0; 0.0 0.9046578942453445 0.0;;;; 0.0 0.0 0.0; 0.0 0.5680239161728668 0.0;;;;;; 0.0 0.0 1.2498992746934297; 0.0 0.0 0.0;;;; 0.0 0.0 0.7830888594771388; 0.0 0.0 0.0;;;;; 0.0 0.0 0.0; 0.0 0.0 1.5680193443731774;;;; 0.0 0.0 0.0; 0.0 0.0 1.8554432985549147;;;;;;;], nothing)

however, if you want to have a really efficient one, you may need Reactant.jl indeed, for instance it will get rid of all the buffers you use in the loop, if you stay with so little cases, then its fine

julia> @btime Enzyme.jacobian(Enzyme.set_runtime_activity(Reverse),jax_lax_scan,Const($f),$A,Const($accumulator_init))
  9.500 μs (216 allocations: 15.39 KiB)

the other way is to make a better function from julia side

function jax_lax_scan2(f!, x, accumulator_init)
    acc = copy(accumulator_init)
    for i in axes(x, 1)
        f!(acc, selectdim(x, 1, i))
    end
    return acc
end
function f!(acc, x)  
    acc .+= x.^2
    return nothing
end
g2 = Enzyme.jacobian(Enzyme.set_runtime_activity(Reverse),jax_lax_scan2,Const(f!),A,accumulator_init)[2]

leading to

@btime Enzyme.jacobian(Enzyme.set_runtime_activity(Reverse),jax_lax_scan2,Const($f!),$A,$accumulator_init)[2]

6.560 μs (123 allocations: 13.86 KiB)

1 Like