Help With Automatic Differentiation

Good afternoon,

I wanted to ask some advice about Automatic Differentiation in Julia. I am aware of the fact that in AD mutating arrays are not supported, but I find it difficult to write my code in a non mutating fashion, therefore I wanted to ask if somebody has any help/advice on how to do that. For example:

N = 10
Y = rand(N)
W = rand(N)
A = rand(N)
for i in 1:N
    p = Y[i] * W[i]
    if p < 0.3
        A[i] = A[i] - p
    else
        A[i] = 0
    end
end

my naive approach was to:

newA = zeros(N)
p = Y .* W
pIdx = findall(p .< 0.3)
newA[pIdx] = A[pIdx] - p[pIdx]
notIdx = findall(p .>= 0.3)
newA[notIdx] = 0
A = newA

which of course does not work. Does anybody have any suggestions/advice or resources on the matter?

One of the sources I encountered suggest defining your own pullback for the mutating parts. Would anybody know where to start in these regard?

Thank you in advance!

This can perhaps be useful to you: GitHub - rakeshvar/Zygote-Mutating-Arrays-WorkAround.jl: A tutorial on how to work around ‘Mutating arrays is not supported’ error while performing automatic differentiation (AD) using the Julia package Zygote.

Bear in mind that the mutation limitation is intrinsic to Zygote.jl. You could try using Enzyme.jl (not sure if it will work), or ReverseDiff.jl.

1 Like

Thank you. This was the resource I spoke about in the post! Say that I wanted to differentiate the above to the respect of Y, would then the appropiate thing to do be to make the for loop in a separa function and then set the derivative to be -W[i] if p<0.3 and 0 otherwise?(when differentiating A?)

The place to start is the (very thorough) documentation of ChainRules.jl, on which Zygote.jl relies

1 Like

If your code relies on mutation for efficiency, a custom chain rule is what you want anyway

1 Like

Enzyme works fine on this code:

wmoses@beast:~/git/Enzyme.jl ((HEAD detached from 0a742b9)) $ ./julia-1.9.0/bin/julia --project test.jl 
(dY, dW, dA) = ([0.0, 0.0, -0.36481667164909504, 0.0, -0.40760801938330093, 0.0, -0.29238901363265246, -0.7601463773026405, -0.10528419303498171, -0.21089297714276434], [0.0, 0.0, -0.17407141133367543, 0.0, -0.17535508982666037, 0.0, -0.5764836125183987, -0.12627017591563605, -0.13579863783844937, -0.1190684947298164], [0.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0])
wmoses@beast:~/git/Enzyme.jl ((HEAD detached from 0a742b9)) $ cat test.jl 
using Enzyme

function f(N, Y, W, A)
    for i in 1:N
        p = Y[i] * W[i]
        if p < 0.3
            A[i] = A[i] - p
        else
            A[i] = 0
        end
    end
end


N = 10
Y = rand(N)
dY = zeros(N)
W = rand(N)
dW = zeros(N)

A = rand(N)
# Derivative we backpropagate
dA = ones(N)

Enzyme.autodiff(Reverse, f, Const(N), Duplicated(Y, dY), Duplicated(W, dW), Duplicated(A, dA))
@show dY, dW, dA
2 Likes

Comprehensions (or map) and broadcasting are ways of avoiding mutation. E.g.

julia> function orig!(A, Y, W) 
         for i in eachindex(A)
           p = Y[i] * W[i]
           if p < 0.3
               A[i] = A[i] - p
           else
               A[i] = 0
           end
         end
         A
       end;

julia> nonmut(A, Y, W) = map(eachindex(A)) do i
           p = Y[i] * W[i]
           (p < 0.3) * (A[i] - p)  # this makes consistent type
       end;

julia> nonmut(A, Y, W) ≈ orig!(copy(A), Y, W)
true

julia> noindex(A, Y, W) = @. ((Y * W) < 0.3) * (A - Y * W);  # Zygote also dislikes indexing

julia> noindex(A, Y, W) ≈ orig!(copy(A), Y, W)
true

julia> using Zygote, BenchmarkTools

julia> @btime gradient(a -> sum(abs2, nonmut(a,$Y,$W)), $A)
  min 1.625 μs, mean 2.413 μs (51 allocations, 11.69 KiB)
([-0.4549349972052379, 1.471142478710321, 0.927932031171403, 0.0, 1.055719617078303, 0.3149281759965271, 0.07592684856551069, 0.0, 0.0, 1.3759356933258697],)

julia> @btime gradient(a -> sum(abs2, noindex(a,$Y,$W)), $A)  # 1/4 the memory, probably saves more time at larger N
  min 1.271 μs, mean 1.415 μs (36 allocations, 3.16 KiB)
([-0.4549349972052379, 1.471142478710321, 0.927932031171403, -0.0, 1.055719617078303, 0.3149281759965271, 0.07592684856551069, 0.0, -0.0, 1.3759356933258697],)
2 Likes

Thanks to all for their different perspectives all of them are super useful!

1 Like

ForwardDiff.jl should work right?