Help with workaround a mutation in an optimization problem

Hi everyone,

I’m running into an issue with an optimization problem utilizing gradient-based updates using Zygote where an in-place mutation appears to be causing problems. Specifically, when I update my parameter vector x and assign it to x_new, Zygote complains that mutation isn’t supported.

Below is a simplified version of my code:

using Convex
using Random
using Zygote
using JuMP
using MathOptInterface
using Clarabel

optimizer = Convex.MOI.OptimizerWithAttributes(Clarabel.Optimizer, "verbose" => 0)
Random.seed!(0)

function f(x)
    return x[1]*x[1]
end

function g_u(x, t, j)
    return x[1] * t[j] + 3.0
end

function barrier(x, t)
    constraint_values = [-g_u(x, t, j) for j in 1:length(t)]
    return f(x) - sum(log.(constraint_values))
end

function inner_min(u)
    lr_x = 1e-10
    x = rand(1)

    grad_x = Zygote.gradient(x -> barrier(x, u), x)[1]
    x_new = x .- lr_x .* grad_x

    return f(x_new)
end

u = rand(1)  
lr_u = 1e-1  
grad_u = Zygote.gradient(u -> inner_min(u), u)[1]
u .= u .+ lr_u .* grad_u

The error points to the mutation on x_new (via x .- ...). I suspect that this is because Zygote doesn’t support in-place mutations on arrays during gradient computation.

Has anyone encountered this issue or found a workaround? Is there an alternative AD approach or package that supports such operations? Any insights or suggestions would be greatly appreciated.

Thanks in advance!

Hi, welcome to the forum!
When I run your code, the first error I get is a DomainError because a log is applied to some negative value. Can you maybe fix the initial values in a way that is coherent with your problem, to make sure that g always outputs something negative?
Before we talk about AD specifics, can you maybe explain why you want to differentiate through inner_min? There are smarter ways to do this than differentiating through each gradient step, and the right recipe depends on the structure of your problem. Is it convex for instance? If so, you can use DiffOpt.jl to automatically differentiate through the JuMP.jl formulation.