Best practice for Flux.jl: how to untrack gradient?

This question is based on the previous question.

I found that it’s hard to incorporate Convex.jl and Flux.jl.
So, I tried to “approximate” the solution by untracking gradients of the solution of Convex.jl from Flux.jl (probably Zygote.jl).

I tried to use deepcopy but it seems not work.
How to untrack the gradient of Convex.jl solution?

using Convex
using Flux
using Flux.Data: DataLoader
using Mosek, MosekTools
using Zygote


n = 3
m = 2
num = 10
x = rand(n, num)
u = rand(m, num)
data_train = (x, u)
dataloader = DataLoader(data_train..., batchsize=64, shuffle=true)
opt = ADAM(1e-3)
network = Chain(Dense(n+m, 1))

(network::Chain)(x_u::Convex.AbstractExpr) = network.layers[1].W * x_u + (network.layers[1].b .+ network.layers[1].W * zeros(size(x_u)...))

function test()
    __network = deepcopy(network)
    d = size(u)[2]
    problems = [Convex.minimize(__network(vcat(x[:, i], Convex.Variable(size(u[:, i])...)))) for i in 1:d]
    for i in 1:d
        solve!(problems[i], Mosek.Optimizer())
    end
    optval = [problems[i].optval for i in 1:d]
    sum(optval)
end

function main()
    loss = function (x, u)
        __network = deepcopy(network)
        d = size(u)[2]
        problems = [Convex.minimize(__network(vcat(x[:, i], Convex.Variable(size(u[:, i])...)))) for i in 1:d]
        for i in 1:d
            solve!(problems[i], Mosek.Optimizer())
        end
        optval = [problems[i].optval for i in 1:d]
        sum(optval)
    end
    sqnorm(x) = sum(abs2, x)
    loss_reg(args...) = loss(args...) + 1e-3 * sum(sqnorm, Flux.params(network))
    Flux.train!(loss_reg, Flux.params(network), dataloader, opt)
end
julia> main()
ERROR: MethodError: no method matching zero(::Convex.AdditionAtom)
Closest candidates are:
  zero(::Type{Dates.DateTime}) at /Users/julia/buildbot/worker/package_macos64/build/usr/share/julia/stdlib/v1.5/Dates/src/types.jl:404
  zero(::Type{Dates.Date}) at /Users/julia/buildbot/worker/package_macos64/build/usr/share/julia/stdlib/v1.5/Dates/src/types.jl:405
  zero(::Type{ModelingToolkit.TermCombination}) at /Users/jinrae/.julia/packages/ModelingToolkit/hkIWj/src/linearity.jl:67
  ...
Stacktrace:
 [1] iszero(::Convex.AdditionAtom) at ./number.jl:40
 [2] rrule(::typeof(sign), ::Convex.AdditionAtom) at /Users/jinrae/.julia/packages/ChainRules/wuTHR/src/rulesets/Base/fastmath_able.jl:197
 [3] chain_rrule at /Users/jinrae/.julia/packages/Zygote/EjVY4/src/compiler/chainrules.jl:87 [inlined]
 [4] macro expansion at /Users/jinrae/.julia/packages/Zygote/EjVY4/src/compiler/interface2.jl:0 [inlined]
 [5] _pullback(::Zygote.Context, ::typeof(sign), ::Convex.AdditionAtom) at /Users/jinrae/.julia/packages/Zygote/EjVY4/src/compiler/interface2.jl:12
 [6] Problem at /Users/jinrae/.julia/packages/Convex/Zv1ch/src/problems.jl:10 [inlined]
 [7] _pullback(::Zygote.Context, ::Type{Problem{Float64}}, ::Symbol, ::Convex.AdditionAtom, ::Array{Constraint,1}) at /Users/jinrae/.julia/packages/Zygote/EjVY4/src/compiler/interface2.jl:0
 [8] #minimize#4 at /Users/jinrae/.julia/packages/Convex/Zv1ch/src/problems.jl:74 [inlined]
 [9] _pullback(::Zygote.Context, ::Convex.var"##minimize#4", ::Type{Float64}, ::typeof(minimize), ::Convex.AdditionAtom, ::Array{Constraint,1}) at /Users/jinrae/.julia/packages/Zygote/EjVY4/src/compiler/interface2.jl:0
 [10] minimize at /Users/jinrae/.julia/packages/Convex/Zv1ch/src/problems.jl:74 [inlined]
 [11] _pullback(::Zygote.Context, ::typeof(minimize), ::Convex.AdditionAtom, ::Array{Constraint,1}) at /Users/jinrae/.julia/packages/Zygote/EjVY4/src/compiler/interface2.jl:0
 [12] minimize at /Users/jinrae/.julia/packages/Convex/Zv1ch/src/problems.jl:74 [inlined]
 [13] _pullback(::Zygote.Context, ::typeof(minimize), ::Convex.AdditionAtom) at /Users/jinrae/.julia/packages/Zygote/EjVY4/src/compiler/interface2.jl:0
 [14] #76 at ./none:0 [inlined]
 [15] _pullback(::Zygote.Context, ::var"#76#79"{Array{Float64,2},Array{Float64,2},Chain{Tuple{Dense{typeof(identity),Array{Float32,2},Array{Float32,1}}}}}, ::Int64) at /Users/jinrae/.julia/packages/Zygote/EjVY4/src/compiler/interface2.jl:0
 [16] (::Zygote.var"#586#590"{Zygote.Context,var"#76#79"{Array{Float64,2},Array{Float64,2},Chain{Tuple{Dense{typeof(identity),Array{Float32,2},Array{Float32,1}}}}}})(::Int64) at /Users/jinrae/.julia/packages/Zygote/EjVY4/src/lib/array.jl:174
 [17] iterate at ./generator.jl:47 [inlined]
 [18] _collect at ./array.jl:699 [inlined]
 [19] collect_similar at ./array.jl:628 [inlined]
 [20] map at ./abstractarray.jl:2162 [inlined]
 [21] ∇map at /Users/jinrae/.julia/packages/Zygote/EjVY4/src/lib/array.jl:174 [inlined]
 [22] _pullback(::Zygote.Context, ::typeof(collect), ::Base.Generator{UnitRange{Int64},var"#76#79"{Array{Float64,2},Array{Float64,2},Chain{Tuple{Dense{typeof(identity),Array{Float32,2},Array{Float32,1}}}}}}) at /Users/jinrae/.julia/packages/Zygote/EjVY4/src/lib/array.jl:193
 [23] #75 at /Users/jinrae/.julia/dev/PartiallyConvexApproximator/test/cvx_flux.jl:35 [inlined]
 [24] _pullback(::Zygote.Context, ::var"#75#78", ::Array{Float64,2}, ::Array{Float64,2}) at /Users/jinrae/.julia/packages/Zygote/EjVY4/src/compiler/interface2.jl:0
 [25] adjoint at /Users/jinrae/.julia/packages/Zygote/EjVY4/src/lib/lib.jl:175 [inlined]
 [26] _pullback at /Users/jinrae/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:57 [inlined]
 [27] loss_reg at /Users/jinrae/.julia/dev/PartiallyConvexApproximator/test/cvx_flux.jl:43 [inlined]
 [28] _pullback(::Zygote.Context, ::var"#loss_reg#82"{var"#75#78",var"#sqnorm#81"}, ::Array{Float64,2}, ::Array{Float64,2}) at /Users/jinrae/.julia/packages/Zygote/EjVY4/src/compiler/interface2.jl:0
 [29] adjoint at /Users/jinrae/.julia/packages/Zygote/EjVY4/src/lib/lib.jl:175 [inlined]
 [30] _pullback at /Users/jinrae/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:57 [inlined]
 [31] #15 at /Users/jinrae/.julia/packages/Flux/sY3yx/src/optimise/train.jl:103 [inlined]
 [32] _pullback(::Zygote.Context, ::Flux.Optimise.var"#15#21"{var"#loss_reg#82"{var"#75#78",var"#sqnorm#81"},Tuple{Array{Float64,2},Array{Float64,2}}}) at /Users/jinrae/.julia/packages/Zygote/EjVY4/src/compiler/interface2.jl:0
 [33] pullback(::Function, ::Params) at /Applications/Julia-1.5.app/Contents/Resources/julia/lib/julia/sys.dylib:?
 [34] gradient(::Function, ::Params) at /Applications/Julia-1.5.app/Contents/Resources/julia/lib/julia/sys.dylib:?
 [35] macro expansion at /Users/jinrae/.julia/packages/Flux/sY3yx/src/optimise/train.jl:102 [inlined]
 [36] macro expansion at /Users/jinrae/.julia/packages/Juno/n6wyj/src/progress.jl:134 [inlined]
 [37] train!(::Function, ::Params, ::DataLoader{Tuple{Array{Float64,2},Array{Float64,2}}}, ::ADAM; cb::Flux.Optimise.var"#16#22") at /Users/jinrae/.julia/packages/Flux/sY3yx/src/optimise/train.jl:100
 [38] train!(::Function, ::Params, ::DataLoader{Tuple{Array{Float64,2},Array{Float64,2}}}, ::ADAM) at /Users/jinrae/.julia/packages/Flux/sY3yx/src/optimise/train.jl:98
 [39] main() at /Users/jinrae/.julia/dev/PartiallyConvexApproximator/test/cvx_flux.jl:44
 [40] top-level scope at REPL[17]:1

julia> test()
Problem
  Name                   :
  Objective sense        : min
  Type                   : LO (linear optimization problem)
  Constraints            : 1
  Cones                  : 0
  Scalar variables       : 3
  Matrix variables       : 0
  Integer variables      : 0

Optimizer started.
Presolve started.
Eliminator started.
Freed constraints in eliminator : 0
Eliminator terminated.
Eliminator started.
Freed constraints in eliminator : 0
Eliminator terminated.
Eliminator - tries                  : 2                 time                   : 0.00
Lin. dep.  - tries                  : 0                 time                   : 0.00
Lin. dep.  - number                 : 0
Presolve terminated. Time: 0.00
Optimizer terminated. Time: 0.00

┌ Warning: Problem status DUAL_INFEASIBLE; solution may be inaccurate.
└ @ Convex ~/.julia/packages/Convex/Zv1ch/src/solution.jl:253
Problem
  Name                   :
  Objective sense        : min
  Type                   : LO (linear optimization problem)
  Constraints            : 1
  Cones                  : 0
  Scalar variables       : 3
  Matrix variables       : 0
  Integer variables      : 0

Optimizer started.
Presolve started.
Eliminator started.
Freed constraints in eliminator : 0
Eliminator terminated.
Eliminator started.
Freed constraints in eliminator : 0
Eliminator terminated.
Eliminator - tries                  : 2                 time                   : 0.00
Lin. dep.  - tries                  : 0                 time                   : 0.00
Lin. dep.  - number                 : 0
Presolve terminated. Time: 0.00
Optimizer terminated. Time: 0.00

┌ Warning: Problem status DUAL_INFEASIBLE; solution may be inaccurate.
└ @ Convex ~/.julia/packages/Convex/Zv1ch/src/solution.jl:253
Problem
  Name                   :
  Objective sense        : min
  Type                   : LO (linear optimization problem)
  Constraints            : 1
  Cones                  : 0
  Scalar variables       : 3
  Matrix variables       : 0
  Integer variables      : 0

Optimizer started.
Presolve started.
Eliminator started.
Freed constraints in eliminator : 0
Eliminator terminated.
Eliminator started.
Freed constraints in eliminator : 0
Eliminator terminated.
Eliminator - tries                  : 2                 time                   : 0.00
Lin. dep.  - tries                  : 0                 time                   : 0.00
Lin. dep.  - number                 : 0
Presolve terminated. Time: 0.00
Optimizer terminated. Time: 0.00

┌ Warning: Problem status DUAL_INFEASIBLE; solution may be inaccurate.
└ @ Convex ~/.julia/packages/Convex/Zv1ch/src/solution.jl:253
Problem
  Name                   :
  Objective sense        : min
  Type                   : LO (linear optimization problem)
  Constraints            : 1
  Cones                  : 0
  Scalar variables       : 3
  Matrix variables       : 0
  Integer variables      : 0

Optimizer started.
Presolve started.
Eliminator started.
Freed constraints in eliminator : 0
Eliminator terminated.
Eliminator started.
Freed constraints in eliminator : 0
Eliminator terminated.
Eliminator - tries                  : 2                 time                   : 0.00
Lin. dep.  - tries                  : 0                 time                   : 0.00
Lin. dep.  - number                 : 0
Presolve terminated. Time: 0.00
Optimizer terminated. Time: 0.00

┌ Warning: Problem status DUAL_INFEASIBLE; solution may be inaccurate.
└ @ Convex ~/.julia/packages/Convex/Zv1ch/src/solution.jl:253
Problem
  Name                   :
  Objective sense        : min
  Type                   : LO (linear optimization problem)
  Constraints            : 1
  Cones                  : 0
  Scalar variables       : 3
  Matrix variables       : 0
  Integer variables      : 0

Optimizer started.
Presolve started.
Eliminator started.
Freed constraints in eliminator : 0
Eliminator terminated.
Eliminator started.
Freed constraints in eliminator : 0
Eliminator terminated.
Eliminator - tries                  : 2                 time                   : 0.00
Lin. dep.  - tries                  : 0                 time                   : 0.00
Lin. dep.  - number                 : 0
Presolve terminated. Time: 0.00
Optimizer terminated. Time: 0.00

┌ Warning: Problem status DUAL_INFEASIBLE; solution may be inaccurate.
└ @ Convex ~/.julia/packages/Convex/Zv1ch/src/solution.jl:253
Problem
  Name                   :
  Objective sense        : min
  Type                   : LO (linear optimization problem)
  Constraints            : 1
  Cones                  : 0
  Scalar variables       : 3
  Matrix variables       : 0
  Integer variables      : 0

Optimizer started.
Presolve started.
Eliminator started.
Freed constraints in eliminator : 0
Eliminator terminated.
Eliminator started.
Freed constraints in eliminator : 0
Eliminator terminated.
Eliminator - tries                  : 2                 time                   : 0.00
Lin. dep.  - tries                  : 0                 time                   : 0.00
Lin. dep.  - number                 : 0
Presolve terminated. Time: 0.00
Optimizer terminated. Time: 0.00

┌ Warning: Problem status DUAL_INFEASIBLE; solution may be inaccurate.
└ @ Convex ~/.julia/packages/Convex/Zv1ch/src/solution.jl:253
Problem
  Name                   :
  Objective sense        : min
  Type                   : LO (linear optimization problem)
  Constraints            : 1
  Cones                  : 0
  Scalar variables       : 3
  Matrix variables       : 0
  Integer variables      : 0

Optimizer started.
Presolve started.
Eliminator started.
Freed constraints in eliminator : 0
Eliminator terminated.
Eliminator started.
Freed constraints in eliminator : 0
Eliminator terminated.
Eliminator - tries                  : 2                 time                   : 0.00
Lin. dep.  - tries                  : 0                 time                   : 0.00
Lin. dep.  - number                 : 0
Presolve terminated. Time: 0.00
Optimizer terminated. Time: 0.00

┌ Warning: Problem status DUAL_INFEASIBLE; solution may be inaccurate.
└ @ Convex ~/.julia/packages/Convex/Zv1ch/src/solution.jl:253
Problem
  Name                   :
  Objective sense        : min
  Type                   : LO (linear optimization problem)
  Constraints            : 1
  Cones                  : 0
  Scalar variables       : 3
  Matrix variables       : 0
  Integer variables      : 0

Optimizer started.
Presolve started.
Eliminator started.
Freed constraints in eliminator : 0
Eliminator terminated.
Eliminator started.
Freed constraints in eliminator : 0
Eliminator terminated.
Eliminator - tries                  : 2                 time                   : 0.00
Lin. dep.  - tries                  : 0                 time                   : 0.00
Lin. dep.  - number                 : 0
Presolve terminated. Time: 0.00
Optimizer terminated. Time: 0.00

┌ Warning: Problem status DUAL_INFEASIBLE; solution may be inaccurate.
└ @ Convex ~/.julia/packages/Convex/Zv1ch/src/solution.jl:253
Problem
  Name                   :
  Objective sense        : min
  Type                   : LO (linear optimization problem)
  Constraints            : 1
  Cones                  : 0
  Scalar variables       : 3
  Matrix variables       : 0
  Integer variables      : 0

Optimizer started.
Presolve started.
Eliminator started.
Freed constraints in eliminator : 0
Eliminator terminated.
Eliminator started.
Freed constraints in eliminator : 0
Eliminator terminated.
Eliminator - tries                  : 2                 time                   : 0.00
Lin. dep.  - tries                  : 0                 time                   : 0.00
Lin. dep.  - number                 : 0
Presolve terminated. Time: 0.00
Optimizer terminated. Time: 0.00

┌ Warning: Problem status DUAL_INFEASIBLE; solution may be inaccurate.
└ @ Convex ~/.julia/packages/Convex/Zv1ch/src/solution.jl:253
Problem
  Name                   :
  Objective sense        : min
  Type                   : LO (linear optimization problem)
  Constraints            : 1
  Cones                  : 0
  Scalar variables       : 3
  Matrix variables       : 0
  Integer variables      : 0

Optimizer started.
Presolve started.
Eliminator started.
Freed constraints in eliminator : 0
Eliminator terminated.
Eliminator started.
Freed constraints in eliminator : 0
Eliminator terminated.
Eliminator - tries                  : 2                 time                   : 0.00
Lin. dep.  - tries                  : 0                 time                   : 0.00
Lin. dep.  - number                 : 0
Presolve terminated. Time: 0.00
Optimizer terminated. Time: 0.00

┌ Warning: Problem status DUAL_INFEASIBLE; solution may be inaccurate.
└ @ Convex ~/.julia/packages/Convex/Zv1ch/src/solution.jl:253
-10.0

Note: In the above code, the desired situation is that no gradient update is performed.

The answer was actually very easy, I’d not realised though.
I leave an example of Q-learning with separated two Q-networks and low pass filter (LPF)-like update.
Unfortunately, the following code is not a copy-and-paste-executable code. Just refer the structure of this code. For more details, see an example of Flux.jl.

using PartiallyConvexApproximator
using Test

using Flux
using Flux.Optimise: update!
using Flux.Data: DataLoader
using Transducers
using Random

using Convex
using Optim, Mosek, MosekTools


function test(; seed=2021)
    Random.seed!(seed)
    # network
    n = 3
    m = 2
    d = 1000
    i_max = 20
    h_array = [16, 16, 16]
    act = Flux.relu
    Q̂ = pMA(n, m, i_max, h_array; act=act)
    # data
    xs = 1:d |> Map(i -> rand(n)) |> collect
    us = 1:d |> Map(i -> rand(m)) |> collect
    rs = 1:d |> collect
    x_train = hcat(xs[1:end-1]...)
    u_train = hcat(us[1:end-1]...)
    x_next_train = hcat(xs[2:end]...)
    Δr_train = hcat(diff(rs)...)
    data_train = (x_train, u_train, x_next_train, Δr_train)
    dataloader = DataLoader(data_train..., batchsize=64, shuffle=true)
    opt = ADAM(1e-3)
    ps = params(Q̂)
    _ps = deepcopy(ps)
    epochs = 10
    for epoch in 1:epochs
        training_loss = custom_train!(Q̂, dataloader, opt)
        @show training_loss
    end
    @test _ps != params(Q̂)
end

function custom_train!(Q̂, data, opt)
    local training_loss
    γ = 1.00
    τ = 0.1
    _Q̂ = deepcopy(Q̂)
    _ps = params(_Q̂)
    for d in data
        num_data = size(d[1])[2]
        x_train = d[1]
        u_train = d[2]
        x_next_train = d[3]
        Δr_train = d[4]
        Q_min_train = zeros(1, num_data)
        for i in 1:num_data
            u = Convex.Variable(size(u_train)[1])
            problem = Convex.minimize(_Q̂(x_next_train[:, i], u))
            solve!(problem, Mosek.Optimizer(); silent_solver=true)
            Q_min_train[1, i] = problem.optval
        end
        gs = gradient(_ps) do
            training_loss = Flux.Losses.mse(_Q̂(x_train, u_train), Δr_train + γ*Q_min_train)
            return training_loss
        end
        update!(opt, _ps, gs)
    end
    Flux.loadparams!(Q̂, τ.*params(_Q̂) .+ (1-τ).*params(Q̂))
    return training_loss
end