This question is based on the previous question.
I found that it’s hard to incorporate Convex.jl and Flux.jl.
So, I tried to “approximate” the solution by untracking gradients of the solution of Convex.jl from Flux.jl (probably Zygote.jl).
I tried to use deepcopy
but it seems not work.
How to untrack the gradient of Convex.jl solution?
using Convex
using Flux
using Flux.Data: DataLoader
using Mosek, MosekTools
using Zygote
n = 3
m = 2
num = 10
x = rand(n, num)
u = rand(m, num)
data_train = (x, u)
dataloader = DataLoader(data_train..., batchsize=64, shuffle=true)
opt = ADAM(1e-3)
network = Chain(Dense(n+m, 1))
(network::Chain)(x_u::Convex.AbstractExpr) = network.layers[1].W * x_u + (network.layers[1].b .+ network.layers[1].W * zeros(size(x_u)...))
function test()
__network = deepcopy(network)
d = size(u)[2]
problems = [Convex.minimize(__network(vcat(x[:, i], Convex.Variable(size(u[:, i])...)))) for i in 1:d]
for i in 1:d
solve!(problems[i], Mosek.Optimizer())
end
optval = [problems[i].optval for i in 1:d]
sum(optval)
end
function main()
loss = function (x, u)
__network = deepcopy(network)
d = size(u)[2]
problems = [Convex.minimize(__network(vcat(x[:, i], Convex.Variable(size(u[:, i])...)))) for i in 1:d]
for i in 1:d
solve!(problems[i], Mosek.Optimizer())
end
optval = [problems[i].optval for i in 1:d]
sum(optval)
end
sqnorm(x) = sum(abs2, x)
loss_reg(args...) = loss(args...) + 1e-3 * sum(sqnorm, Flux.params(network))
Flux.train!(loss_reg, Flux.params(network), dataloader, opt)
end
julia> main()
ERROR: MethodError: no method matching zero(::Convex.AdditionAtom)
Closest candidates are:
zero(::Type{Dates.DateTime}) at /Users/julia/buildbot/worker/package_macos64/build/usr/share/julia/stdlib/v1.5/Dates/src/types.jl:404
zero(::Type{Dates.Date}) at /Users/julia/buildbot/worker/package_macos64/build/usr/share/julia/stdlib/v1.5/Dates/src/types.jl:405
zero(::Type{ModelingToolkit.TermCombination}) at /Users/jinrae/.julia/packages/ModelingToolkit/hkIWj/src/linearity.jl:67
...
Stacktrace:
[1] iszero(::Convex.AdditionAtom) at ./number.jl:40
[2] rrule(::typeof(sign), ::Convex.AdditionAtom) at /Users/jinrae/.julia/packages/ChainRules/wuTHR/src/rulesets/Base/fastmath_able.jl:197
[3] chain_rrule at /Users/jinrae/.julia/packages/Zygote/EjVY4/src/compiler/chainrules.jl:87 [inlined]
[4] macro expansion at /Users/jinrae/.julia/packages/Zygote/EjVY4/src/compiler/interface2.jl:0 [inlined]
[5] _pullback(::Zygote.Context, ::typeof(sign), ::Convex.AdditionAtom) at /Users/jinrae/.julia/packages/Zygote/EjVY4/src/compiler/interface2.jl:12
[6] Problem at /Users/jinrae/.julia/packages/Convex/Zv1ch/src/problems.jl:10 [inlined]
[7] _pullback(::Zygote.Context, ::Type{Problem{Float64}}, ::Symbol, ::Convex.AdditionAtom, ::Array{Constraint,1}) at /Users/jinrae/.julia/packages/Zygote/EjVY4/src/compiler/interface2.jl:0
[8] #minimize#4 at /Users/jinrae/.julia/packages/Convex/Zv1ch/src/problems.jl:74 [inlined]
[9] _pullback(::Zygote.Context, ::Convex.var"##minimize#4", ::Type{Float64}, ::typeof(minimize), ::Convex.AdditionAtom, ::Array{Constraint,1}) at /Users/jinrae/.julia/packages/Zygote/EjVY4/src/compiler/interface2.jl:0
[10] minimize at /Users/jinrae/.julia/packages/Convex/Zv1ch/src/problems.jl:74 [inlined]
[11] _pullback(::Zygote.Context, ::typeof(minimize), ::Convex.AdditionAtom, ::Array{Constraint,1}) at /Users/jinrae/.julia/packages/Zygote/EjVY4/src/compiler/interface2.jl:0
[12] minimize at /Users/jinrae/.julia/packages/Convex/Zv1ch/src/problems.jl:74 [inlined]
[13] _pullback(::Zygote.Context, ::typeof(minimize), ::Convex.AdditionAtom) at /Users/jinrae/.julia/packages/Zygote/EjVY4/src/compiler/interface2.jl:0
[14] #76 at ./none:0 [inlined]
[15] _pullback(::Zygote.Context, ::var"#76#79"{Array{Float64,2},Array{Float64,2},Chain{Tuple{Dense{typeof(identity),Array{Float32,2},Array{Float32,1}}}}}, ::Int64) at /Users/jinrae/.julia/packages/Zygote/EjVY4/src/compiler/interface2.jl:0
[16] (::Zygote.var"#586#590"{Zygote.Context,var"#76#79"{Array{Float64,2},Array{Float64,2},Chain{Tuple{Dense{typeof(identity),Array{Float32,2},Array{Float32,1}}}}}})(::Int64) at /Users/jinrae/.julia/packages/Zygote/EjVY4/src/lib/array.jl:174
[17] iterate at ./generator.jl:47 [inlined]
[18] _collect at ./array.jl:699 [inlined]
[19] collect_similar at ./array.jl:628 [inlined]
[20] map at ./abstractarray.jl:2162 [inlined]
[21] ∇map at /Users/jinrae/.julia/packages/Zygote/EjVY4/src/lib/array.jl:174 [inlined]
[22] _pullback(::Zygote.Context, ::typeof(collect), ::Base.Generator{UnitRange{Int64},var"#76#79"{Array{Float64,2},Array{Float64,2},Chain{Tuple{Dense{typeof(identity),Array{Float32,2},Array{Float32,1}}}}}}) at /Users/jinrae/.julia/packages/Zygote/EjVY4/src/lib/array.jl:193
[23] #75 at /Users/jinrae/.julia/dev/PartiallyConvexApproximator/test/cvx_flux.jl:35 [inlined]
[24] _pullback(::Zygote.Context, ::var"#75#78", ::Array{Float64,2}, ::Array{Float64,2}) at /Users/jinrae/.julia/packages/Zygote/EjVY4/src/compiler/interface2.jl:0
[25] adjoint at /Users/jinrae/.julia/packages/Zygote/EjVY4/src/lib/lib.jl:175 [inlined]
[26] _pullback at /Users/jinrae/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:57 [inlined]
[27] loss_reg at /Users/jinrae/.julia/dev/PartiallyConvexApproximator/test/cvx_flux.jl:43 [inlined]
[28] _pullback(::Zygote.Context, ::var"#loss_reg#82"{var"#75#78",var"#sqnorm#81"}, ::Array{Float64,2}, ::Array{Float64,2}) at /Users/jinrae/.julia/packages/Zygote/EjVY4/src/compiler/interface2.jl:0
[29] adjoint at /Users/jinrae/.julia/packages/Zygote/EjVY4/src/lib/lib.jl:175 [inlined]
[30] _pullback at /Users/jinrae/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:57 [inlined]
[31] #15 at /Users/jinrae/.julia/packages/Flux/sY3yx/src/optimise/train.jl:103 [inlined]
[32] _pullback(::Zygote.Context, ::Flux.Optimise.var"#15#21"{var"#loss_reg#82"{var"#75#78",var"#sqnorm#81"},Tuple{Array{Float64,2},Array{Float64,2}}}) at /Users/jinrae/.julia/packages/Zygote/EjVY4/src/compiler/interface2.jl:0
[33] pullback(::Function, ::Params) at /Applications/Julia-1.5.app/Contents/Resources/julia/lib/julia/sys.dylib:?
[34] gradient(::Function, ::Params) at /Applications/Julia-1.5.app/Contents/Resources/julia/lib/julia/sys.dylib:?
[35] macro expansion at /Users/jinrae/.julia/packages/Flux/sY3yx/src/optimise/train.jl:102 [inlined]
[36] macro expansion at /Users/jinrae/.julia/packages/Juno/n6wyj/src/progress.jl:134 [inlined]
[37] train!(::Function, ::Params, ::DataLoader{Tuple{Array{Float64,2},Array{Float64,2}}}, ::ADAM; cb::Flux.Optimise.var"#16#22") at /Users/jinrae/.julia/packages/Flux/sY3yx/src/optimise/train.jl:100
[38] train!(::Function, ::Params, ::DataLoader{Tuple{Array{Float64,2},Array{Float64,2}}}, ::ADAM) at /Users/jinrae/.julia/packages/Flux/sY3yx/src/optimise/train.jl:98
[39] main() at /Users/jinrae/.julia/dev/PartiallyConvexApproximator/test/cvx_flux.jl:44
[40] top-level scope at REPL[17]:1
julia> test()
Problem
Name :
Objective sense : min
Type : LO (linear optimization problem)
Constraints : 1
Cones : 0
Scalar variables : 3
Matrix variables : 0
Integer variables : 0
Optimizer started.
Presolve started.
Eliminator started.
Freed constraints in eliminator : 0
Eliminator terminated.
Eliminator started.
Freed constraints in eliminator : 0
Eliminator terminated.
Eliminator - tries : 2 time : 0.00
Lin. dep. - tries : 0 time : 0.00
Lin. dep. - number : 0
Presolve terminated. Time: 0.00
Optimizer terminated. Time: 0.00
┌ Warning: Problem status DUAL_INFEASIBLE; solution may be inaccurate.
└ @ Convex ~/.julia/packages/Convex/Zv1ch/src/solution.jl:253
Problem
Name :
Objective sense : min
Type : LO (linear optimization problem)
Constraints : 1
Cones : 0
Scalar variables : 3
Matrix variables : 0
Integer variables : 0
Optimizer started.
Presolve started.
Eliminator started.
Freed constraints in eliminator : 0
Eliminator terminated.
Eliminator started.
Freed constraints in eliminator : 0
Eliminator terminated.
Eliminator - tries : 2 time : 0.00
Lin. dep. - tries : 0 time : 0.00
Lin. dep. - number : 0
Presolve terminated. Time: 0.00
Optimizer terminated. Time: 0.00
┌ Warning: Problem status DUAL_INFEASIBLE; solution may be inaccurate.
└ @ Convex ~/.julia/packages/Convex/Zv1ch/src/solution.jl:253
Problem
Name :
Objective sense : min
Type : LO (linear optimization problem)
Constraints : 1
Cones : 0
Scalar variables : 3
Matrix variables : 0
Integer variables : 0
Optimizer started.
Presolve started.
Eliminator started.
Freed constraints in eliminator : 0
Eliminator terminated.
Eliminator started.
Freed constraints in eliminator : 0
Eliminator terminated.
Eliminator - tries : 2 time : 0.00
Lin. dep. - tries : 0 time : 0.00
Lin. dep. - number : 0
Presolve terminated. Time: 0.00
Optimizer terminated. Time: 0.00
┌ Warning: Problem status DUAL_INFEASIBLE; solution may be inaccurate.
└ @ Convex ~/.julia/packages/Convex/Zv1ch/src/solution.jl:253
Problem
Name :
Objective sense : min
Type : LO (linear optimization problem)
Constraints : 1
Cones : 0
Scalar variables : 3
Matrix variables : 0
Integer variables : 0
Optimizer started.
Presolve started.
Eliminator started.
Freed constraints in eliminator : 0
Eliminator terminated.
Eliminator started.
Freed constraints in eliminator : 0
Eliminator terminated.
Eliminator - tries : 2 time : 0.00
Lin. dep. - tries : 0 time : 0.00
Lin. dep. - number : 0
Presolve terminated. Time: 0.00
Optimizer terminated. Time: 0.00
┌ Warning: Problem status DUAL_INFEASIBLE; solution may be inaccurate.
└ @ Convex ~/.julia/packages/Convex/Zv1ch/src/solution.jl:253
Problem
Name :
Objective sense : min
Type : LO (linear optimization problem)
Constraints : 1
Cones : 0
Scalar variables : 3
Matrix variables : 0
Integer variables : 0
Optimizer started.
Presolve started.
Eliminator started.
Freed constraints in eliminator : 0
Eliminator terminated.
Eliminator started.
Freed constraints in eliminator : 0
Eliminator terminated.
Eliminator - tries : 2 time : 0.00
Lin. dep. - tries : 0 time : 0.00
Lin. dep. - number : 0
Presolve terminated. Time: 0.00
Optimizer terminated. Time: 0.00
┌ Warning: Problem status DUAL_INFEASIBLE; solution may be inaccurate.
└ @ Convex ~/.julia/packages/Convex/Zv1ch/src/solution.jl:253
Problem
Name :
Objective sense : min
Type : LO (linear optimization problem)
Constraints : 1
Cones : 0
Scalar variables : 3
Matrix variables : 0
Integer variables : 0
Optimizer started.
Presolve started.
Eliminator started.
Freed constraints in eliminator : 0
Eliminator terminated.
Eliminator started.
Freed constraints in eliminator : 0
Eliminator terminated.
Eliminator - tries : 2 time : 0.00
Lin. dep. - tries : 0 time : 0.00
Lin. dep. - number : 0
Presolve terminated. Time: 0.00
Optimizer terminated. Time: 0.00
┌ Warning: Problem status DUAL_INFEASIBLE; solution may be inaccurate.
└ @ Convex ~/.julia/packages/Convex/Zv1ch/src/solution.jl:253
Problem
Name :
Objective sense : min
Type : LO (linear optimization problem)
Constraints : 1
Cones : 0
Scalar variables : 3
Matrix variables : 0
Integer variables : 0
Optimizer started.
Presolve started.
Eliminator started.
Freed constraints in eliminator : 0
Eliminator terminated.
Eliminator started.
Freed constraints in eliminator : 0
Eliminator terminated.
Eliminator - tries : 2 time : 0.00
Lin. dep. - tries : 0 time : 0.00
Lin. dep. - number : 0
Presolve terminated. Time: 0.00
Optimizer terminated. Time: 0.00
┌ Warning: Problem status DUAL_INFEASIBLE; solution may be inaccurate.
└ @ Convex ~/.julia/packages/Convex/Zv1ch/src/solution.jl:253
Problem
Name :
Objective sense : min
Type : LO (linear optimization problem)
Constraints : 1
Cones : 0
Scalar variables : 3
Matrix variables : 0
Integer variables : 0
Optimizer started.
Presolve started.
Eliminator started.
Freed constraints in eliminator : 0
Eliminator terminated.
Eliminator started.
Freed constraints in eliminator : 0
Eliminator terminated.
Eliminator - tries : 2 time : 0.00
Lin. dep. - tries : 0 time : 0.00
Lin. dep. - number : 0
Presolve terminated. Time: 0.00
Optimizer terminated. Time: 0.00
┌ Warning: Problem status DUAL_INFEASIBLE; solution may be inaccurate.
└ @ Convex ~/.julia/packages/Convex/Zv1ch/src/solution.jl:253
Problem
Name :
Objective sense : min
Type : LO (linear optimization problem)
Constraints : 1
Cones : 0
Scalar variables : 3
Matrix variables : 0
Integer variables : 0
Optimizer started.
Presolve started.
Eliminator started.
Freed constraints in eliminator : 0
Eliminator terminated.
Eliminator started.
Freed constraints in eliminator : 0
Eliminator terminated.
Eliminator - tries : 2 time : 0.00
Lin. dep. - tries : 0 time : 0.00
Lin. dep. - number : 0
Presolve terminated. Time: 0.00
Optimizer terminated. Time: 0.00
┌ Warning: Problem status DUAL_INFEASIBLE; solution may be inaccurate.
└ @ Convex ~/.julia/packages/Convex/Zv1ch/src/solution.jl:253
Problem
Name :
Objective sense : min
Type : LO (linear optimization problem)
Constraints : 1
Cones : 0
Scalar variables : 3
Matrix variables : 0
Integer variables : 0
Optimizer started.
Presolve started.
Eliminator started.
Freed constraints in eliminator : 0
Eliminator terminated.
Eliminator started.
Freed constraints in eliminator : 0
Eliminator terminated.
Eliminator - tries : 2 time : 0.00
Lin. dep. - tries : 0 time : 0.00
Lin. dep. - number : 0
Presolve terminated. Time: 0.00
Optimizer terminated. Time: 0.00
┌ Warning: Problem status DUAL_INFEASIBLE; solution may be inaccurate.
└ @ Convex ~/.julia/packages/Convex/Zv1ch/src/solution.jl:253
-10.0
Note: In the above code, the desired situation is that no gradient update is performed.