Zygote: mutating array error with Buffer during parameter gradient call

I’m trying to update the parameters of a multi-input multi-output neural network using a loss function that is some function of the Jacobian of the neural network. However, computing the gradient of the loss produces an array mutation error, even when I use Zygote.Buffer to construct the Jacobian. Is the below example a correct implementation of this? Tips appreciated, I’m fairly new to using Julia!

# MWE
using Flux, Zygote

nn = Chain(Dense(2,2),
           Dense(2,2))
p,re = Flux.destructure(nn)
theta = copy(p)
ps = Flux.params(theta)

# buffered jacobian of a function
function jacobian(f, x)
    n = length(x)
    buf = Zygote.Buffer(ones(2,2),n,n)
    for i in 1:n
        buf[i,:] = gradient(x -> f(x)[i], x)[1]
    end
    return copy(buf)
end

function loss()
    # arbitrary function of jacobian of nn, arbitrary input
    sum(jacobian(nn,[1;1]))
end

# LoadError: Mutating arrays is not supported
gs = gradient(ps) do
    return loss()
end

1 Like