I often use Convex.jl and Flux.jl.
When I use them, it is really tough to combine them. Especially, I wanna train a network using Flux.jl when the loss function includes optimal values by solving optimisation problems using Convex.jl.
For example, for d data points, Flux.jl usually provides a network that receives n x d data and provides m x d  data. Loss function also provides 1 x d data.
But in my experience, it is hard to recieve n x d data and optimise some function for each data to provide 1 x d loss which is compatible with Flux.jl.
+) I can combine the two packages only for 1-dimensional array inputs case by expanding existing methods. However, it is not applicable for 2-dimensional array inputs, e.g., n x d data.
Here is an example that describes Convex.jl + Flux.jl.
- Example
using Convex
using Flux
using Flux.Data: DataLoader
using Mosek, MosekTools
n = 3
m = 2
num = 10
x = rand(n, num)
u = rand(m, num)
data_train = (x, u)
dataloader = DataLoader(data_train..., batchsize=64, shuffle=true)
opt = ADAM(1e-3)
network = Chain(Dense(n+m, 1))
(network::Chain)(x_u::Convex.AbstractExpr) = network.layers[1].W * x_u + (network.layers[1].b .+ network.layers[1].W * zeros(size(x_u)...))
function test()
    _u = Convex.Variable(m)
    problem = Convex.minimize(network(vcat(rand(n), _u)))
    solve!(problem, Mosek.Optimizer())
    optval = deepcopy(problem.optval)
end
function main()
    loss = function (x, u)
        _u = Convex.Variable(size(u)...)
        problem = Convex.minimize(network(vcat(x, _u)))
        solve!(problem, Mosek.Optimizer())
        optval = deepcopy(problem.optval)
        # f = network(vcat(x, u))
        sum(f)
    end
    sqnorm(x) = sum(abs2, x)
    loss_reg(args...) = loss(args...) + 1e-3 * sum(sqnorm, Flux.params(network))
    Flux.train!(loss_reg, Flux.params(network), dataloader, opt)
end
- when running test: it works
julia> test()
Problem
  Name                   :
  Objective sense        : min
  Type                   : LO (linear optimization problem)
  Constraints            : 1
  Cones                  : 0
  Scalar variables       : 3
  Matrix variables       : 0
  Integer variables      : 0
Optimizer started.
Presolve started.
Eliminator started.
Freed constraints in eliminator : 0
Eliminator terminated.
Eliminator started.
Freed constraints in eliminator : 0
Eliminator terminated.
Eliminator - tries                  : 2                 time                   : 0.00
Lin. dep.  - tries                  : 0                 time                   : 0.00
Lin. dep.  - number                 : 0
Presolve terminated. Time: 0.00
Optimizer terminated. Time: 0.01
┌ Warning: Problem status DUAL_INFEASIBLE; solution may be inaccurate.
└ @ Convex ~/.julia/packages/Convex/Zv1ch/src/solution.jl:253
-1.0
- when running main: not works
julia> main()
ERROR: MethodError: no method matching zero(::Convex.AdditionAtom)
Closest candidates are:
  zero(::Type{Dates.DateTime}) at /Users/julia/buildbot/worker/package_macos64/build/usr/share/julia/stdlib/v1.5/Dates/src/types.jl:404
  zero(::Type{Dates.Date}) at /Users/julia/buildbot/worker/package_macos64/build/usr/share/julia/stdlib/v1.5/Dates/src/types.jl:405
  zero(::Type{ModelingToolkit.TermCombination}) at /Users/jinrae/.julia/packages/ModelingToolkit/hkIWj/src/linearity.jl:67
  ...
Stacktrace:
 [1] iszero(::Convex.AdditionAtom) at ./number.jl:40
 [2] rrule(::typeof(sign), ::Convex.AdditionAtom) at /Users/jinrae/.julia/packages/ChainRules/wuTHR/src/rulesets/Base/fastmath_able.jl:197
 [3] chain_rrule at /Users/jinrae/.julia/packages/Zygote/EjVY4/src/compiler/chainrules.jl:87 [inlined]
 [4] macro expansion at /Users/jinrae/.julia/packages/Zygote/EjVY4/src/compiler/interface2.jl:0 [inlined]
 [5] _pullback(::Zygote.Context, ::typeof(sign), ::Convex.AdditionAtom) at /Users/jinrae/.julia/packages/Zygote/EjVY4/src/compiler/interface2.jl:12
 [6] Problem at /Users/jinrae/.julia/packages/Convex/Zv1ch/src/problems.jl:10 [inlined]
 [7] _pullback(::Zygote.Context, ::Type{Problem{Float64}}, ::Symbol, ::Convex.AdditionAtom, ::Array{Constraint,1}) at /Users/jinrae/.julia/packages/Zygote/EjVY4/src/compiler/interface2.jl:0
 [8] #minimize#4 at /Users/jinrae/.julia/packages/Convex/Zv1ch/src/problems.jl:74 [inlined]
 [9] _pullback(::Zygote.Context, ::Convex.var"##minimize#4", ::Type{Float64}, ::typeof(minimize), ::Convex.AdditionAtom, ::Array{Constraint,1}) at /Users/jinrae/.julia/packages/Zygote/EjVY4/src/compiler/interface2.jl:0
 [10] minimize at /Users/jinrae/.julia/packages/Convex/Zv1ch/src/problems.jl:74 [inlined]
 [11] _pullback(::Zygote.Context, ::typeof(minimize), ::Convex.AdditionAtom, ::Array{Constraint,1}) at /Users/jinrae/.julia/packages/Zygote/EjVY4/src/compiler/interface2.jl:0
 [12] minimize at /Users/jinrae/.julia/packages/Convex/Zv1ch/src/problems.jl:74 [inlined]
 [13] _pullback(::Zygote.Context, ::typeof(minimize), ::Convex.AdditionAtom) at /Users/jinrae/.julia/packages/Zygote/EjVY4/src/compiler/interface2.jl:0
 [14] #149 at /Users/jinrae/.julia/dev/PartiallyConvexApproximator/test/cvx_flux.jl:28 [inlined]
 [15] _pullback(::Zygote.Context, ::var"#149#150", ::Array{Float64,2}, ::Array{Float64,2}) at /Users/jinrae/.julia/packages/Zygote/EjVY4/src/compiler/interface2.jl:0
 [16] adjoint at /Users/jinrae/.julia/packages/Zygote/EjVY4/src/lib/lib.jl:175 [inlined]
 [17] _pullback at /Users/jinrae/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:57 [inlined]
 [18] loss_reg at /Users/jinrae/.julia/dev/PartiallyConvexApproximator/test/cvx_flux.jl:35 [inlined]
 [19] _pullback(::Zygote.Context, ::var"#loss_reg#152"{var"#149#150",var"#sqnorm#151"}, ::Array{Float64,2}, ::Array{Float64,2}) at /Users/jinrae/.julia/packages/Zygote/EjVY4/src/compiler/interface2.jl:0
 [20] adjoint at /Users/jinrae/.julia/packages/Zygote/EjVY4/src/lib/lib.jl:175 [inlined]
 [21] _pullback at /Users/jinrae/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:57 [inlined]
 [22] #15 at /Users/jinrae/.julia/packages/Flux/sY3yx/src/optimise/train.jl:103 [inlined]
 [23] _pullback(::Zygote.Context, ::Flux.Optimise.var"#15#21"{var"#loss_reg#152"{var"#149#150",var"#sqnorm#151"},Tuple{Array{Float64,2},Array{Float64,2}}}) at /Users/jinrae/.julia/packages/Zygote/EjVY4/src/compiler/interface2.jl:0
 [24] pullback(::Function, ::Zygote.Params) at /Applications/Julia-1.5.app/Contents/Resources/julia/lib/julia/sys.dylib:?
 [25] gradient(::Function, ::Zygote.Params) at /Applications/Julia-1.5.app/Contents/Resources/julia/lib/julia/sys.dylib:?
 [26] macro expansion at /Users/jinrae/.julia/packages/Flux/sY3yx/src/optimise/train.jl:102 [inlined]
 [27] macro expansion at /Users/jinrae/.julia/packages/Juno/n6wyj/src/progress.jl:134 [inlined]
 [28] train!(::Function, ::Zygote.Params, ::DataLoader{Tuple{Array{Float64,2},Array{Float64,2}}}, ::ADAM; cb::Flux.Optimise.var"#16#22") at /Users/jinrae/.julia/packages/Flux/sY3yx/src/optimise/train.jl:100
 [29] train!(::Function, ::Zygote.Params, ::DataLoader{Tuple{Array{Float64,2},Array{Float64,2}}}, ::ADAM) at /Users/jinrae/.julia/packages/Flux/sY3yx/src/optimise/train.jl:98
 [30] main() at /Users/jinrae/.julia/dev/PartiallyConvexApproximator/test/cvx_flux.jl:36
 [31] top-level scope at REPL[103]:1
It’s not the problem only for Dense. When using some custom networks, it is much difficult to deal with it.
Is there any great idea to combine Convex.jl and Flux.jl?