Best practice to combine Flux.jl and Convex.jl

I often use Convex.jl and Flux.jl.
When I use them, it is really tough to combine them. Especially, I wanna train a network using Flux.jl when the loss function includes optimal values by solving optimisation problems using Convex.jl.
For example, for d data points, Flux.jl usually provides a network that receives n x d data and provides m x d data. Loss function also provides 1 x d data.
But in my experience, it is hard to recieve n x d data and optimise some function for each data to provide 1 x d loss which is compatible with Flux.jl.

+) I can combine the two packages only for 1-dimensional array inputs case by expanding existing methods. However, it is not applicable for 2-dimensional array inputs, e.g., n x d data.

Here is an example that describes Convex.jl + Flux.jl.

  • Example
using Convex
using Flux
using Flux.Data: DataLoader
using Mosek, MosekTools


n = 3
m = 2
num = 10
x = rand(n, num)
u = rand(m, num)
data_train = (x, u)
dataloader = DataLoader(data_train..., batchsize=64, shuffle=true)
opt = ADAM(1e-3)
network = Chain(Dense(n+m, 1))

(network::Chain)(x_u::Convex.AbstractExpr) = network.layers[1].W * x_u + (network.layers[1].b .+ network.layers[1].W * zeros(size(x_u)...))

function test()
    _u = Convex.Variable(m)
    problem = Convex.minimize(network(vcat(rand(n), _u)))
    solve!(problem, Mosek.Optimizer())
    optval = deepcopy(problem.optval)
end
function main()
    loss = function (x, u)
        _u = Convex.Variable(size(u)...)
        problem = Convex.minimize(network(vcat(x, _u)))
        solve!(problem, Mosek.Optimizer())
        optval = deepcopy(problem.optval)
        # f = network(vcat(x, u))
        sum(f)
    end
    sqnorm(x) = sum(abs2, x)
    loss_reg(args...) = loss(args...) + 1e-3 * sum(sqnorm, Flux.params(network))
    Flux.train!(loss_reg, Flux.params(network), dataloader, opt)
end
  • when running test: it works
julia> test()
Problem
  Name                   :
  Objective sense        : min
  Type                   : LO (linear optimization problem)
  Constraints            : 1
  Cones                  : 0
  Scalar variables       : 3
  Matrix variables       : 0
  Integer variables      : 0

Optimizer started.
Presolve started.
Eliminator started.
Freed constraints in eliminator : 0
Eliminator terminated.
Eliminator started.
Freed constraints in eliminator : 0
Eliminator terminated.
Eliminator - tries                  : 2                 time                   : 0.00
Lin. dep.  - tries                  : 0                 time                   : 0.00
Lin. dep.  - number                 : 0
Presolve terminated. Time: 0.00
Optimizer terminated. Time: 0.01

┌ Warning: Problem status DUAL_INFEASIBLE; solution may be inaccurate.
└ @ Convex ~/.julia/packages/Convex/Zv1ch/src/solution.jl:253
-1.0

  • when running main: not works
julia> main()
ERROR: MethodError: no method matching zero(::Convex.AdditionAtom)
Closest candidates are:
  zero(::Type{Dates.DateTime}) at /Users/julia/buildbot/worker/package_macos64/build/usr/share/julia/stdlib/v1.5/Dates/src/types.jl:404
  zero(::Type{Dates.Date}) at /Users/julia/buildbot/worker/package_macos64/build/usr/share/julia/stdlib/v1.5/Dates/src/types.jl:405
  zero(::Type{ModelingToolkit.TermCombination}) at /Users/jinrae/.julia/packages/ModelingToolkit/hkIWj/src/linearity.jl:67
  ...
Stacktrace:
 [1] iszero(::Convex.AdditionAtom) at ./number.jl:40
 [2] rrule(::typeof(sign), ::Convex.AdditionAtom) at /Users/jinrae/.julia/packages/ChainRules/wuTHR/src/rulesets/Base/fastmath_able.jl:197
 [3] chain_rrule at /Users/jinrae/.julia/packages/Zygote/EjVY4/src/compiler/chainrules.jl:87 [inlined]
 [4] macro expansion at /Users/jinrae/.julia/packages/Zygote/EjVY4/src/compiler/interface2.jl:0 [inlined]
 [5] _pullback(::Zygote.Context, ::typeof(sign), ::Convex.AdditionAtom) at /Users/jinrae/.julia/packages/Zygote/EjVY4/src/compiler/interface2.jl:12
 [6] Problem at /Users/jinrae/.julia/packages/Convex/Zv1ch/src/problems.jl:10 [inlined]
 [7] _pullback(::Zygote.Context, ::Type{Problem{Float64}}, ::Symbol, ::Convex.AdditionAtom, ::Array{Constraint,1}) at /Users/jinrae/.julia/packages/Zygote/EjVY4/src/compiler/interface2.jl:0
 [8] #minimize#4 at /Users/jinrae/.julia/packages/Convex/Zv1ch/src/problems.jl:74 [inlined]
 [9] _pullback(::Zygote.Context, ::Convex.var"##minimize#4", ::Type{Float64}, ::typeof(minimize), ::Convex.AdditionAtom, ::Array{Constraint,1}) at /Users/jinrae/.julia/packages/Zygote/EjVY4/src/compiler/interface2.jl:0
 [10] minimize at /Users/jinrae/.julia/packages/Convex/Zv1ch/src/problems.jl:74 [inlined]
 [11] _pullback(::Zygote.Context, ::typeof(minimize), ::Convex.AdditionAtom, ::Array{Constraint,1}) at /Users/jinrae/.julia/packages/Zygote/EjVY4/src/compiler/interface2.jl:0
 [12] minimize at /Users/jinrae/.julia/packages/Convex/Zv1ch/src/problems.jl:74 [inlined]
 [13] _pullback(::Zygote.Context, ::typeof(minimize), ::Convex.AdditionAtom) at /Users/jinrae/.julia/packages/Zygote/EjVY4/src/compiler/interface2.jl:0
 [14] #149 at /Users/jinrae/.julia/dev/PartiallyConvexApproximator/test/cvx_flux.jl:28 [inlined]
 [15] _pullback(::Zygote.Context, ::var"#149#150", ::Array{Float64,2}, ::Array{Float64,2}) at /Users/jinrae/.julia/packages/Zygote/EjVY4/src/compiler/interface2.jl:0
 [16] adjoint at /Users/jinrae/.julia/packages/Zygote/EjVY4/src/lib/lib.jl:175 [inlined]
 [17] _pullback at /Users/jinrae/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:57 [inlined]
 [18] loss_reg at /Users/jinrae/.julia/dev/PartiallyConvexApproximator/test/cvx_flux.jl:35 [inlined]
 [19] _pullback(::Zygote.Context, ::var"#loss_reg#152"{var"#149#150",var"#sqnorm#151"}, ::Array{Float64,2}, ::Array{Float64,2}) at /Users/jinrae/.julia/packages/Zygote/EjVY4/src/compiler/interface2.jl:0
 [20] adjoint at /Users/jinrae/.julia/packages/Zygote/EjVY4/src/lib/lib.jl:175 [inlined]
 [21] _pullback at /Users/jinrae/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:57 [inlined]
 [22] #15 at /Users/jinrae/.julia/packages/Flux/sY3yx/src/optimise/train.jl:103 [inlined]
 [23] _pullback(::Zygote.Context, ::Flux.Optimise.var"#15#21"{var"#loss_reg#152"{var"#149#150",var"#sqnorm#151"},Tuple{Array{Float64,2},Array{Float64,2}}}) at /Users/jinrae/.julia/packages/Zygote/EjVY4/src/compiler/interface2.jl:0
 [24] pullback(::Function, ::Zygote.Params) at /Applications/Julia-1.5.app/Contents/Resources/julia/lib/julia/sys.dylib:?
 [25] gradient(::Function, ::Zygote.Params) at /Applications/Julia-1.5.app/Contents/Resources/julia/lib/julia/sys.dylib:?
 [26] macro expansion at /Users/jinrae/.julia/packages/Flux/sY3yx/src/optimise/train.jl:102 [inlined]
 [27] macro expansion at /Users/jinrae/.julia/packages/Juno/n6wyj/src/progress.jl:134 [inlined]
 [28] train!(::Function, ::Zygote.Params, ::DataLoader{Tuple{Array{Float64,2},Array{Float64,2}}}, ::ADAM; cb::Flux.Optimise.var"#16#22") at /Users/jinrae/.julia/packages/Flux/sY3yx/src/optimise/train.jl:100
 [29] train!(::Function, ::Zygote.Params, ::DataLoader{Tuple{Array{Float64,2},Array{Float64,2}}}, ::ADAM) at /Users/jinrae/.julia/packages/Flux/sY3yx/src/optimise/train.jl:98
 [30] main() at /Users/jinrae/.julia/dev/PartiallyConvexApproximator/test/cvx_flux.jl:36
 [31] top-level scope at REPL[103]:1


It’s not the problem only for Dense. When using some custom networks, it is much difficult to deal with it.

Is there any great idea to combine Convex.jl and Flux.jl?

It looks like the issue here is that Flux is differentiating through the Convex.jl solve to do backpropagation. Convex.jl’s problems are not differentiable yet. There is work in the Julia ecosystem to differentiate through optimization problems (DiffOpt.jl) but I don’t think it’s quite ready yet and isn’t hooked up to Convex.jl yet either.

Python’s cvxpylayers is more mature from what I’ve heard, so you might be able to PyCall out to use cvxpy in your loss function, but it will probably require a custom derivative rule to connect it to Julia’s autodiff ecosystem.

2 Likes

OMG, it sounds like it would be impossible for me…
Bad news. Thanks.

Do you think it’s possible to differentiate Convex.jl problems by expanding some methods?

I don’t think so, or at least, I don’t know how.

You could maybe try @tjdiamandis’s GitHub - tjdiamandis/ConeProgramDiff.jl, but it’s not hooked up to Convex.jl. The optimization problem you have there looks fairly simple so there might be other approaches too, outside the formulation via Convex. I don’t know much about this area though, sorry.

1 Like

Unfortunately there isn’t great support for differentiation through nested optimization problems in the Julia ecosystem right now as far as I know (I think this will change soon!). The ConeProgramDiff.jl package will give you a function that computes gradients wrt the parameters of a solved conic form problem.

If you want to include a linear program or quadratic program as a layer of your neural network, it may be easier to differentiate directly through the KKT conditions (see OptNet paper Section 3). I haven’t used Flux.jl, so I’m not sure what is involved in defining a new layer. But the forward pass would be a solve of your optimization problem by any solver that provides the dual variables, then the backward pass is given by eq (7) and eq (8) of the OptNet paper.

3 Likes