I often use Convex.jl and Flux.jl.

When I use them, it is really tough to combine them. Especially, I wanna train a network using Flux.jl when the loss function includes optimal values by solving optimisation problems using Convex.jl.

For example, for d data points, Flux.jl usually provides a network that receives n x d data and provides m x d data. Loss function also provides 1 x d data.

But in my experience, it is hard to recieve n x d data and optimise some function for each data to provide 1 x d loss which is compatible with Flux.jl.

+) I can combine the two packages only for 1-dimensional array inputs case by expanding existing methods. However, it is not applicable for 2-dimensional array inputs, e.g., n x d data.

Here is an example that describes Convex.jl + Flux.jl.

- Example

```
using Convex
using Flux
using Flux.Data: DataLoader
using Mosek, MosekTools
n = 3
m = 2
num = 10
x = rand(n, num)
u = rand(m, num)
data_train = (x, u)
dataloader = DataLoader(data_train..., batchsize=64, shuffle=true)
opt = ADAM(1e-3)
network = Chain(Dense(n+m, 1))
(network::Chain)(x_u::Convex.AbstractExpr) = network.layers[1].W * x_u + (network.layers[1].b .+ network.layers[1].W * zeros(size(x_u)...))
function test()
_u = Convex.Variable(m)
problem = Convex.minimize(network(vcat(rand(n), _u)))
solve!(problem, Mosek.Optimizer())
optval = deepcopy(problem.optval)
end
function main()
loss = function (x, u)
_u = Convex.Variable(size(u)...)
problem = Convex.minimize(network(vcat(x, _u)))
solve!(problem, Mosek.Optimizer())
optval = deepcopy(problem.optval)
# f = network(vcat(x, u))
sum(f)
end
sqnorm(x) = sum(abs2, x)
loss_reg(args...) = loss(args...) + 1e-3 * sum(sqnorm, Flux.params(network))
Flux.train!(loss_reg, Flux.params(network), dataloader, opt)
end
```

- when running
`test`

: it works

```
julia> test()
Problem
Name :
Objective sense : min
Type : LO (linear optimization problem)
Constraints : 1
Cones : 0
Scalar variables : 3
Matrix variables : 0
Integer variables : 0
Optimizer started.
Presolve started.
Eliminator started.
Freed constraints in eliminator : 0
Eliminator terminated.
Eliminator started.
Freed constraints in eliminator : 0
Eliminator terminated.
Eliminator - tries : 2 time : 0.00
Lin. dep. - tries : 0 time : 0.00
Lin. dep. - number : 0
Presolve terminated. Time: 0.00
Optimizer terminated. Time: 0.01
┌ Warning: Problem status DUAL_INFEASIBLE; solution may be inaccurate.
└ @ Convex ~/.julia/packages/Convex/Zv1ch/src/solution.jl:253
-1.0
```

- when running
`main`

: not works

```
julia> main()
ERROR: MethodError: no method matching zero(::Convex.AdditionAtom)
Closest candidates are:
zero(::Type{Dates.DateTime}) at /Users/julia/buildbot/worker/package_macos64/build/usr/share/julia/stdlib/v1.5/Dates/src/types.jl:404
zero(::Type{Dates.Date}) at /Users/julia/buildbot/worker/package_macos64/build/usr/share/julia/stdlib/v1.5/Dates/src/types.jl:405
zero(::Type{ModelingToolkit.TermCombination}) at /Users/jinrae/.julia/packages/ModelingToolkit/hkIWj/src/linearity.jl:67
...
Stacktrace:
[1] iszero(::Convex.AdditionAtom) at ./number.jl:40
[2] rrule(::typeof(sign), ::Convex.AdditionAtom) at /Users/jinrae/.julia/packages/ChainRules/wuTHR/src/rulesets/Base/fastmath_able.jl:197
[3] chain_rrule at /Users/jinrae/.julia/packages/Zygote/EjVY4/src/compiler/chainrules.jl:87 [inlined]
[4] macro expansion at /Users/jinrae/.julia/packages/Zygote/EjVY4/src/compiler/interface2.jl:0 [inlined]
[5] _pullback(::Zygote.Context, ::typeof(sign), ::Convex.AdditionAtom) at /Users/jinrae/.julia/packages/Zygote/EjVY4/src/compiler/interface2.jl:12
[6] Problem at /Users/jinrae/.julia/packages/Convex/Zv1ch/src/problems.jl:10 [inlined]
[7] _pullback(::Zygote.Context, ::Type{Problem{Float64}}, ::Symbol, ::Convex.AdditionAtom, ::Array{Constraint,1}) at /Users/jinrae/.julia/packages/Zygote/EjVY4/src/compiler/interface2.jl:0
[8] #minimize#4 at /Users/jinrae/.julia/packages/Convex/Zv1ch/src/problems.jl:74 [inlined]
[9] _pullback(::Zygote.Context, ::Convex.var"##minimize#4", ::Type{Float64}, ::typeof(minimize), ::Convex.AdditionAtom, ::Array{Constraint,1}) at /Users/jinrae/.julia/packages/Zygote/EjVY4/src/compiler/interface2.jl:0
[10] minimize at /Users/jinrae/.julia/packages/Convex/Zv1ch/src/problems.jl:74 [inlined]
[11] _pullback(::Zygote.Context, ::typeof(minimize), ::Convex.AdditionAtom, ::Array{Constraint,1}) at /Users/jinrae/.julia/packages/Zygote/EjVY4/src/compiler/interface2.jl:0
[12] minimize at /Users/jinrae/.julia/packages/Convex/Zv1ch/src/problems.jl:74 [inlined]
[13] _pullback(::Zygote.Context, ::typeof(minimize), ::Convex.AdditionAtom) at /Users/jinrae/.julia/packages/Zygote/EjVY4/src/compiler/interface2.jl:0
[14] #149 at /Users/jinrae/.julia/dev/PartiallyConvexApproximator/test/cvx_flux.jl:28 [inlined]
[15] _pullback(::Zygote.Context, ::var"#149#150", ::Array{Float64,2}, ::Array{Float64,2}) at /Users/jinrae/.julia/packages/Zygote/EjVY4/src/compiler/interface2.jl:0
[16] adjoint at /Users/jinrae/.julia/packages/Zygote/EjVY4/src/lib/lib.jl:175 [inlined]
[17] _pullback at /Users/jinrae/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:57 [inlined]
[18] loss_reg at /Users/jinrae/.julia/dev/PartiallyConvexApproximator/test/cvx_flux.jl:35 [inlined]
[19] _pullback(::Zygote.Context, ::var"#loss_reg#152"{var"#149#150",var"#sqnorm#151"}, ::Array{Float64,2}, ::Array{Float64,2}) at /Users/jinrae/.julia/packages/Zygote/EjVY4/src/compiler/interface2.jl:0
[20] adjoint at /Users/jinrae/.julia/packages/Zygote/EjVY4/src/lib/lib.jl:175 [inlined]
[21] _pullback at /Users/jinrae/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:57 [inlined]
[22] #15 at /Users/jinrae/.julia/packages/Flux/sY3yx/src/optimise/train.jl:103 [inlined]
[23] _pullback(::Zygote.Context, ::Flux.Optimise.var"#15#21"{var"#loss_reg#152"{var"#149#150",var"#sqnorm#151"},Tuple{Array{Float64,2},Array{Float64,2}}}) at /Users/jinrae/.julia/packages/Zygote/EjVY4/src/compiler/interface2.jl:0
[24] pullback(::Function, ::Zygote.Params) at /Applications/Julia-1.5.app/Contents/Resources/julia/lib/julia/sys.dylib:?
[25] gradient(::Function, ::Zygote.Params) at /Applications/Julia-1.5.app/Contents/Resources/julia/lib/julia/sys.dylib:?
[26] macro expansion at /Users/jinrae/.julia/packages/Flux/sY3yx/src/optimise/train.jl:102 [inlined]
[27] macro expansion at /Users/jinrae/.julia/packages/Juno/n6wyj/src/progress.jl:134 [inlined]
[28] train!(::Function, ::Zygote.Params, ::DataLoader{Tuple{Array{Float64,2},Array{Float64,2}}}, ::ADAM; cb::Flux.Optimise.var"#16#22") at /Users/jinrae/.julia/packages/Flux/sY3yx/src/optimise/train.jl:100
[29] train!(::Function, ::Zygote.Params, ::DataLoader{Tuple{Array{Float64,2},Array{Float64,2}}}, ::ADAM) at /Users/jinrae/.julia/packages/Flux/sY3yx/src/optimise/train.jl:98
[30] main() at /Users/jinrae/.julia/dev/PartiallyConvexApproximator/test/cvx_flux.jl:36
[31] top-level scope at REPL[103]:1
```

It’s not the problem only for `Dense`

. When using some custom networks, it is much difficult to deal with it.

Is there any great idea to combine Convex.jl and Flux.jl?