After a few weeks of thinking about this on and off with a lot of help from the good folks in Slack’s #math-optimization channel I’ve decided I don’t trust myself to come up with an optimal (no pun intended) solution to this, and given that this problem is pretty key to a library I’m writing which I’m hoping will beat out some decent R and Stata libraries, so I’m posting about it here.
I’m trying to construct a Synthetic Control estimator following the method laid out in Abadie (2021). The most time consuming part is a nested optimization problem as follows:
Find a set of weights V^* given by
where Y_1 is a T \times 1 vector, Y_0 is a T \times J matrix and W(V) is a J \times 1 vector. W(V) for a given V is determined by an inner optimization problem as follows:
Here X_1 is a k \times 1 vector, X_0 is a k \times J matrix, and V is a k \times 1 vector.
The simplest Optim
based implementation I could think of is this:
# inner
get_W(v, x₁, x₀) = optimize(w -> (x₁ .- x₀*w)'*Diagonal(v)*(x₁ .- x₀*w), wₗ, wᵤ, w₀).minimizer
#outer
get_V(y₁, y₀) = optimize(v -> sum((y₁ .- y₀*get_W(v, x₁, x₀)).^2), wₗ, wᵤ, v₀)
Which is great for teaching purposes but… never actually returns; even on a small problem it will run for hours and then fail with som Hager-Zhang linesearch error.
With the help of the Slack community I cobbled together the following JuMP/Optim hybrid solution (I’ll put a bit of setup code below the fold to make this an MWE):
m = Model(HiGHS.Optimizer); set_silent(m)
@variable(m, 0 <= w[1:J] <= 1)
@constraint(m, sum(w[i] for i ∈ 1:J) == 1)
@objective(m, Min, (x₁ .- x₀*w)'*Diagonal(v)*(x₁ .- x₀*w))
# Function to update v vector and re-solve model
function get_w_given_v(v_new; m = m)
@objective(m, Min, (x₁ .- x₀*w)'*Diagonal(v_new)*(x₁ .- x₀*w))
optimize!(m)
return value.(w)
end
# Outer optimization with Optim
function outer_opt(v; m = m, y₁ = y₁, y₀ = y₀)
w_star = get_w_given_v(v, m = m)
return sum((y₁ .- y₀*w_star).^2) + 10e6*((sum(v) - 1.0)^2 + sum(max.(v .- 1, [0]).^2) + sum(max.(-v,[0]).^2))
end
Which seems to work but
- seems slower than competitor packages in R/Stata (although they end up using some C++ routine of course and it’s a bit hard to say whether I’m comparing apples to apples); and
- is also sensitive to the choice of the scaling parameter for the penalty function that restricts the elements of
V
to lie between 0 and 1 and sum to 1.
So I’m interested in peoples thoughts on:
- General approach to solving this problem - is there a better way (people on Slack recommended fancy things like Symbolics, Complementarity.jl or ImplicitDifferentiation which are above my pay grade)
- Is there a way to nest JuMP problems, i.e. have the outer problem in JuMP as well to directly specify the constraints?
- What’s the simplest “close to the math” way of solving this, close to my initial two one liners above, which runs in reasonable time (I’m trying to write some Pluto docs for the package in which a simple one-liner solution would come in handy)
Thanks for any pointers!
Additional code for MWE:
import Pkg; Pkg.activate(@__DIR__)
using JuMP, Optim, HiGHS, LinearAlgebra
## Setup code
x₁ = [7351.0, 46.0, 5.0, 44.0, 50.0]
x₀ = [7026.0 6747.0 6685.0 7231.0 6625.0 4936.0 6226.0 5829.0 7197.0 6446.0 5984.0 3475.0 4792.0 9673.0 5930.0 8330.0;
31.0 64.0 109.0 62.0 38.0 33.0 42.0 24.0 93.0 55.0 77.0 53.0 31.0 63.0 53.0 16.0;
10.0 6.0 7.0 10.0 10.0 15.0 14.0 9.0 7.0 13.0 8.0 18.0 15.0 5.0 14.0 8.0;
36.0 40.0 40.0 29.0 37.0 34.0 41.0 42.0 36.0 33.0 35.0 35.0 41.0 35.0 42.0 34.0;
47.0 54.0 31.0 44.0 31.0 16.0 24.0 39.0 44.0 52.0 37.0 10.0 11.0 45.0 32.0 53.0]
# Data for outer problem
y₁ = [12115.0, 12761.0, 13519.0, 14481.0, 15291.0, 15998.0, 16679.0, 17786.0, 18994.0, 20465.0]
y₀ = [11513.0 11242.0 11079.0 11106.0 10929.0 7870.0 10593.0 9986.0 11304.0 9736.0 10548.0 5812.0 7447.0 15338.0 9161.0 13533.0;
11537.0 12148.0 11827.0 12115.0 11869.0 8204.0 11275.0 10813.0 11784.0 10687.0 11205.0 6263.0 7957.0 15963.0 9917.0 13940.0;
12300.0 13048.0 12334.0 12822.0 12518.0 8388.0 11812.0 11346.0 12419.0 11262.0 12022.0 6479.0 8378.0 16611.0 10669.0 15008.0;
13120.0 13533.0 13113.0 13777.0 13145.0 8834.0 12543.0 12064.0 13234.0 12141.0 13220.0 6570.0 8812.0 17710.0 11336.0 16549.0;
14019.0 14296.0 13735.0 14698.0 13746.0 9297.0 13285.0 12978.0 13938.0 12556.0 14354.0 6959.0 9259.0 18812.0 12068.0 17600.0;
14537.0 14921.0 14292.0 15608.0 14308.0 9525.0 13896.0 13590.0 14613.0 13085.0 15184.0 7414.0 9744.0 19458.0 12795.0 18439.0;
15554.0 15549.0 15013.0 16024.0 14940.0 9548.0 14688.0 14425.0 15197.0 13402.0 15823.0 8126.0 10542.0 20120.0 13717.0 19407.0;
16524.0 16595.0 16209.0 16766.0 16040.0 10277.0 15784.0 15862.0 16082.0 13569.0 16299.0 9057.0 11434.0 21334.0 14864.0 20711.0;
17255.0 17768.0 17345.0 17418.0 17193.0 11036.0 16875.0 17269.0 17387.0 14124.0 17043.0 10042.0 12417.0 22965.0 15716.0 22047.0;
17322.0 19070.0 18526.0 18237.0 18244.0 11405.0 17946.0 18815.0 18665.0 14420.0 18004.0 10894.0 13365.0 24518.0 16397.0 23064.0]
k = size(x₀, 1); J = size(x₀, 2)
equal_v = fill(1/k, k); equal_w = fill(1/J, J)
v1s = ones(k); v0s = zeros(k); w1s = ones(J); w0s = zeros(J)
v = equal_v