# Best practice for Flux.jl: how to untrack gradient?

This question is based on the previous question.

I found that it’s hard to incorporate Convex.jl and Flux.jl.
So, I tried to “approximate” the solution by untracking gradients of the solution of Convex.jl from Flux.jl (probably Zygote.jl).

I tried to use `deepcopy` but it seems not work.
How to untrack the gradient of Convex.jl solution?

``````using Convex
using Flux
using Mosek, MosekTools
using Zygote

n = 3
m = 2
num = 10
x = rand(n, num)
u = rand(m, num)
data_train = (x, u)
network = Chain(Dense(n+m, 1))

(network::Chain)(x_u::Convex.AbstractExpr) = network.layers[1].W * x_u + (network.layers[1].b .+ network.layers[1].W * zeros(size(x_u)...))

function test()
__network = deepcopy(network)
d = size(u)[2]
problems = [Convex.minimize(__network(vcat(x[:, i], Convex.Variable(size(u[:, i])...)))) for i in 1:d]
for i in 1:d
solve!(problems[i], Mosek.Optimizer())
end
optval = [problems[i].optval for i in 1:d]
sum(optval)
end

function main()
loss = function (x, u)
__network = deepcopy(network)
d = size(u)[2]
problems = [Convex.minimize(__network(vcat(x[:, i], Convex.Variable(size(u[:, i])...)))) for i in 1:d]
for i in 1:d
solve!(problems[i], Mosek.Optimizer())
end
optval = [problems[i].optval for i in 1:d]
sum(optval)
end
sqnorm(x) = sum(abs2, x)
loss_reg(args...) = loss(args...) + 1e-3 * sum(sqnorm, Flux.params(network))
end
``````
``````julia> main()
ERROR: MethodError: no method matching zero(::Convex.AdditionAtom)
Closest candidates are:
zero(::Type{Dates.DateTime}) at /Users/julia/buildbot/worker/package_macos64/build/usr/share/julia/stdlib/v1.5/Dates/src/types.jl:404
zero(::Type{Dates.Date}) at /Users/julia/buildbot/worker/package_macos64/build/usr/share/julia/stdlib/v1.5/Dates/src/types.jl:405
zero(::Type{ModelingToolkit.TermCombination}) at /Users/jinrae/.julia/packages/ModelingToolkit/hkIWj/src/linearity.jl:67
...
Stacktrace:
[3] chain_rrule at /Users/jinrae/.julia/packages/Zygote/EjVY4/src/compiler/chainrules.jl:87 [inlined]
[4] macro expansion at /Users/jinrae/.julia/packages/Zygote/EjVY4/src/compiler/interface2.jl:0 [inlined]
[5] _pullback(::Zygote.Context, ::typeof(sign), ::Convex.AdditionAtom) at /Users/jinrae/.julia/packages/Zygote/EjVY4/src/compiler/interface2.jl:12
[6] Problem at /Users/jinrae/.julia/packages/Convex/Zv1ch/src/problems.jl:10 [inlined]
[7] _pullback(::Zygote.Context, ::Type{Problem{Float64}}, ::Symbol, ::Convex.AdditionAtom, ::Array{Constraint,1}) at /Users/jinrae/.julia/packages/Zygote/EjVY4/src/compiler/interface2.jl:0
[8] #minimize#4 at /Users/jinrae/.julia/packages/Convex/Zv1ch/src/problems.jl:74 [inlined]
[9] _pullback(::Zygote.Context, ::Convex.var"##minimize#4", ::Type{Float64}, ::typeof(minimize), ::Convex.AdditionAtom, ::Array{Constraint,1}) at /Users/jinrae/.julia/packages/Zygote/EjVY4/src/compiler/interface2.jl:0
[10] minimize at /Users/jinrae/.julia/packages/Convex/Zv1ch/src/problems.jl:74 [inlined]
[11] _pullback(::Zygote.Context, ::typeof(minimize), ::Convex.AdditionAtom, ::Array{Constraint,1}) at /Users/jinrae/.julia/packages/Zygote/EjVY4/src/compiler/interface2.jl:0
[12] minimize at /Users/jinrae/.julia/packages/Convex/Zv1ch/src/problems.jl:74 [inlined]
[13] _pullback(::Zygote.Context, ::typeof(minimize), ::Convex.AdditionAtom) at /Users/jinrae/.julia/packages/Zygote/EjVY4/src/compiler/interface2.jl:0
[14] #76 at ./none:0 [inlined]
[15] _pullback(::Zygote.Context, ::var"#76#79"{Array{Float64,2},Array{Float64,2},Chain{Tuple{Dense{typeof(identity),Array{Float32,2},Array{Float32,1}}}}}, ::Int64) at /Users/jinrae/.julia/packages/Zygote/EjVY4/src/compiler/interface2.jl:0
[16] (::Zygote.var"#586#590"{Zygote.Context,var"#76#79"{Array{Float64,2},Array{Float64,2},Chain{Tuple{Dense{typeof(identity),Array{Float32,2},Array{Float32,1}}}}}})(::Int64) at /Users/jinrae/.julia/packages/Zygote/EjVY4/src/lib/array.jl:174
[17] iterate at ./generator.jl:47 [inlined]
[18] _collect at ./array.jl:699 [inlined]
[19] collect_similar at ./array.jl:628 [inlined]
[20] map at ./abstractarray.jl:2162 [inlined]
[21] ∇map at /Users/jinrae/.julia/packages/Zygote/EjVY4/src/lib/array.jl:174 [inlined]
[22] _pullback(::Zygote.Context, ::typeof(collect), ::Base.Generator{UnitRange{Int64},var"#76#79"{Array{Float64,2},Array{Float64,2},Chain{Tuple{Dense{typeof(identity),Array{Float32,2},Array{Float32,1}}}}}}) at /Users/jinrae/.julia/packages/Zygote/EjVY4/src/lib/array.jl:193
[23] #75 at /Users/jinrae/.julia/dev/PartiallyConvexApproximator/test/cvx_flux.jl:35 [inlined]
[24] _pullback(::Zygote.Context, ::var"#75#78", ::Array{Float64,2}, ::Array{Float64,2}) at /Users/jinrae/.julia/packages/Zygote/EjVY4/src/compiler/interface2.jl:0
[27] loss_reg at /Users/jinrae/.julia/dev/PartiallyConvexApproximator/test/cvx_flux.jl:43 [inlined]
[28] _pullback(::Zygote.Context, ::var"#loss_reg#82"{var"#75#78",var"#sqnorm#81"}, ::Array{Float64,2}, ::Array{Float64,2}) at /Users/jinrae/.julia/packages/Zygote/EjVY4/src/compiler/interface2.jl:0
[31] #15 at /Users/jinrae/.julia/packages/Flux/sY3yx/src/optimise/train.jl:103 [inlined]
[32] _pullback(::Zygote.Context, ::Flux.Optimise.var"#15#21"{var"#loss_reg#82"{var"#75#78",var"#sqnorm#81"},Tuple{Array{Float64,2},Array{Float64,2}}}) at /Users/jinrae/.julia/packages/Zygote/EjVY4/src/compiler/interface2.jl:0
[33] pullback(::Function, ::Params) at /Applications/Julia-1.5.app/Contents/Resources/julia/lib/julia/sys.dylib:?
[35] macro expansion at /Users/jinrae/.julia/packages/Flux/sY3yx/src/optimise/train.jl:102 [inlined]
[36] macro expansion at /Users/jinrae/.julia/packages/Juno/n6wyj/src/progress.jl:134 [inlined]
[39] main() at /Users/jinrae/.julia/dev/PartiallyConvexApproximator/test/cvx_flux.jl:44
[40] top-level scope at REPL[17]:1

``````
``````julia> test()
Problem
Name                   :
Objective sense        : min
Type                   : LO (linear optimization problem)
Constraints            : 1
Cones                  : 0
Scalar variables       : 3
Matrix variables       : 0
Integer variables      : 0

Optimizer started.
Presolve started.
Eliminator started.
Freed constraints in eliminator : 0
Eliminator terminated.
Eliminator started.
Freed constraints in eliminator : 0
Eliminator terminated.
Eliminator - tries                  : 2                 time                   : 0.00
Lin. dep.  - tries                  : 0                 time                   : 0.00
Lin. dep.  - number                 : 0
Presolve terminated. Time: 0.00
Optimizer terminated. Time: 0.00

┌ Warning: Problem status DUAL_INFEASIBLE; solution may be inaccurate.
└ @ Convex ~/.julia/packages/Convex/Zv1ch/src/solution.jl:253
Problem
Name                   :
Objective sense        : min
Type                   : LO (linear optimization problem)
Constraints            : 1
Cones                  : 0
Scalar variables       : 3
Matrix variables       : 0
Integer variables      : 0

Optimizer started.
Presolve started.
Eliminator started.
Freed constraints in eliminator : 0
Eliminator terminated.
Eliminator started.
Freed constraints in eliminator : 0
Eliminator terminated.
Eliminator - tries                  : 2                 time                   : 0.00
Lin. dep.  - tries                  : 0                 time                   : 0.00
Lin. dep.  - number                 : 0
Presolve terminated. Time: 0.00
Optimizer terminated. Time: 0.00

┌ Warning: Problem status DUAL_INFEASIBLE; solution may be inaccurate.
└ @ Convex ~/.julia/packages/Convex/Zv1ch/src/solution.jl:253
Problem
Name                   :
Objective sense        : min
Type                   : LO (linear optimization problem)
Constraints            : 1
Cones                  : 0
Scalar variables       : 3
Matrix variables       : 0
Integer variables      : 0

Optimizer started.
Presolve started.
Eliminator started.
Freed constraints in eliminator : 0
Eliminator terminated.
Eliminator started.
Freed constraints in eliminator : 0
Eliminator terminated.
Eliminator - tries                  : 2                 time                   : 0.00
Lin. dep.  - tries                  : 0                 time                   : 0.00
Lin. dep.  - number                 : 0
Presolve terminated. Time: 0.00
Optimizer terminated. Time: 0.00

┌ Warning: Problem status DUAL_INFEASIBLE; solution may be inaccurate.
└ @ Convex ~/.julia/packages/Convex/Zv1ch/src/solution.jl:253
Problem
Name                   :
Objective sense        : min
Type                   : LO (linear optimization problem)
Constraints            : 1
Cones                  : 0
Scalar variables       : 3
Matrix variables       : 0
Integer variables      : 0

Optimizer started.
Presolve started.
Eliminator started.
Freed constraints in eliminator : 0
Eliminator terminated.
Eliminator started.
Freed constraints in eliminator : 0
Eliminator terminated.
Eliminator - tries                  : 2                 time                   : 0.00
Lin. dep.  - tries                  : 0                 time                   : 0.00
Lin. dep.  - number                 : 0
Presolve terminated. Time: 0.00
Optimizer terminated. Time: 0.00

┌ Warning: Problem status DUAL_INFEASIBLE; solution may be inaccurate.
└ @ Convex ~/.julia/packages/Convex/Zv1ch/src/solution.jl:253
Problem
Name                   :
Objective sense        : min
Type                   : LO (linear optimization problem)
Constraints            : 1
Cones                  : 0
Scalar variables       : 3
Matrix variables       : 0
Integer variables      : 0

Optimizer started.
Presolve started.
Eliminator started.
Freed constraints in eliminator : 0
Eliminator terminated.
Eliminator started.
Freed constraints in eliminator : 0
Eliminator terminated.
Eliminator - tries                  : 2                 time                   : 0.00
Lin. dep.  - tries                  : 0                 time                   : 0.00
Lin. dep.  - number                 : 0
Presolve terminated. Time: 0.00
Optimizer terminated. Time: 0.00

┌ Warning: Problem status DUAL_INFEASIBLE; solution may be inaccurate.
└ @ Convex ~/.julia/packages/Convex/Zv1ch/src/solution.jl:253
Problem
Name                   :
Objective sense        : min
Type                   : LO (linear optimization problem)
Constraints            : 1
Cones                  : 0
Scalar variables       : 3
Matrix variables       : 0
Integer variables      : 0

Optimizer started.
Presolve started.
Eliminator started.
Freed constraints in eliminator : 0
Eliminator terminated.
Eliminator started.
Freed constraints in eliminator : 0
Eliminator terminated.
Eliminator - tries                  : 2                 time                   : 0.00
Lin. dep.  - tries                  : 0                 time                   : 0.00
Lin. dep.  - number                 : 0
Presolve terminated. Time: 0.00
Optimizer terminated. Time: 0.00

┌ Warning: Problem status DUAL_INFEASIBLE; solution may be inaccurate.
└ @ Convex ~/.julia/packages/Convex/Zv1ch/src/solution.jl:253
Problem
Name                   :
Objective sense        : min
Type                   : LO (linear optimization problem)
Constraints            : 1
Cones                  : 0
Scalar variables       : 3
Matrix variables       : 0
Integer variables      : 0

Optimizer started.
Presolve started.
Eliminator started.
Freed constraints in eliminator : 0
Eliminator terminated.
Eliminator started.
Freed constraints in eliminator : 0
Eliminator terminated.
Eliminator - tries                  : 2                 time                   : 0.00
Lin. dep.  - tries                  : 0                 time                   : 0.00
Lin. dep.  - number                 : 0
Presolve terminated. Time: 0.00
Optimizer terminated. Time: 0.00

┌ Warning: Problem status DUAL_INFEASIBLE; solution may be inaccurate.
└ @ Convex ~/.julia/packages/Convex/Zv1ch/src/solution.jl:253
Problem
Name                   :
Objective sense        : min
Type                   : LO (linear optimization problem)
Constraints            : 1
Cones                  : 0
Scalar variables       : 3
Matrix variables       : 0
Integer variables      : 0

Optimizer started.
Presolve started.
Eliminator started.
Freed constraints in eliminator : 0
Eliminator terminated.
Eliminator started.
Freed constraints in eliminator : 0
Eliminator terminated.
Eliminator - tries                  : 2                 time                   : 0.00
Lin. dep.  - tries                  : 0                 time                   : 0.00
Lin. dep.  - number                 : 0
Presolve terminated. Time: 0.00
Optimizer terminated. Time: 0.00

┌ Warning: Problem status DUAL_INFEASIBLE; solution may be inaccurate.
└ @ Convex ~/.julia/packages/Convex/Zv1ch/src/solution.jl:253
Problem
Name                   :
Objective sense        : min
Type                   : LO (linear optimization problem)
Constraints            : 1
Cones                  : 0
Scalar variables       : 3
Matrix variables       : 0
Integer variables      : 0

Optimizer started.
Presolve started.
Eliminator started.
Freed constraints in eliminator : 0
Eliminator terminated.
Eliminator started.
Freed constraints in eliminator : 0
Eliminator terminated.
Eliminator - tries                  : 2                 time                   : 0.00
Lin. dep.  - tries                  : 0                 time                   : 0.00
Lin. dep.  - number                 : 0
Presolve terminated. Time: 0.00
Optimizer terminated. Time: 0.00

┌ Warning: Problem status DUAL_INFEASIBLE; solution may be inaccurate.
└ @ Convex ~/.julia/packages/Convex/Zv1ch/src/solution.jl:253
Problem
Name                   :
Objective sense        : min
Type                   : LO (linear optimization problem)
Constraints            : 1
Cones                  : 0
Scalar variables       : 3
Matrix variables       : 0
Integer variables      : 0

Optimizer started.
Presolve started.
Eliminator started.
Freed constraints in eliminator : 0
Eliminator terminated.
Eliminator started.
Freed constraints in eliminator : 0
Eliminator terminated.
Eliminator - tries                  : 2                 time                   : 0.00
Lin. dep.  - tries                  : 0                 time                   : 0.00
Lin. dep.  - number                 : 0
Presolve terminated. Time: 0.00
Optimizer terminated. Time: 0.00

┌ Warning: Problem status DUAL_INFEASIBLE; solution may be inaccurate.
└ @ Convex ~/.julia/packages/Convex/Zv1ch/src/solution.jl:253
-10.0

``````

Note: In the above code, the desired situation is that no gradient update is performed.

The answer was actually very easy, I’d not realised though.
I leave an example of Q-learning with separated two Q-networks and low pass filter (LPF)-like update.
Unfortunately, the following code is not a copy-and-paste-executable code. Just refer the structure of this code. For more details, see an example of Flux.jl.

``````using PartiallyConvexApproximator
using Test

using Flux
using Flux.Optimise: update!
using Transducers
using Random

using Convex
using Optim, Mosek, MosekTools

function test(; seed=2021)
Random.seed!(seed)
# network
n = 3
m = 2
d = 1000
i_max = 20
h_array = [16, 16, 16]
act = Flux.relu
Q̂ = pMA(n, m, i_max, h_array; act=act)
# data
xs = 1:d |> Map(i -> rand(n)) |> collect
us = 1:d |> Map(i -> rand(m)) |> collect
rs = 1:d |> collect
x_train = hcat(xs[1:end-1]...)
u_train = hcat(us[1:end-1]...)
x_next_train = hcat(xs[2:end]...)
Δr_train = hcat(diff(rs)...)
data_train = (x_train, u_train, x_next_train, Δr_train)
ps = params(Q̂)
_ps = deepcopy(ps)
epochs = 10
for epoch in 1:epochs
@show training_loss
end
@test _ps != params(Q̂)
end

function custom_train!(Q̂, data, opt)
local training_loss
γ = 1.00
τ = 0.1
_Q̂ = deepcopy(Q̂)
_ps = params(_Q̂)
for d in data
num_data = size(d[1])[2]
x_train = d[1]
u_train = d[2]
x_next_train = d[3]
Δr_train = d[4]
Q_min_train = zeros(1, num_data)
for i in 1:num_data
u = Convex.Variable(size(u_train)[1])
problem = Convex.minimize(_Q̂(x_next_train[:, i], u))
solve!(problem, Mosek.Optimizer(); silent_solver=true)
Q_min_train[1, i] = problem.optval
end
training_loss = Flux.Losses.mse(_Q̂(x_train, u_train), Δr_train + γ*Q_min_train)
return training_loss
end
update!(opt, _ps, gs)
end